参数高效微调 (PEFT)#
这个迷你库为 Flax linen 模块添加 LoRA 支持。
LoRA 适配器#
一些 Flax linen nn.Modules
可用于包装现有层
LoRADense
:包装nn.Dense
层。LoRAEinsum
:包装nn.Einsum
层。
在 Flax 模块内部的使用示例
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
layer = peft.LoRADense(
rank=3,
wrapped=nn.Dense(10),
)
return layer(x)
请注意,每个包装器都有一个关联的底层模块,该模块仅执行 x @ A @ B
矩阵乘法。 例如 peft.LoRADense
-> peft.LoRADenseAdapter
。 在这种情况下,与原始输出的求和必须手动完成。
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
dense = nn.Dense(10)
lora = peft.LoRADenseAdapter(rank=3) # Only do `x @ A @ B`
return dense(x) + lora(x)
量化#
我们提供了两个新的 API,一个用于应用量化,另一个用于使用量化感知训练和直通估计来训练/优化检查点。
应用#
与 LoRA 相反,我们建议您首先使用模拟训练模型以创建相关的检查点。 然后,量化加载的参数。
params_q = peft.quantize_checkpoint(
params, method=peft.QuantizationMethod.INT4
)
注意:目前仅支持 peft.QuantizationMethod.INT4 方法。
然后,与为 LoRA 适配器引入的内容类似,我们添加了量化模拟包装器
Int4Dense
:包装nn.Dense
层。Int4Einsum
:包装nn.Einsum
层。
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
layer = peft.Int4Dense(
wrapped=nn.Dense(10),
method=peft.QuantizationMethod.Q4_0
)
return layer(x)
模拟#
与为 LoRA 适配器引入的内容类似,我们添加了量化模拟包装器
SimulateQuantizedDense
:包装nn.Dense
层。SimulateQuantizedEinsum
:包装nn.Einsum
层。
在 Flax 模块内部的使用示例
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
layer = peft.SimulateQuantizedDense(
wrapped=nn.Dense(10),
method=peft.QuantizationMethod.Q4_0
)
return layer(x)
模型手术#
该库提供了一些实用程序,通过将其包装版本替换模块来帮助进行模型手术。 例如
def _replace_dense_by_lora(module: nn.Module) -> nn.Module:
if isinstance(module, nn.Dense):
return peft.LoRADense(rank=3, wrapped=module)
else:
return module
# Within the context, the dense layers are replaced by their LoRA version.
with ModuleInterceptor(_replace_dense_by_lora):
y = model(x)
关于量化,特别是 Q4_0 的特别说明。 它假设 FFN 的每个第一层的权重都被转置。 这可以通过以下方式实现
def _apply_q4_0_to_dense(module: nn.Module) -> nn.Module:
if isinstance(module, nn.Dense):
if 'gating' in module.name.lower():
return peft.SimulateQuantizedDense(
wrapped=module,
method=peft.QuantizationMethod.Q4_0_TRANSPOSE,
)
else:
return peft.SimulateQuantizedDense(
wrapped=module,
method=peft.QuantizationMethod.Q4_0,
)
else:
return module
# Within the context, the dense layers are replaced by their LoRA version.
with ModuleInterceptor(_apply_q4_0):
y = model(x)
参数手术#
用于参数树结构操作
peft.split_params
:将参数嵌套字典拆分为 2 个树:一个仅包含原始参数,另一个仅包含 LoRA 参数。peft.merge_params
:split_params
的逆操作。 将 2 个树合并为单个树。
params = {
'dense': {
'kernel': 0,
'bias': 1,
'lora': {
'a': 0,
'b': 1,
},
},
'other': 0,
}
original, lora = peft.split_params(params)
assert original == {
'dense': {
'kernel': 0,
'bias': 1,
},
'other': 0,
}
assert lora == {
'dense': {
'lora': {
'a': 0,
'b': 1,
},
},
}
用于融合 LoRA 参数
peft.fuse_params
:将 LoRA 参数融合到原始参数权重中。peft.unfuse_params
:fuse_params
的逆操作,恢复 LoRA 参数。