gm.text.Sampler#
- class gemma.gm.text.Sampler(*, model: gemma.gm.nn._transformer.Transformer, params: typing.Mapping[str, typing.Any], tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int = 4096, max_out_length: int = 2048)[源代码]
基类:
object
采样器。
这是一个底层 API。对于大多数用例,建议使用
gm.text.ChatSampler
。sampler = Sampler( model=model, params=params, ) output = sampler.sample(prompt)
此采样器
是无状态的(状态必须在调用之间手动传递)
用户必须使用 <start_of_turn> 等手动格式化提示。
BOS(序列开始)标记会自动添加。
- 模型
Gemma Transformer 模型。
- 类型:
gemma.gm.nn._transformer.Transformer
- 参数
模型参数。
- 类型:
Mapping[str, Any]
- 分词器
分词器。
- 类型:
gemma.gm.text._tokenizer.Tokenizer
- 采样
要使用的采样方法。默认为贪婪采样。
- 类型:
gemma.gm.text._sampling.SamplingMethod
- forbidden_tokens
禁止生成的标记列表。如果提供 str,则应映射到词汇表中的单个标记 ID。
- 类型:
collections.abc.Sequence[str | int] | None
- cache_length
要使用的缓存长度。这是对话可以拥有的最大标记数(所有轮次的提示、答案、图像)。将其设置为固定值可以避免轮次之间的重新编译。
- 类型:
int
- max_out_length
单轮输出缓冲区的长度。用于避免触发 jit 重新编译的静态值。除非您的任务中模型生成非常长的输出,否则不应更改。
- 类型:
int
- model: gemma.gm.nn._transformer.Transformer
- params: Mapping[str, Any]
- tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
- sampling: gemma.gm.text._sampling.SamplingMethod
- forbidden_tokens: collections.abc.Sequence[str | int] | None = None
- cache_length: int = 4096
- max_out_length: int = 2048
- sample(
- prompt: str,
- *,
- images: jaxtyping.UInt8[Array, 'N? H W C'] | jaxtyping.UInt8[ndarray, 'N? H W C'] | None = None,
- max_new_tokens: int | None = None,
- sampling: gemma.gm.text._sampling.SamplingMethod = None,
- rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
- return_state: Literal[False] = False,
- last_state: gemma.gm.text._sampler_call.SamplingState | None = None,
- sharding: kd.sharding.ShardingTree | None = None,
- sample(
- prompt: collections.abc.Sequence[str],
- *,
- images: collections.abc.Sequence[jaxtyping.UInt8[Array, 'N H W C'] | jaxtyping.UInt8[ndarray, 'N H W C']] | None = None,
- max_new_tokens: int | None = None,
- sampling: gemma.gm.text._sampling.SamplingMethod = None,
- rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
- return_state: Literal[False] = False,
- last_state: gemma.gm.text._sampler_call.SamplingState | None = None,
- sharding: kd.sharding.ShardingTree | None = None,
- sample(
- prompt: str | collections.abc.Sequence[str],
- *,
- images: jaxtyping.UInt8[Array, 'B? N? H W C'] | jaxtyping.UInt8[ndarray, 'B? N? H W C'] | None = None,
- max_new_tokens: int | None = None,
- sampling: gemma.gm.text._sampling.SamplingMethod = None,
- rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
- return_state: Literal[True] = False,
- last_state: gemma.gm.text._sampler_call.SamplingState | None = None,
- sharding: kd.sharding.ShardingTree | None = None,
从模型中采样一个字符串。
示例
prompt = """<start_of_turn>user I'm hesitating between those two options: Option 1: <start_of_image> Option 2: <start_of_image> Any thoughts ? <end_of_turn> <start_of_turn>model """ sampler.sample(prompt, images=images))
- 参数:
prompt – 要从中采样的提示。可以是单个字符串或字符串列表。
images – 提示的图像。图像应插入到提示中的位置由提示中的 <start_of_image> 标记确定。
max_new_tokens – 要生成的最大新标记数。Transformer 将处理 input_length + max_new_tokens。
sampling – 要使用的采样方法。如果给定,将覆盖默认采样方法。
rng – 用于采样方法的种子。如果为 None,则使用随机种子。可以是种子 int 或 jax.random.PRNGKey 对象。
return_state – 如果为 True,则返回带有输出附加值(logits、缓存等)的 SamplerOutput 对象。
last_state – 当 return_state=True 时,状态可以在对采样器的调用之间传播,用于多轮对话。使用
gm.text.ChatSampler
以获得更简单的 API,它可以为您处理状态。sharding – 如果提供,则根据指定的分片对标记进行分片。用户负责确保标记化的提示与分片兼容。例如,如果 sharding=kd.sharding.FIRST_DIM,则提示的数量必须可被设备数量整除。
- 返回:
采样的输出。