ControlNet 简介与 Diffusers 实现

April 30, 2024

ControlNet 是 ICCV 2023 的一篇 best paper,原文:Adding Conditional Control to Text-to-Image Diffusion Models,其目的是在 Diffusion 图像生成时增加条件控制,与通常的文本 Prompt 不同,ControlNet 支持在像素空间做精细控制,如下图的三个示例,分别是通过 Canny 边缘、人体姿态、深度来控制图像的生成,左上角是输入的 condition 图,其余是 Diffusion 生成的结果图,可以看到控制非常有效,并且图像质量没有明显下降,整体和谐。

算法简介

ControlNet 本身是一种模型微调方法,如下图右侧,在原模型的基础上,增加一个可训练副本,可训练副本的输入是原输入x加上条件c,然后把两个模型的输出相加,可训练副本的输入输出都经过零卷积(zero convolution)处理,用于在刚开始训练时保持模型的稳定性(零卷积的参数初始化为0,使得微调的影响从0开始逐步增加)

具体的针对 Stable Diffusion 的 ControlNet 结构如下图,只复制(结构+权重)了 UNet 的 Encoder block 和 Middle block,控制条件图c先经过几层卷积,再与原UNet的输入zt相加作为输入,ControlNet 每个 block 的输出再添加到原 UNet 的 Decoder block 输入,实际代码实现中,这里 ControlNet 的输出还可以乘上一个scale,用于控制影响程度

作为一种微调方法,增加 ControlNet 后的训练 Loss 还是和原始 Diffusion 的 Loss 一致:

完整的 Diffusion + ControlNet 流程如下:

Diffusers 实现

Hugging Face Diffusers 中 关于 ControlNet 的实现主要有三个地方:

另外还有 ControlNet 的训练示例代码:train_controlnet.pytrain_controlnet_sdxl.py

ControlNetModel

首先是 ControlNet 的输出,包含 down block 列表和 mid block,用于添加到原 UNet

class ControlNetOutput(BaseOutput):
    down_block_res_samples: Tuple[torch.Tensor]
    mid_block_res_sample: torch.Tensor

然后是对 ControlNet 输入条件图的处理:

class ControlNetConditioningEmbedding(nn.Module):
    def __init__(
        self,
        conditioning_embedding_channels: int,
        conditioning_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
    ):
        super().__init__()

        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        # 零卷积输出
        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )

    def forward(self, conditioning):
        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)

        embedding = self.conv_out(embedding)

        return embedding

可以看到其主要是几层 Conv2d,用来将输入图转为 feature map ,并在最后通过zero_module添加了一个零卷积

def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

接下来是 ControlNetModel 的初始化,主要包含从 UNet 复制结构和权重的逻辑,以及在各个block后添加零卷积的逻辑:

def from_unet(
    cls,
    unet: UNet2DConditionModel,
    controlnet_conditioning_channel_order: str = "rgb",
    conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
    load_weights_from_unet: bool = True,
    conditioning_channels: int = 3,
):
    # 从 unet 读取一些参数
    transformer_layers_per_block = (
        unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
    )
    encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
    ...

    # 创建 controlnet 对象
    controlnet = cls(...)

    # 从 unet 复制权重
    if load_weights_from_unet:
        controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
        controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
        ...

        # 复制 down_blocks 权重
        controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())

        # 复制 mid_block 权重
        controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())

    return controlnet
def __init__(
    self,
    ...
):
    super().__init__()
    ...

    # input
    conv_in_kernel = 3
    conv_in_padding = (conv_in_kernel - 1) // 2
    self.conv_in = nn.Conv2d(
        in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
    )

    # 初始化 time, class embedding
    ...

    # 初始化条件输入embedding
    self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
        conditioning_embedding_channels=block_out_channels[0],
        block_out_channels=conditioning_embedding_out_channels,
        conditioning_channels=conditioning_channels,
    )

    self.down_blocks = nn.ModuleList([])
    self.controlnet_down_blocks = nn.ModuleList([])

    # down
    output_channel = block_out_channels[0]

    # 添加到 down block 后的零卷积
    controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
    controlnet_block = zero_module(controlnet_block)
    self.controlnet_down_blocks.append(controlnet_block)

    self.down_blocks = ...

    # mid
    mid_block_channel = block_out_channels[-1]

    # 添加到 mid block 后的零卷积
    controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
    controlnet_block = zero_module(controlnet_block)
    self.controlnet_mid_block = controlnet_block

    self.mid_block = ...

