gm.text.Sampler

gm.text.Sampler#

class gemma.gm.text.Sampler(*, model: gemma.gm.nn._transformer.Transformer, params: typing.Mapping[str, typing.Any], tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int = 4096, max_out_length: int = 2048)[源代码]

基类: object

采样器。

这是一个底层 API。对于大多数用例,建议使用 gm.text.ChatSampler

sampler = Sampler(
    model=model,
    params=params,
)

output = sampler.sample(prompt)

此采样器

  • 是无状态的(状态必须在调用之间手动传递)

  • 用户必须使用 <start_of_turn> 等手动格式化提示。

  • BOS(序列开始)标记会自动添加。

模型

Gemma Transformer 模型。

类型:

gemma.gm.nn._transformer.Transformer

参数

模型参数。

类型:

Mapping[str, Any]

分词器

分词器。

类型:

gemma.gm.text._tokenizer.Tokenizer

采样

要使用的采样方法。默认为贪婪采样。

类型:

gemma.gm.text._sampling.SamplingMethod

forbidden_tokens

禁止生成的标记列表。如果提供 str,则应映射到词汇表中的单个标记 ID。

类型:

collections.abc.Sequence[str | int] | None

cache_length

要使用的缓存长度。这是对话可以拥有的最大标记数(所有轮次的提示、答案、图像)。将其设置为固定值可以避免轮次之间的重新编译。

类型:

int

max_out_length

单轮输出缓冲区的长度。用于避免触发 jit 重新编译的静态值。除非您的任务中模型生成非常长的输出,否则不应更改。

类型:

int

model: gemma.gm.nn._transformer.Transformer
params: Mapping[str, Any]
tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
sampling: gemma.gm.text._sampling.SamplingMethod
forbidden_tokens: collections.abc.Sequence[str | int] | None = None
cache_length: int = 4096
max_out_length: int = 2048
sample(
prompt: str,
*,
images: jaxtyping.UInt8[Array, 'N? H W C'] | jaxtyping.UInt8[ndarray, 'N? H W C'] | None = None,
max_new_tokens: int | None = None,
sampling: gemma.gm.text._sampling.SamplingMethod = None,
rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
return_state: Literal[False] = False,
last_state: gemma.gm.text._sampler_call.SamplingState | None = None,
sharding: kd.sharding.ShardingTree | None = None,
) str[源代码]
sample(
prompt: collections.abc.Sequence[str],
*,
images: collections.abc.Sequence[jaxtyping.UInt8[Array, 'N H W C'] | jaxtyping.UInt8[ndarray, 'N H W C']] | None = None,
max_new_tokens: int | None = None,
sampling: gemma.gm.text._sampling.SamplingMethod = None,
rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
return_state: Literal[False] = False,
last_state: gemma.gm.text._sampler_call.SamplingState | None = None,
sharding: kd.sharding.ShardingTree | None = None,
) list[str]
sample(
prompt: str | collections.abc.Sequence[str],
*,
images: jaxtyping.UInt8[Array, 'B? N? H W C'] | jaxtyping.UInt8[ndarray, 'B? N? H W C'] | None = None,
max_new_tokens: int | None = None,
sampling: gemma.gm.text._sampling.SamplingMethod = None,
rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
return_state: Literal[True] = False,
last_state: gemma.gm.text._sampler_call.SamplingState | None = None,
sharding: kd.sharding.ShardingTree | None = None,
) gemma.gm.text._sampler.SamplerOutput

从模型中采样一个字符串。

示例

prompt = """<start_of_turn>user
I'm hesitating between those two options:

Option 1: <start_of_image>
Option 2: <start_of_image>

Any thoughts ?
<end_of_turn>
<start_of_turn>model
"""
sampler.sample(prompt, images=images))
参数:
  • prompt – 要从中采样的提示。可以是单个字符串或字符串列表。

  • images – 提示的图像。图像应插入到提示中的位置由提示中的 <start_of_image> 标记确定。

  • max_new_tokens – 要生成的最大新标记数。Transformer 将处理 input_length + max_new_tokens

  • sampling – 要使用的采样方法。如果给定,将覆盖默认采样方法。

  • rng – 用于采样方法的种子。如果为 None,则使用随机种子。可以是种子 intjax.random.PRNGKey 对象。

  • return_state – 如果为 True,则返回带有输出附加值(logits、缓存等)的 SamplerOutput 对象。

  • last_state – 当 return_state=True 时,状态可以在对采样器的调用之间传播,用于多轮对话。使用 gm.text.ChatSampler 以获得更简单的 API,它可以为您处理状态。

  • sharding – 如果提供,则根据指定的分片对标记进行分片。用户负责确保标记化的提示与分片兼容。例如,如果 sharding=kd.sharding.FIRST_DIM,则提示的数量必须可被设备数量整除。

返回:

采样的输出。