StableDiffusion工程理解

通用controlnet网络结构

SD1.5 Controlnet Connection.svg SD1.5 Controlnet Connection.svg

unet行为像全卷积网络

主要在CrossAttnDownBlock2D, CrossAttnUpBlock2D这些模块中,这些模块主要包含了ResnetBlock2D和Transformer2DModel,以下为这两个模块的中间数据流

ResnetBlock2D

这部分temb是如何加到hidden_state中的呢:

...
hidden_states.shape
# torch.Size([2, 320, 64, 64])
temb.shape
# torch.Size([2, 1280])
temb = self.time_emb_proj(temb)[:, :, None, None]
# torch.Size([2, 320, 1, 1])
hidden_states = hidden_states + temb

相当于给每个特征像素加上了这个时间step,所以看着像是全卷积网络,中间网络没有任何的flatten操作。

Transformer2DModel

这部分也好理解,主要是在通道这一层进行操作的(attention操作),虽然尺寸部分被拉直了,但是长度并没有变化,所以最终尺寸仍然不变。

hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
# 随后执行cross attention.

推理框架梳理

主要包含了几个关键的模块:

模块解释具体实现方式
pasd.unet_2d_conditionpasd中修改后的unet结构这里主要是对于unet_2d_block增加了一个额外的输入pixelwise_hidden_states,原始的controlnet和unet upblock连接方式保持不变
pasd.unet_2d_blockpasd中unet实际修改部分具体实现方式是在UpBlock中增加了一层Transformer2DModel,新增的Transformers2DModel是在模块最后的,也即对于UpBlock2D就是在resnet block之后,对于CrossAttnUpBlock2D就是在text_embedding的Transformers2DModel之后,又加了一个Transformers2DModel用于处理文本cross attention之后的hidden_states
pasd.controlnet.ControlNetModelpasd中的controlnet目前看只是增加了一个return_rgbs以及use_rrdb具体用途还未可知。rrdb相关的代码是from basicsr.archs.rrdbnet_arch import RRDB
好像结构上的比较关键的信息也就这么多了,并且实际看下来,对于SD原始pipeline的改动不大,理论上来讲模型转换工程难度不大。

Qualcomm SD1.5框架中的代码理解

SD1.5框架是基于diffusers 0.10.2搞的

关于AdaRounding

AdaRounding全称是Adaptive Rounding,详细的阅读和解析见[[Adaptive Rounding(AdaRound模型后量化算法)]]。

当前内存/显存在模型转换阶段占用比较严重的就是这个部分,问题主要在于需要保存大量的中间变量,需要从实现上优化这部分。

模型结构转换

通常mha表示的是multi-head attention,sha表示的是single-head attention,但这里表示的只是attention实现的方式,而不是指有多少头,两种表达指向的是相同的运算(多头注意力)。只不过mha的多头是通过权重的reshape来实现了,而sha则是通过多个同级的conv2d来实现了,从mha转换到sha的实现为:

def replace_linear_to_convs(self):
	query_dim = self.query_dim
	cross_attention_dim = self.cross_attention_dim
	heads = self.heads
	dim_head = self.dim_head
	bias = self.bias

	self.to_q_convs = nn.ModuleList([nn.Conv2d(query_dim, dim_head, 1, bias=bias) for _ in range(heads)])
	self.to_k_convs = nn.ModuleList([nn.Conv2d(cross_attention_dim, dim_head, 1, bias=bias) for _ in range(heads)])
	self.to_v_convs = nn.ModuleList([nn.Conv2d(cross_attention_dim, dim_head, 1, bias=bias) for _ in range(heads)])

	# copy weights
	for ndx in range(self.heads):
		# with torch.no_grad():
		self.to_q_convs[ndx].weight.data.copy_(self.to_q.weight[ndx*dim_head:(ndx+1)*dim_head, :, None, None])
		self.to_k_convs[ndx].weight.data.copy_(self.to_k.weight[ndx*dim_head:(ndx+1)*dim_head, :, None, None])
		self.to_v_convs[ndx].weight.data.copy_(self.to_v.weight[ndx*dim_head:(ndx+1)*dim_head, :, None, None])
		if bias:
			self.to_q_convs[ndx].bias.data.copy_(self.to_q.bias[ndx*dim_head:(ndx+1)*dim_head])
			self.to_k_convs[ndx].bias.data.copy_(self.to_k.bias[ndx*dim_head:(ndx+1)*dim_head])
			self.to_v_convs[ndx].bias.data.copy_(self.to_v.bias[ndx*dim_head:(ndx+1)*dim_head])

		self.to_q_convs[ndx].to(self.to_q.weight)
		self.to_k_convs[ndx].to(self.to_k.weight)
		self.to_v_convs[ndx].to(self.to_v.weight)

	self.matmul_1 = nn.ModuleList([elementwise_ops.MatMul() for _ in range(heads)])
	self.softmax_1 = nn.ModuleList([torch.nn.Softmax(-1) for _ in range(heads)])
	self.matmul_2 = nn.ModuleList([elementwise_ops.MatMul() for _ in range(heads)])
	self.concat_1 = elementwise_ops.Concat(-1)
	Transformer2DModel.mha_to_sha = True
	if not hasattr(self, 'forward_mha'):
		self.forward_mha = self.forward
		self.forward = self._forward_single_head_convs
