gm.nn.Gemma3_1B#
- class gemma.gm.nn.Gemma3_1B(
- config: transformer.TransformerConfig = TransformerConfig(num_layers=26,
- num_embed=262144,
- embed_dim=1152,
- hidden_dim=6912,
- num_heads=4,
- head_dim=256,
- num_kv_heads=1,
- final_logit_softcap=None,
- use_post_attn_norm=True,
- use_post_ffw_norm=True,
- attention_types=(<AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>),
- max_cache_length=None,
- query_pre_attn_norm=<QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM: 1>,
- attn_logits_soft_cap=None,
- sliding_window_size=512,
- transpose_gating_einsum=True,
- use_qk_norm=True,
- local_base_frequency=10000,
- global_base_frequency=1000000,
- local_scale_factor=1.0,
- global_scale_factor=1.0,
- mm_extra_vocab_size=0,
- vision_encoder=None),
- return_last_only: bool | None = None,
- dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>,
- tokens: kontext.Key = '__KEY_REQUIRED__',
- images: kontext.Key | None = None,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
gemma.gm.nn._transformer.Transformer
Gemma3 transformer architecture.
- config: transformer.TransformerConfig = TransformerConfig(num_layers=26, num_embed=262144, embed_dim=1152, hidden_dim=6912, num_heads=4, head_dim=256, num_kv_heads=1, final_logit_softcap=None, use_post_attn_norm=True, use_post_ffw_norm=True, attention_types=(<AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>), max_cache_length=None, query_pre_attn_norm=<QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM: 1>, attn_logits_soft_cap=None, sliding_window_size=512, transpose_gating_einsum=True, use_qk_norm=True, local_base_frequency=10000, global_base_frequency=1000000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
- INFO: ClassVar[ModelInfo] = ModelInfo(tokenizer_version=3, default_ckpt=None)
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: Scope | None = None