gm.nn.Transformer

gm.nn.Transformer#

class gemma.gm.nn.Transformer(
config: gemma.transformer.TransformerConfig,
return_last_only: bool | None = None,
dtype: numpy.dtype = <class 'jax.numpy.bfloat16'>,
tokens: typing.Annotated[typing.Any,
<object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__',
images: typing.Annotated[typing.Any,
<object object at 0x75a909cb7ae0>] | None = None,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[来源]

基类: gemma.transformer.Transformer

基础 Transformer 类。

return_last_only

如果为 True,则仅计算并返回最后一个 token。否则,返回所有 logits。默认为 False

类型:

bool | None

dtype

参数 dtype。默认为 jnp.bfloat16

类型:

numpy.dtype

return_last_only: bool | None = None
dtype

jax.numpy.bfloat16 的别名

tokens: kontext.Key = '__KEY_REQUIRED__'
images: kontext.Key | None = None
INFO: ClassVar[ModelInfo] = ModelInfo(tokenizer_version=None, default_ckpt=None)
init_cache(
*,
batch_size: int,
dtype: numpy.dtype[Any],
cache_length: int,
) dict[str, dict[str, jax.Array]][来源]
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None