最后是 forward 实现,中间的 down、mid 层逻辑与 UNet 一致,区别主要是最后多了零卷积处理,并乘以scale系数

def forward(
    self,
    ...
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
    # 1. time
    timesteps = timestep
    emb = ...

    # 2. pre-process
    sample = self.conv_in(sample)

    # 从输入条件图提取 feature map,并 Add 到 sample
    controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
    sample = sample + controlnet_cond

    # 3. down
    down_block_res_samples = (sample,)
    for downsample_block in self.down_blocks:
        sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
        down_block_res_samples += res_samples

    # 4. mid
    if self.mid_block is not None:
        sample = self.mid_block(sample, emb)

    # 5. Control net blocks
    # 将前面的 down、mid 输出再输入到零卷积
    controlnet_down_block_res_samples = ()
    for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
        down_block_res_sample = controlnet_block(down_block_res_sample)
        controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)

    down_block_res_samples = controlnet_down_block_res_samples
    mid_block_res_sample = self.controlnet_mid_block(sample)

    # 6. scaling
    mid_block_res_sample = mid_block_res_sample * conditioning_scale
    ...

    if not return_dict:
        return (down_block_res_samples, mid_block_res_sample)

    return ControlNetOutput(
        down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
    )

UNet2DConditionModel

unet 主要是 forward 中涉及对 ControlNet 输出的处理:

def forward(
    self,
    ...
    # ControlNet 结果输入
    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
    mid_block_additional_residual: Optional[torch.Tensor] = None,
    ...
) -> Union[UNet2DConditionOutput, Tuple]:
    ...

    # 判断是否用到了 ControlNet
    is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
    ...

    # 3. down
    down_block_res_samples = (sample,)
    for downsample_block in self.down_blocks:
        sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
        down_block_res_samples += res_samples

    # down 层的输出处理:down_block_res_samples = down_block_res_samples + down_block_additional_residuals
    if is_controlnet:
        new_down_block_res_samples = ()

        for down_block_res_sample, down_block_additional_residual in zip(
            down_block_res_samples, down_block_additional_residuals
        ):
            down_block_res_sample = down_block_res_sample + down_block_additional_residual
            new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)

        down_block_res_samples = new_down_block_res_samples

    # 4. mid
    if self.mid_block is not None:
        sample = self.mid_block(sample, emb)

    # mid 层的输出处理:sample = sample + mid_block_additional_residual
    if is_controlnet:
        sample = sample + mid_block_additional_residual

    # 5. up
    ...

StableDiffusionControlNetPipeline

pipeline 中的实现,主要是去噪过程增加对 controlnet 的调用:

for i, t in enumerate(timesteps):
    ... 

    # 输入给 controlnet 的文本 embedding
    if guess_mode and self.do_classifier_free_guidance:
        # Infer ControlNet only for the conditional batch.
        control_model_input = latents
        control_model_input = self.scheduler.scale_model_input(control_model_input, t)
        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
    else:
        control_model_input = latent_model_input
        controlnet_prompt_embeds = prompt_embeds

    # controlnet scale 处理
    if isinstance(controlnet_keep[i], list):
        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
    else:
        controlnet_cond_scale = controlnet_conditioning_scale
        if isinstance(controlnet_cond_scale, list):
            controlnet_cond_scale = controlnet_cond_scale[0]
        cond_scale = controlnet_cond_scale * controlnet_keep[i]

    # 调用 controlnet 模型推理,得到 down_block_res_samples, mid_block_res_sample
    down_block_res_samples, mid_block_res_sample = self.controlnet(
        control_model_input,
        t,
        encoder_hidden_states=controlnet_prompt_embeds,
        controlnet_cond=image,
        conditioning_scale=cond_scale,
        guess_mode=guess_mode,
        return_dict=False,
    )

    # 将 controlnet 的输出设置到 unet (down_block_res_samples、mid_block_res_sample)
    noise_pred = self.unet(
        latent_model_input,
        t,
        encoder_hidden_states=prompt_embeds,
        timestep_cond=timestep_cond,
        cross_attention_kwargs=self.cross_attention_kwargs,
        down_block_additional_residuals=down_block_res_samples,
        mid_block_additional_residual=mid_block_res_sample,
        added_cond_kwargs=added_cond_kwargs,
        return_dict=False,
    )[0]

    ...

