gm.data.Seq2SeqTask#
- class gemma.gm.data.Seq2SeqTask(*, in_prompt: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], in_response: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_input: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_target: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_target_mask: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], drop_inputs: bool = True, tokenizer: gemma.gm.text._tokenizer.Tokenizer, max_length: int, truncate: bool = False, sampling: bool = False)[source]
基类:
grain._src.core.transforms.MapTransform
序列到序列任务。
此任务将
格式化提示和响应以匹配预期的对话模板(即添加 <start_of_turn>user, <end_of_turn>,…)
分词提示和响应。
连接输入和响应以创建模型输入和目标(目标是输入向后移动一个 token)。
创建损失掩码(提示为 0,响应为 1)
将输入和目标填充/截断到最大长度。
示例
# Input: { 'prompt': 'Hello! I would love to visit France.', 'response': 'Bonjour ! J'adorerais visiter la France.', } # Ouptut: { 'input': i32['max_length'], 'target': i32['max_length 1'], 'target_mask': bool['max_length 1'], }
注意
输入和目标是同一个序列,目标向后移动一个 token。
目标的最后一个 token 从输入中截断(因为它没有目标)
- in_prompt
输入键
- 类型:
Any
- in_response
输入键
- 类型:
Any
- out_input
输出键(将添加到示例字典中)
- 类型:
Any
- out_target
输出键(将添加到示例字典中)
- 类型:
Any
- out_target_mask
输出键(将添加到示例字典中)
- 类型:
Any
- drop_inputs
如果为 True,则从输出中删除输入键。
- 类型:
bool
- max_length
序列的最大长度(示例将被填充/截断到此长度)。
- 类型:
int
- truncate
是否将序列截断到最大长度。如果为 False,则长度超过 max_length 的序列将引发错误。
- 类型:
bool
- sampling
如果为 True,数据集将生成原始提示和响应,以便它们可以在
gm.evals.SamplerEvaluator
中使用。- 类型:
bool
- in_prompt: Annotated[Any, <object object at 0x75a909cb7ae0>]
- in_response: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_input: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_target: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_target_mask: Annotated[Any, <object object at 0x75a909cb7ae0>]
- drop_inputs: bool = True
- tokenizer: gemma.gm.text._tokenizer.Tokenizer
- max_length: int
- truncate: bool = False
- sampling: bool = False
- map(element)[source]
映射单个元素。