encodings和权重的转换

关于权重的转换可以理解,因为mha拆分成多个1x1的conv2d了,所以权重需要按照head进行拆分,关于encodings,每个权重只有一个(可以用文本编辑器打开encodings文件看看)

权重encodings的例子:

{
	...,
	"down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight": [
		{
			"bitwidth": 8,
			"dtype": "int",
			"is_symmetric": "True",
			"max": 0.3505791425704956,
			"min": -0.3533396082600271,
			"offset": -128,
			"scale": 0.0027604656895314616
		}
	],
	"down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight": [
		{
			"bitwidth": 8,
			"dtype": "int",
			"is_symmetric": "True",
			"max": 0.18362492322921753,
			"min": -0.18507078876645547,
			"offset": -128,
			"scale": 0.0014458655372379333
		}
	],
	"down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight": [
		{
			"bitwidth": 8,
			"dtype": "int",
			"is_symmetric": "True",
			"max": 0.9385758638381958,
			"min": -0.94596622497078,
			"offset": -128,
			"scale": 0.007390361132584219
		}
	],
	...
}

replace encodings的实现:

def replace_encodings(source_dir, num_heads=8):
    with open(os.path.join(source_dir, "parameters_mha.encodings"), "rt") as f:
        encodings = json.load(f)

    new_dict = {}
    remove_keys = []
    for key, value in encodings.items():
        for qkv in ["q", "k", "v"]:
            if key.endswith(f"to_{qkv}.weight"):
                remove_keys.append(key)

                for i in range(num_heads):
                    new_key = key.replace(f"to_{qkv}.weight", f"to_{qkv}_convs.{i}.weight")
                    new_dict[new_key] = value

    encodings.update(new_dict)

    with open(os.path.join(source_dir, "parameters_sha.encodings"), "w") as f:
        json.dump(encodings, f, indent=4)

replace权重的实现:

def replace_tensors(source_dir, num_heads=8, dim_head=40):
    state_dict = torch.load(os.path.join(source_dir, "state_dict_mha.pth"))
    new_dict = collections.OrderedDict()
    remove_keys = []
    for key, value in state_dict.items():
        for qkv in ["q", "k", "v"]:
            if key.endswith(f"to_{qkv}.weight"):
                remove_keys.append(key)
                dim_head = int(value.shape[0] / num_heads)
                splited = split_tensors(value, num_heads, dim_head)

                for i, sp in enumerate(splited):
                    new_key = key.replace(f"to_{qkv}.weight", f"to_{qkv}_convs.{i}.weight")
                    new_dict[new_key] = sp

        if key.endswith("to_q.bias"):
            print('unexpected')
            raise Exception(f"Unexpected key {key} in state_dict")

    state_dict.update(new_dict)

    if not os.path.exists(source_dir):
        os.makedirs(source_dir)

    torch.save(state_dict, os.path.join(source_dir, "state_dict_sha.pth"))
推理

redefined_modules/diffusers/models/attention.py:CrossAttention._forward_single_head_convs

多个同级的conv2d推理的结果分别执行单头注意力,随后输出通过concat的方式合并成最终的hidden_states。

另外需要注意,测试下来黄彪训练的pasd模型如果control_image是512的话结果会很差,需要至少768.

pasd的unet输入图像范围是[-1, 1], controlnet输入图像范围是[0, 1]

计划

  • 直接基于Qualcomm Controlnet example修改unet_2d_block,增加pixelwise的cross-attention block。
  • 测试修改之后的效果是否正常,直接fp32跑。
  • 开始尝试进行量化和模型转换。
  • Android端bentchmark:分数和测试集实际效果。