peft.LoRADense#
- class gemma.peft.LoRADense(*, rank: int, wrapped: flax.linen.linear.Dense, dtype: numpy.dtype = <class 'jax.numpy.float64'>, a_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function variance_scaling.<locals>.init>, b_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function zeros>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
- 基类: - flax.linen.module.Module- 围绕 nn.Dense 的包装器,添加了 LoRA 适配器。 - rank: int
 - wrapped: flax.linen.linear.Dense
 - dtype
- jax.numpy.float64的别名
 - a_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
 - b_init(shape: collections.abc.Sequence[int | typing.Any], dtype: typing.Any = <class 'jax.numpy.float64'>) jax.Array
- 一个初始化器,返回一个由零填充的常量数组。 - The - key参数被忽略。- >>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) 
 - name: str | None = None
 - parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
 - scope: Scope | None = None