peft.LoRAEinsumAdapter#
- class gemma.peft.LoRAEinsumAdapter(*, rank: int, einsum_str: str, shape: collections.abc.Sequence[int], dtype: numpy.dtype = <class 'jax.numpy.float64'>, a_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function variance_scaling.<locals>.init>, b_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function zeros>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
基类:
flax.linen.module.Module
LoRA einsum 模块。
此模块仅执行 x @ A @ B 计算。使用
LoRAEinsum
封装 nn.Einsum 层。- 秩
LoRA 分解的秩。
- 类型:
int
- einsum_str
原始 einsum 操作的 einsum 字符串。应为 inputs,weights->outputs (这将在内部重写为 inputs,a,b->outputs)
- 类型:
str
- 形状
低秩适配前原始权重的形状。应与 einsum_str 中的 weights 形状匹配。
- 类型:
collections.abc.Sequence[int]
- dtype
用于 LoRA 权重的 dtype。
- 类型:
numpy.dtype
- a_init
A 矩阵的初始化器。
- 类型:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- b_init
B 矩阵的初始化器。
- 类型:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- rank: int
- einsum_str: str
- shape: collections.abc.Sequence[int]
- dtype
别名
jax.numpy.float64
- a_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
- b_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
返回一个由零填充的常数数组的初始化器。
key
参数被忽略。>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- setup()[source]
延迟初始化模块(类似于延迟
__init__
)。当模块被绑定时,在调用任何其他方法(如
__call__
)之前,或在访问self
上setup
定义的属性之前,setup
会在模块实例上延迟调用一次。这可能发生在三种情况下
当立即调用
apply()
、init()
或init_and_output()
时。一旦模块通过在另一个模块的
setup
方法中分配给另一个模块的属性而被命名(请参阅__setattr__()
)>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
一旦模块在用
compact()
包装的方法内部构建,立即在调用另一个方法或访问setup
定义的属性之前。
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None