gm.text.ChatSampler#
- class gemma.gm.text.ChatSampler(*, model: gemma.gm.nn._transformer.Transformer, params: typing.Mapping[str, typing.Any], multi_turn: bool = False, tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int | None = 4096, max_out_length: int = 2048, last_state: gemma.gm.text._sampler_call.SamplingState = None, turns: list[gemma.gm.text._template.Turn] = <factory>)[source]
Bases:
object
聊天采样器。
sampler = ChatSampler( model=model, params=params, multi_turn=True, ) output0 = sampler.chat('Write a poem about cats.') output1 = sampler.chat('And about dogs.') output2 = sampler.chat('Which one do you prefer?')
此采样器
是有状态的(KV-缓存状态自动处理)
自动使用 <start_of_turn> 和 <end_of_turn> 标记格式化提示,添加 BOS(序列开始)标记。并从输出中过滤掉 <end_of_turn> 标记。
- model
Gemma Transformer 模型。
- 类型:
gemma.gm.nn._transformer.Transformer
- params
模型参数。
- 类型:
Mapping[str, Any]
- multi_turn
如果为 True,则重用之前的轮次作为上下文。
- 类型:
bool
- tokenizer
分词器。
- 类型:
gemma.gm.text._tokenizer.Tokenizer
- sampling
要使用的采样方法。默认为贪婪采样。
- 类型:
gemma.gm.text._sampling.SamplingMethod
- forbidden_tokens
禁止生成的标记列表。如果提供 str,则应映射到词汇表中的单个标记 ID。
- 类型:
collections.abc.Sequence[str | int] | None
- cache_length
要使用的缓存长度。这是对话可以拥有的最大标记数(所有轮次的提示、答案、图像)。将此设置为固定值可以避免轮次之间重新编译。
- 类型:
int | None
- max_out_length
单轮输出缓冲区的长度。用于避免触发 jit 重新编译的静态值。除非您的任务需要模型生成非常长的输出,否则不应更改此值。
- 类型:
int
- last_state
采样器的最后状态,由采样器自动处理,但为了方便高级用户访问 logits、缓存等或初始化采样器而公开。
- 类型:
gemma.gm.text._sampler_call.SamplingState
- turns
当前的对话。
- 类型:
list[gemma.gm.text._template.Turn]
- model: gemma.gm.nn._transformer.Transformer
- params: Mapping[str, Any]
- multi_turn: bool = False
- tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
- sampling: gemma.gm.text._sampling.SamplingMethod
- forbidden_tokens: collections.abc.Sequence[str | int] | None = None
- cache_length: int | None = 4096
- max_out_length: int = 2048
- last_state: gemma.gm.text._sampler_call.SamplingState = None
- turns: list[gemma.gm.text._template.Turn]
- property sampler: gemma.gm.text._sampler.Sampler
返回底层采样器。
- chat(
- prompt: str,
- *,
- images: jaxtyping.UInt8[Array, 'N? H W C'] | jaxtyping.UInt8[ndarray, 'N? H W C'] | None = None,
- sampling: gemma.gm.text._sampling.SamplingMethod | None = None,
- rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
- max_new_tokens: int | None = None,
- multi_turn: bool | None = None,
从模型中采样一个字符串。
示例
prompt = """I'm hesitating between those two options: Option 1: <start_of_image> Option 2: <start_of_image> Any thoughts ?""" sampler.sample(prompt, images=[image1, image2]))
- 参数:
prompt – 要从中采样的提示。可以是单个字符串或字符串列表。
images – 提示的图像。图像应插入到提示中的位置由提示中的 <start_of_image> 标记确定。
sampling – 要使用的采样方法。如果给定,将覆盖默认采样方法。
rng – 用于采样方法的种子。如果为 None,则使用随机种子。可以是种子 int 或 jax.random.PRNGKey 对象。
max_new_tokens – 如果给定,将在这么多标记后停止采样。用于在调试时进行更快的迭代。默认情况下,采样将持续到找到 <end_of_turn> 标记,或直到 max_out_length 缓冲区被填满。
multi_turn – 如果为 True,则重用之前的轮次作为上下文。覆盖 multi_turn 属性。
- 返回:
采样的输出。