peft.LoRADenseAdapter

peft.LoRADenseAdapter#

class gemma.peft.LoRADenseAdapter(*, rank: int, features: int, dtype: numpy.dtype = <class 'jax.numpy.float64'>, a_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function variance_scaling.<locals>.init>, b_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function zeros>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

基类:flax.linen.module.Module

LoRA 模块。

此模块仅执行 x @ A @ B 计算。使用 LoRADense 包装 nn.Dense 层。

rank: int
features: int
dtype

别名:jax.numpy.float64

a_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
b_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array

一个返回充满零的常数数组的初始化器。

key 参数被忽略。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None