Lora原理和训练详细解析

diffusers源码分析

diffusers中给unet加lora是利用peft库,然后通过修改原始的to_q, to_k, to_v, to_out.0等的线性层来做的:

主要原理就是使用矩阵乘法的分配律来做的,原理很简单,可以通过以下代码来理解。

from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
# 设置lora配置
unet_lora_config = LoraConfig(
	r=4,
	lora_alpha=4,
	init_lora_weights="gaussian",
	target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

# 随后unet调用add_adapter
unet.add_adapter(unet_lora_config)

修改前后网络结果变化: 修改前:

(attn2): Attention(
	(to_q): Linear(in_features=320, out_features=320, bias=False)
	(to_k): Linear(in_features=768, out_features=320, bias=False)
	(to_v): Linear(in_features=768, out_features=320, bias=False)
	(to_out): ModuleList(
	  (0): Linear(in_features=320, out_features=320, bias=True)
	  (1): Dropout(p=0.0, inplace=False)
	)
  )

修改后:

(attn2): Attention(
	(to_q): lora.Linear(
	  (base_layer): Linear(in_features=320, out_features=320, bias=False)
	  (lora_dropout): ModuleDict(
		(default): Identity()
	  )
	  (lora_A): ModuleDict(
		(default): Linear(in_features=320, out_features=4, bias=False)
	  )
	  (lora_B): ModuleDict(
		(default): Linear(in_features=4, out_features=320, bias=False)
	  )
	  (lora_embedding_A): ParameterDict()
	  (lora_embedding_B): ParameterDict()
	)
	(to_k): lora.Linear(
	  (base_layer): Linear(in_features=768, out_features=320, bias=False)
	  (lora_dropout): ModuleDict(
		(default): Identity()
	  )
	  (lora_A): ModuleDict(
		(default): Linear(in_features=768, out_features=4, bias=False)
	  )
	  (lora_B): ModuleDict(
		(default): Linear(in_features=4, out_features=320, bias=False)
	  )
	  (lora_embedding_A): ParameterDict()
	  (lora_embedding_B): ParameterDict()
	)
	(to_v): lora.Linear(
	  (base_layer): Linear(in_features=768, out_features=320, bias=False)
	  (lora_dropout): ModuleDict(
		(default): Identity()
	  )
	  (lora_A): ModuleDict(
		(default): Linear(in_features=768, out_features=4, bias=False)
	  )
	  (lora_B): ModuleDict(
		(default): Linear(in_features=4, out_features=320, bias=False)
	  )
	  (lora_embedding_A): ParameterDict()
	  (lora_embedding_B): ParameterDict()
	)
	(to_out): ModuleList(
	  (0): lora.Linear(
		(base_layer): Linear(in_features=320, out_features=320, bias=True)
		(lora_dropout): ModuleDict(
		  (default): Identity()
		)
		(lora_A): ModuleDict(
		  (default): Linear(in_features=320, out_features=4, bias=False)
		)
		(lora_B): ModuleDict(
		  (default): Linear(in_features=4, out_features=320, bias=False)
		)
		(lora_embedding_A): ParameterDict()
		(lora_embedding_B): ParameterDict()
	  )
	  (1): Dropout(p=0.0, inplace=False)
	)
  )

以下是peft库中添加lora的代码:

#来自peft/tuners/lora/layer.py:LoraLayer

if r <= 0:
	raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
	lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
	lora_dropout_layer = nn.Identity()

self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
if r > 0:
	self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
	self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
	self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights == "loftq":
	self.loftq_init(adapter_name)
elif init_lora_weights:
	self.reset_lora_parameters(adapter_name, init_lora_weights)
weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
	# the layer is already completely initialized, this is an update
	if weight.dtype.is_floating_point or weight.dtype.is_complex:
		self.to(weight.device, dtype=weight.dtype)
	else:
		self.to(weight.device)
self.set_adapter(self.active_adapters)

merge lora权重:

# 来自peft/tuners/lora/layer.py:Linear
# 矩阵运算的分配律
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
	if self.merged:
		warnings.warn(
			f"Already following adapters were merged {','.join(self.merged_adapters)}. "
			f"You are now additionally merging {','.join(self.active_adapters)}."
		)
	if adapter_names is None:
		adapter_names = self.active_adapters
	
	for active_adapter in adapter_names:
		if active_adapter in self.lora_A.keys():
			base_layer = self.get_base_layer()
			if safe_merge:
				# Note that safe_merge will be slower than the normal merge
				# because of the copy operation.
				orig_weights = base_layer.weight.data.clone()
				orig_weights += self.get_delta_weight(active_adapter)
	
				if not torch.isfinite(orig_weights).all():
					raise ValueError(f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken")
				base_layer.weight.data = orig_weights
			else:
				base_layer.weight.data += self.get_delta_weight(active_adapter)
			self.merged_adapters.append(active_adapter)

# 来自peft/tuners/lora/layer.py:Linear
def get_delta_weight(self, adapter) -> torch.Tensor:
	"""
	Compute the delta weight for the given adapter.
	Args:
		adapter (str):
			The name of the adapter for which the delta weight should be computed.
	"""
	device = self.lora_B[adapter].weight.device
	dtype = self.lora_B[adapter].weight.dtype

	# In case users wants to merge the adapter weights that are in
	# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
	# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
	cast_to_fp32 = device.type == "cpu" and dtype == torch.float16
	weight_A = self.lora_A[adapter].weight
	weight_B = self.lora_B[adapter].weight

	if cast_to_fp32:
		weight_A = weight_A.float()
		weight_B = weight_B.float()
	# @ 执行矩阵乘法
	output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
	if cast_to_fp32:
		output_tensor = output_tensor.to(dtype=dtype)
		# cast back the weights
		self.lora_A[adapter].weight.data = weight_A.to(dtype)
		self.lora_B[adapter].weight.data = weight_B.to(dtype)
	return output_tensor

执行forward:

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
	previous_dtype = x.dtype
	if self.disable_adapters:
		if self.merged:
			self.unmerge()
		result = self.base_layer(x, *args, **kwargs)
	elif self.merged:
		result = self.base_layer(x, *args, **kwargs)
	else:
		result = self.base_layer(x, *args, **kwargs)
		for active_adapter in self.active_adapters:
			# 矩阵运算的分配律
			if active_adapter not in self.lora_A.keys():
				continue
			lora_A = self.lora_A[active_adapter]
			lora_B = self.lora_B[active_adapter]
			dropout = self.lora_dropout[active_adapter]
			scaling = self.scaling[active_adapter]
			x = x.to(lora_A.weight.dtype)
			result += lora_B(lora_A(dropout(x))) * scaling
	result = result.to(previous_dtype)
	return result