April 30, 2024
ControlNet 是 ICCV 2023 的一篇 best paper,原文:Adding Conditional Control to Text-to-Image Diffusion Models,其目的是在 Diffusion 图像生成时增加条件控制,与通常的文本 Prompt 不同,ControlNet 支持在像素空间做精细控制,如下图的三个示例,分别是通过 Canny 边缘、人体姿态、深度来控制图像的生成,左上角是输入的 condition 图,其余是 Diffusion 生成的结果图,可以看到控制非常有效,并且图像质量没有明显下降,整体和谐。
算法简介
ControlNet 本身是一种模型微调方法,如下图右侧,在原模型的基础上,增加一个可训练副本,可训练副本的输入是原输入x加上条件c,然后把两个模型的输出相加,可训练副本的输入输出都经过零卷积(zero convolution)处理,用于在刚开始训练时保持模型的稳定性(零卷积的参数初始化为0,使得微调的影响从0开始逐步增加)
具体的针对 Stable Diffusion 的 ControlNet 结构如下图,只复制(结构+权重)了 UNet 的 Encoder block 和 Middle block,控制条件图c先经过几层卷积,再与原UNet的输入zt相加作为输入,ControlNet 每个 block 的输出再添加到原 UNet 的 Decoder block 输入,实际代码实现中,这里 ControlNet 的输出还可以乘上一个scale,用于控制影响程度
作为一种微调方法,增加 ControlNet 后的训练 Loss 还是和原始 Diffusion 的 Loss 一致:
完整的 Diffusion + ControlNet 流程如下:
Diffusers 实现
Hugging Face Diffusers 中 关于 ControlNet 的实现主要有三个地方:
- ControlNetModel:包含 ControlNet 的模型结构,init 和 forward 实现
- UNet2DConditionModel:包含 ControlNet 输出添加到 UNet 中的实现
- StableDiffusionControlNetPipeline:推理流程串联
另外还有 ControlNet 的训练示例代码:train_controlnet.py、train_controlnet_sdxl.py
ControlNetModel
首先是 ControlNet 的输出,包含 down block 列表和 mid block,用于添加到原 UNet
class ControlNetOutput(BaseOutput):
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
然后是对 ControlNet 输入条件图的处理:
class ControlNetConditioningEmbedding(nn.Module):
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
# 零卷积输出
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
可以看到其主要是几层 Conv2d,用来将输入图转为 feature map ,并在最后通过zero_module添加了一个零卷积
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
接下来是 ControlNetModel 的初始化,主要包含从 UNet 复制结构和权重的逻辑,以及在各个block后添加零卷积的逻辑:
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
conditioning_channels: int = 3,
):
# 从 unet 读取一些参数
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
...
# 创建 controlnet 对象
controlnet = cls(...)
# 从 unet 复制权重
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
...
# 复制 down_blocks 权重
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
# 复制 mid_block 权重
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
def __init__(
self,
...
):
super().__init__()
...
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# 初始化 time, class embedding
...
# 初始化条件输入embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
# 添加到 down block 后的零卷积
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
self.down_blocks = ...
# mid
mid_block_channel = block_out_channels[-1]
# 添加到 mid block 后的零卷积
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
self.mid_block = ...
最后是 forward 实现,中间的 down、mid 层逻辑与 UNet 一致,区别主要是最后多了零卷积处理,并乘以scale系数
def forward(
self,
...
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
# 1. time
timesteps = timestep
emb = ...
# 2. pre-process
sample = self.conv_in(sample)
# 从输入条件图提取 feature map,并 Add 到 sample
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(sample, emb)
# 5. Control net blocks
# 将前面的 down、mid 输出再输入到零卷积
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
mid_block_res_sample = mid_block_res_sample * conditioning_scale
...
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
UNet2DConditionModel
unet 主要是 forward 中涉及对 ControlNet 输出的处理:
def forward(
self,
...
# ControlNet 结果输入
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
...
) -> Union[UNet2DConditionOutput, Tuple]:
...
# 判断是否用到了 ControlNet
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
...
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# down 层的输出处理:down_block_res_samples = down_block_res_samples + down_block_additional_residuals
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(sample, emb)
# mid 层的输出处理:sample = sample + mid_block_additional_residual
if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
...
StableDiffusionControlNetPipeline
pipeline 中的实现,主要是去噪过程增加对 controlnet 的调用:
for i, t in enumerate(timesteps):
...
# 输入给 controlnet 的文本 embedding
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
# controlnet scale 处理
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
# 调用 controlnet 模型推理,得到 down_block_res_samples, mid_block_res_sample
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
# 将 controlnet 的输出设置到 unet (down_block_res_samples、mid_block_res_sample)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
...
ControlNet 训练
diffusers 代码中提供了 ControlNet 完整的训练示例,包括 SD1.5 和 SDXL,两者区别不大,先来看 SD1.5 的训练实现:
首先是各个模块的初始化:
# 初始化 tokenizer、noise_scheduler、text_encoder、vae、unet
tokenizer = AutoTokenizer.from_pretrained(...)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(...)
vae = AutoencoderKL.from_pretrained(...)
unet = UNet2DConditionModel.from_pretrained(...)
# 初始化 controlnet,从预训练模型加载,或者从unet中复制
if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights")
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
else:
logger.info("Initializing controlnet weights from unet")
controlnet = ControlNetModel.from_unet(unet)
# 只训练 controlnet
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
controlnet.train()
然后是 Optimizer、Dataloader、lr_scheduler 的初始化:
# Optimizer creation
params_to_optimize = controlnet.parameters()
optimizer = optimizer_class(...)
train_dataset = make_train_dataset(args, tokenizer, accelerator)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
...
)
lr_scheduler = get_scheduler(...)
# Prepare everything with our `accelerator`.
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
最后是 Diffusion 的加噪训练:
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(controlnet):
# 将结果图转为 latents
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# 生成随机噪声
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# 生成随机 t
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# 通过 scheduler 进行加噪
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# 文本 embedding
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
# 输入条件图
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
# 调用 controlnet 生成 down_block_res_samples, mid_block_res_sample
down_block_res_samples, mid_block_res_sample = controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_image,
return_dict=False,
)
# 调用 unet 预测噪声(设置 down_block_additional_residuals、mid_block_additional_residual)
model_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
return_dict=False,
)[0]
# 计算 loss
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# 反向传播更新权重
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
这里注意输入给 ControlNet 的 Prompt 置空一部分,有利于在没有prompt的时候,更好的挖掘控制条件中的信息(例如边缘,深度等)
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if random.random() < args.proportion_empty_prompts:
captions.append("") # 置空
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
另外可以设置一些参数降低显存占用:
- –mixed_precision:启用混合精度(fp16、bf16)
- –use_8bit_adam:启用 8bit Adam 优化器
- –enable_xformers_memory_efficient_attention:启用 xformers 优化
SDXL 的 ControlNet 训练整体与 SD1.5 类似,区别主要是 SDXL 新增的 embedding:
- pooled text embedding:Prompt 整体语义
- original size、target size:尺度信息
- crop top-left coord:裁剪参数
def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution)
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
prompt_batch = batch[args.caption_column]
prompt_embeds, pooled_prompt_embeds = encode_prompt(
prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
)
add_text_embeds = pooled_prompt_embeds
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
prompt_embeds = prompt_embeds.to(accelerator.device)
add_text_embeds = add_text_embeds.to(accelerator.device)
add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}