gm.text.ChatSampler

gm.text.ChatSampler#

class gemma.gm.text.ChatSampler(*, model: gemma.gm.nn._transformer.Transformer, params: typing.Mapping[str, typing.Any], multi_turn: bool = False, tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int | None = 4096, max_out_length: int = 2048, last_state: gemma.gm.text._sampler_call.SamplingState = None, turns: list[gemma.gm.text._template.Turn] = <factory>)[source]

Bases: object

聊天采样器。

sampler = ChatSampler(
    model=model,
    params=params,
    multi_turn=True,
)

output0 = sampler.chat('Write a poem about cats.')
output1 = sampler.chat('And about dogs.')
output2 = sampler.chat('Which one do you prefer?')

此采样器

  • 是有状态的(KV-缓存状态自动处理)

  • 自动使用 <start_of_turn><end_of_turn> 标记格式化提示,添加 BOS(序列开始)标记。并从输出中过滤掉 <end_of_turn> 标记。

model

Gemma Transformer 模型。

类型:

gemma.gm.nn._transformer.Transformer

params

模型参数。

类型:

Mapping[str, Any]

multi_turn

如果为 True,则重用之前的轮次作为上下文。

类型:

bool

tokenizer

分词器。

类型:

gemma.gm.text._tokenizer.Tokenizer

sampling

要使用的采样方法。默认为贪婪采样。

类型:

gemma.gm.text._sampling.SamplingMethod

forbidden_tokens

禁止生成的标记列表。如果提供 str,则应映射到词汇表中的单个标记 ID。

类型:

collections.abc.Sequence[str | int] | None

cache_length

要使用的缓存长度。这是对话可以拥有的最大标记数(所有轮次的提示、答案、图像)。将此设置为固定值可以避免轮次之间重新编译。

类型:

int | None

max_out_length

单轮输出缓冲区的长度。用于避免触发 jit 重新编译的静态值。除非您的任务需要模型生成非常长的输出,否则不应更改此值。

类型:

int

last_state

采样器的最后状态,由采样器自动处理,但为了方便高级用户访问 logits、缓存等或初始化采样器而公开。

类型:

gemma.gm.text._sampler_call.SamplingState

turns

当前的对话。

类型:

list[gemma.gm.text._template.Turn]

model: gemma.gm.nn._transformer.Transformer
params: Mapping[str, Any]
multi_turn: bool = False
tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
sampling: gemma.gm.text._sampling.SamplingMethod
forbidden_tokens: collections.abc.Sequence[str | int] | None = None
cache_length: int | None = 4096
max_out_length: int = 2048
last_state: gemma.gm.text._sampler_call.SamplingState = None
turns: list[gemma.gm.text._template.Turn]
property sampler: gemma.gm.text._sampler.Sampler

返回底层采样器。

chat(
prompt: str,
*,
images: jaxtyping.UInt8[Array, 'N? H W C'] | jaxtyping.UInt8[ndarray, 'N? H W C'] | None = None,
sampling: gemma.gm.text._sampling.SamplingMethod | None = None,
rng: int | Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
max_new_tokens: int | None = None,
multi_turn: bool | None = None,
) str[source]

从模型中采样一个字符串。

示例

prompt = """I'm hesitating between those two options:

Option 1: <start_of_image>
Option 2: <start_of_image>

Any thoughts ?"""
sampler.sample(prompt, images=[image1, image2]))
参数:
  • prompt – 要从中采样的提示。可以是单个字符串或字符串列表。

  • images – 提示的图像。图像应插入到提示中的位置由提示中的 <start_of_image> 标记确定。

  • sampling – 要使用的采样方法。如果给定,将覆盖默认采样方法。

  • rng – 用于采样方法的种子。如果为 None,则使用随机种子。可以是种子 intjax.random.PRNGKey 对象。

  • max_new_tokens – 如果给定,将在这么多标记后停止采样。用于在调试时进行更快的迭代。默认情况下,采样将持续到找到 <end_of_turn> 标记,或直到 max_out_length 缓冲区被填满。

  • multi_turn – 如果为 True,则重用之前的轮次作为上下文。覆盖 multi_turn 属性。

返回:

采样的输出。