Lora原理和训练详细解析
diffusers源码分析
diffusers中给unet加lora是利用peft库,然后通过修改原始的to_q, to_k, to_v, to_out.0等的线性层来做的:
主要原理就是使用矩阵乘法的分配律来做的,原理很简单,可以通过以下代码来理解。
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
# 设置lora配置
unet_lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
# 随后unet调用add_adapter
unet.add_adapter(unet_lora_config)修改前后网络结果变化: 修改前:
(attn2): Attention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=768, out_features=320, bias=False)
(to_v): Linear(in_features=768, out_features=320, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)修改后:
(attn2): Attention(
(to_q): lora.Linear(
(base_layer): Linear(in_features=320, out_features=320, bias=False)
(lora_dropout): ModuleDict(
(default): Identity()
)
(lora_A): ModuleDict(
(default): Linear(in_features=320, out_features=4, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=4, out_features=320, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(to_k): lora.Linear(
(base_layer): Linear(in_features=768, out_features=320, bias=False)
(lora_dropout): ModuleDict(
(default): Identity()
)
(lora_A): ModuleDict(
(default): Linear(in_features=768, out_features=4, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=4, out_features=320, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(to_v): lora.Linear(
(base_layer): Linear(in_features=768, out_features=320, bias=False)
(lora_dropout): ModuleDict(
(default): Identity()
)
(lora_A): ModuleDict(
(default): Linear(in_features=768, out_features=4, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=4, out_features=320, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(to_out): ModuleList(
(0): lora.Linear(
(base_layer): Linear(in_features=320, out_features=320, bias=True)
(lora_dropout): ModuleDict(
(default): Identity()
)
(lora_A): ModuleDict(
(default): Linear(in_features=320, out_features=4, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=4, out_features=320, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(1): Dropout(p=0.0, inplace=False)
)
)以下是peft库中添加lora的代码:
#来自peft/tuners/lora/layer.py:LoraLayer
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
if r > 0:
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)merge lora权重:
# 来自peft/tuners/lora/layer.py:Linear
# 矩阵运算的分配律
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
if self.merged:
warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all():
raise ValueError(f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken")
base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)
# 来自peft/tuners/lora/layer.py:Linear
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype
# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16
weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight
if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()
# @ 执行矩阵乘法
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)
return output_tensor执行forward:
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
# 矩阵运算的分配律
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling
result = result.to(previous_dtype)
return result