peft.ModuleInterceptor

peft.ModuleInterceptor#

class gemma.peft.ModuleInterceptor(
replace_module_fn: Callable[[flax.linen.module.Module], flax.linen.module.Module],
)[source]

基类: gemma.peft._interceptors.Interceptor

拦截器,用于捕获所有模块并最终替换它们。

对于每个模块,此拦截器调用 replace_module_fn 函数,该函数返回要使用的模块。

示例

def _replace_dense_by_lora(module):
  if isinstance(module, nn.Dense):
    return peft.LoRADense(rank=3, wrapped=module)
  else:
    return module

# Within the context, the dense layers are replaced by their LoRA version.
with ModuleInterceptor(_replace_dense_by_lora):
  y = model(x)
replace_module_fn: Callable[[flax.linen.module.Module], flax.linen.module.Module]
interceptor(
next_fun,
args,
kwargs,
context: flax.linen.module.InterceptorContext,
)[source]

返回要拦截的方法的名称。