peft.LoRAEinsumAdapter

peft.LoRAEinsumAdapter#

class gemma.peft.LoRAEinsumAdapter(*, rank: int, einsum_str: str, shape: collections.abc.Sequence[int], dtype: numpy.dtype = <class 'jax.numpy.float64'>, a_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function variance_scaling.<locals>.init>, b_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function zeros>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

基类:flax.linen.module.Module

LoRA einsum 模块。

此模块仅执行 x @ A @ B 计算。使用 LoRAEinsum 封装 nn.Einsum 层。

LoRA 分解的秩。

类型:

int

einsum_str

原始 einsum 操作的 einsum 字符串。应为 inputs,weights->outputs (这将在内部重写为 inputs,a,b->outputs)

类型:

str

形状

低秩适配前原始权重的形状。应与 einsum_str 中的 weights 形状匹配。

类型:

collections.abc.Sequence[int]

dtype

用于 LoRA 权重的 dtype。

类型:

numpy.dtype

a_init

A 矩阵的初始化器。

类型:

jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]

b_init

B 矩阵的初始化器。

类型:

jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]

rank: int
einsum_str: str
shape: collections.abc.Sequence[int]
dtype

别名 jax.numpy.float64

a_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
b_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array

返回一个由零填充的常数数组的初始化器。

key 参数被忽略。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
setup()[source]

延迟初始化模块(类似于延迟 __init__)。

当模块被绑定时,在调用任何其他方法(如 __call__)之前,或在访问 selfsetup 定义的属性之前,setup 会在模块实例上延迟调用一次。

这可能发生在三种情况下

  1. 当立即调用 apply()init()init_and_output() 时。

  2. 一旦模块通过在另一个模块的 setup 方法中分配给另一个模块的属性而被命名(请参阅 __setattr__()

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. 一旦模块在用 compact() 包装的方法内部构建,立即在调用另一个方法或访问 setup 定义的属性之前。

name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None