参数高效微调 (PEFT)

参数高效微调 (PEFT)#

这个迷你库为 Flax linen 模块添加 LoRA 支持。

LoRA 适配器#

一些 Flax linen nn.Modules 可用于包装现有层

在 Flax 模块内部的使用示例

class MyModel(nn.Module):

  @nn.compact
  def __call__(self, x):
    layer = peft.LoRADense(
        rank=3,
        wrapped=nn.Dense(10),
    )
    return layer(x)

请注意,每个包装器都有一个关联的底层模块,该模块仅执行 x @ A @ B 矩阵乘法。 例如 peft.LoRADense -> peft.LoRADenseAdapter。 在这种情况下,与原始输出的求和必须手动完成。

class MyModel(nn.Module):

  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(10)
    lora = peft.LoRADenseAdapter(rank=3)  # Only do `x @ A @ B`
    return dense(x) + lora(x)

量化#

我们提供了两个新的 API,一个用于应用量化,另一个用于使用量化感知训练和直通估计来训练/优化检查点。

应用#

与 LoRA 相反,我们建议您首先使用模拟训练模型以创建相关的检查点。 然后,量化加载的参数。

params_q = peft.quantize_checkpoint(
  params, method=peft.QuantizationMethod.INT4
)

注意:目前仅支持 peft.QuantizationMethod.INT4 方法。

然后,与为 LoRA 适配器引入的内容类似,我们添加了量化模拟包装器

  • Int4Dense:包装 nn.Dense 层。

  • Int4Einsum:包装 nn.Einsum 层。

class MyModel(nn.Module):

  @nn.compact
  def __call__(self, x):
    layer = peft.Int4Dense(
        wrapped=nn.Dense(10),
        method=peft.QuantizationMethod.Q4_0
    )
    return layer(x)

模拟#

与为 LoRA 适配器引入的内容类似,我们添加了量化模拟包装器

在 Flax 模块内部的使用示例

class MyModel(nn.Module):

  @nn.compact
  def __call__(self, x):
    layer = peft.SimulateQuantizedDense(
        wrapped=nn.Dense(10),
        method=peft.QuantizationMethod.Q4_0
    )
    return layer(x)

模型手术#

该库提供了一些实用程序,通过将其包装版本替换模块来帮助进行模型手术。 例如

def _replace_dense_by_lora(module: nn.Module) -> nn.Module:
  if isinstance(module, nn.Dense):
    return peft.LoRADense(rank=3, wrapped=module)
  else:
    return module

# Within the context, the dense layers are replaced by their LoRA version.
with ModuleInterceptor(_replace_dense_by_lora):
  y = model(x)

关于量化,特别是 Q4_0 的特别说明。 它假设 FFN 的每个第一层的权重都被转置。 这可以通过以下方式实现

def _apply_q4_0_to_dense(module: nn.Module) -> nn.Module:
  if isinstance(module, nn.Dense):
    if 'gating' in module.name.lower():
      return peft.SimulateQuantizedDense(
          wrapped=module,
          method=peft.QuantizationMethod.Q4_0_TRANSPOSE,
      )
    else:
      return peft.SimulateQuantizedDense(
          wrapped=module,
          method=peft.QuantizationMethod.Q4_0,
      )
  else:
    return module

# Within the context, the dense layers are replaced by their LoRA version.
with ModuleInterceptor(_apply_q4_0):
  y = model(x)

参数手术#

用于参数树结构操作

params = {
    'dense': {
        'kernel': 0,
        'bias': 1,
        'lora': {
            'a': 0,
            'b': 1,
        },
    },
    'other': 0,
}

original, lora = peft.split_params(params)
assert original == {
    'dense': {
        'kernel': 0,
        'bias': 1,
    },
    'other': 0,
}
assert lora == {
    'dense': {
        'lora': {
            'a': 0,
            'b': 1,
        },
    },
}

用于融合 LoRA 参数