gm.nn.Output#
- class gemma.gm.nn.Output(
- logits: jaxtyping.Float[Array, '*B L V'] | jaxtyping.Float[ndarray, '*B L V'] | jaxtyping.Float[Array, '*B V'] | jaxtyping.Float[ndarray, '*B V'],
- cache: dict[str, dict[str, jax.Array]] | None,
基类:
object
Gemma 模型的输出。
- logits
模型的预测 logits。
- cache
如果输入缓存不是 None,则为更新后的缓存;否则为 None。
- 类型:
dict[str, dict[str, jax.Array]] | None
- logits: jaxtyping.Float[Array, '*B L V'] | jaxtyping.Float[ndarray, '*B L V'] | jaxtyping.Float[Array, '*B V'] | jaxtyping.Float[ndarray, '*B V']
- cache: dict[str, dict[str, jax.Array]] | None
- replace(**updates)
返回一个新对象,用新值替换指定的字段。