gm.nn.Gemma2_9B

gm.nn.Gemma2_9B#

class gemma.gm.nn.Gemma2_9B(
config: transformer.TransformerConfig = TransformerConfig(num_layers=42,
num_embed=256128,
embed_dim=3584,
hidden_dim=14336,
num_heads=16,
head_dim=256,
num_kv_heads=8,
final_logit_softcap=30.0,
use_post_attn_norm=True,
use_post_ffw_norm=True,
attention_types=(<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>),
max_cache_length=None,
query_pre_attn_norm=<QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM: 1>,
attn_logits_soft_cap=50.0,
sliding_window_size=4096,
transpose_gating_einsum=True,
use_qk_norm=False,
local_base_frequency=10000,
global_base_frequency=10000,
local_scale_factor=1.0,
global_scale_factor=1.0,
mm_extra_vocab_size=0,
vision_encoder=None),
return_last_only: bool | None = None,
dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>,
tokens: kontext.Key = '__KEY_REQUIRED__',
images: kontext.Key | None = None,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: gemma.gm.nn._transformer.Transformer

Gemma2 transformer architecture.

config: transformer.TransformerConfig = TransformerConfig(num_layers=42, num_embed=256128, embed_dim=3584, hidden_dim=14336, num_heads=16, head_dim=256, num_kv_heads=8, final_logit_softcap=30.0, use_post_attn_norm=True, use_post_ffw_norm=True, attention_types=(<AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>), max_cache_length=None, query_pre_attn_norm=<QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM: 1>, attn_logits_soft_cap=50.0, sliding_window_size=4096, transpose_gating_einsum=True, use_qk_norm=False, local_base_frequency=10000, global_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
INFO: ClassVar[ModelInfo] = ModelInfo(tokenizer_version=2, default_ckpt=<CheckpointPath.GEMMA2_9B_IT: 'gs://gemma-data/checkpoints/gemma2-9b-it'>)
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None