gm.nn.Transformer#
- class gemma.gm.nn.Transformer(
- config: gemma.transformer.TransformerConfig,
- return_last_only: bool | None = None,
- dtype: numpy.dtype = <class 'jax.numpy.bfloat16'>,
- tokens: typing.Annotated[typing.Any,
- <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__',
- images: typing.Annotated[typing.Any,
- <object object at 0x75a909cb7ae0>] | None = None,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
基类:
gemma.transformer.Transformer
基础 Transformer 类。
- return_last_only
如果为 True,则仅计算并返回最后一个 token。否则,返回所有 logits。默认为 False。
- 类型:
bool | None
- dtype
参数 dtype。默认为 jnp.bfloat16。
- 类型:
numpy.dtype
- return_last_only: bool | None = None
- dtype
jax.numpy.bfloat16
的别名
- tokens: kontext.Key = '__KEY_REQUIRED__'
- images: kontext.Key | None = None
- INFO: ClassVar[ModelInfo] = ModelInfo(tokenizer_version=None, default_ckpt=None)
- init_cache(
- *,
- batch_size: int,
- dtype: numpy.dtype[Any],
- cache_length: int,
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None