ControlNet 训练

diffusers 代码中提供了 ControlNet 完整的训练示例,包括 SD1.5 和 SDXL,两者区别不大,先来看 SD1.5 的训练实现:

首先是各个模块的初始化:

# 初始化 tokenizer、noise_scheduler、text_encoder、vae、unet
tokenizer = AutoTokenizer.from_pretrained(...)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(...)
vae = AutoencoderKL.from_pretrained(...)
unet = UNet2DConditionModel.from_pretrained(...)

# 初始化 controlnet,从预训练模型加载,或者从unet中复制
if args.controlnet_model_name_or_path:
    logger.info("Loading existing controlnet weights")
    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
else:
    logger.info("Initializing controlnet weights from unet")
    controlnet = ControlNetModel.from_unet(unet)

# 只训练 controlnet
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
controlnet.train()

然后是 Optimizer、Dataloader、lr_scheduler 的初始化:

# Optimizer creation
params_to_optimize = controlnet.parameters()
optimizer = optimizer_class(...)

train_dataset = make_train_dataset(args, tokenizer, accelerator)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    ...
)

lr_scheduler = get_scheduler(...)

# Prepare everything with our `accelerator`.
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    controlnet, optimizer, train_dataloader, lr_scheduler
)

最后是 Diffusion 的加噪训练:

for step, batch in enumerate(train_dataloader):
    with accelerator.accumulate(controlnet):
        # 将结果图转为 latents
        latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # 生成随机噪声
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # 生成随机 t
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # 通过 scheduler 进行加噪
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # 文本 embedding
        encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]

        # 输入条件图
        controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)

        # 调用 controlnet 生成 down_block_res_samples, mid_block_res_sample
        down_block_res_samples, mid_block_res_sample = controlnet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=encoder_hidden_states,
            controlnet_cond=controlnet_image,
            return_dict=False,
        )

        # 调用 unet 预测噪声(设置 down_block_additional_residuals、mid_block_additional_residual)
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=encoder_hidden_states,
            down_block_additional_residuals=[
                sample.to(dtype=weight_dtype) for sample in down_block_res_samples
            ],
            mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
            return_dict=False,
        )[0]

        # 计算 loss
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # 反向传播更新权重
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad(set_to_none=args.set_grads_to_none)

这里注意输入给 ControlNet 的 Prompt 置空一部分,有利于在没有prompt的时候,更好的挖掘控制条件中的信息(例如边缘,深度等)

def tokenize_captions(examples, is_train=True):
    captions = []
    for caption in examples[caption_column]:
        if random.random() < args.proportion_empty_prompts:
            captions.append("") # 置空
        elif isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

另外可以设置一些参数降低显存占用:

  • –mixed_precision:启用混合精度(fp16、bf16)
  • –use_8bit_adam:启用 8bit Adam 优化器
  • –enable_xformers_memory_efficient_attention:启用 xformers 优化

SDXL 的 ControlNet 训练整体与 SD1.5 类似,区别主要是 SDXL 新增的 embedding:

  • pooled text embedding:Prompt 整体语义
  • original size、target size:尺度信息
  • crop top-left coord:裁剪参数
def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
    original_size = (args.resolution, args.resolution)
    target_size = (args.resolution, args.resolution)
    crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
    prompt_batch = batch[args.caption_column]

    prompt_embeds, pooled_prompt_embeds = encode_prompt(
        prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
    )
    add_text_embeds = pooled_prompt_embeds

    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
    add_time_ids = list(original_size + crops_coords_top_left + target_size)
    add_time_ids = torch.tensor([add_time_ids])

    prompt_embeds = prompt_embeds.to(accelerator.device)
    add_text_embeds = add_text_embeds.to(accelerator.device)
    add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
    add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
    unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

    return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}