gm.text.SamplingMethod

gm.text.SamplingMethod#

class gemma.gm.text.SamplingMethod[source]

基类:abc.ABC

采样方法的基础类。

abstractmethod get_next_tokens(
logits: jaxtyping.Float[Array, '*B V'] | jaxtyping.Float[ndarray, '*B V'],
rng: jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array,
) jaxtyping.Int[Array, '*B'] | jaxtyping.Int[ndarray, '*B'][source]

返回要生成的下一个 token。

参数:
  • logits – Logits,由模型返回(即 softmax 之前)。

  • rng – 随机密钥。

返回:

要生成的下一个 token。