gm.nn.Output

gm.nn.Output#

class gemma.gm.nn.Output(
logits: jaxtyping.Float[Array, '*B L V'] | jaxtyping.Float[ndarray, '*B L V'] | jaxtyping.Float[Array, '*B V'] | jaxtyping.Float[ndarray, '*B V'],
cache: dict[str, dict[str, jax.Array]] | None,
)[source]

基类: object

Gemma 模型的输出。

logits

模型的预测 logits。

类型:

jaxtyping.Float[Array, ‘*B L V’] | jaxtyping.Float[ndarray, ‘*B L V’] | jaxtyping.Float[Array, ‘*B V’] | jaxtyping.Float[ndarray, ‘*B V’]

cache

如果输入缓存不是 None,则为更新后的缓存;否则为 None。

类型:

dict[str, dict[str, jax.Array]] | None

logits: jaxtyping.Float[Array, '*B L V'] | jaxtyping.Float[ndarray, '*B L V'] | jaxtyping.Float[Array, '*B V'] | jaxtyping.Float[ndarray, '*B V']
cache: dict[str, dict[str, jax.Array]] | None
replace(**updates)

返回一个新对象,用新值替换指定的字段。