gm.data.AddSeq2SeqFields#
- class gemma.gm.data.AddSeq2SeqFields(*, in_prompt: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], in_response: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_input: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_target: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_target_mask: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>])[source]
基类:
grain._src.core.transforms.MapTransform
添加模型的 input、target 和 loss_mask。
从 prompt 和 response 令牌 ID 生成模型的 input、target 和 loss_mask。
示例
# Input: { 'prompt': [10, 11, 12, 13], 'response': [20, 21, 1], # Here, response ends with EOS token. } # Ouptut: { 'input': [10, 11, 12, 13, 20, 21], 'target': [11, 12, 13, 20, 21, 1], 'target_mask': [ 0, 0, 0, 1, 1, 1], }
注意
Input 和 target 是相同的序列,但 target 比 input 偏移一个令牌。
target 的最后一个令牌会从 input 中截断(因为它没有对应的 target)。
- in_prompt
输入键
- 类型:
任意类型
- in_response
输入键
- 类型:
任意类型
- out_input
输出键 (将添加到示例字典中)
- 类型:
任意类型
- out_target
输出键 (将添加到示例字典中)
- 类型:
任意类型
- out_target_mask
输出键 (将添加到示例字典中)
- 类型:
任意类型
- in_prompt: Annotated[Any, <object object at 0x75a909cb7ae0>]
- in_response: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_input: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_target: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_target_mask: Annotated[Any, <object object at 0x75a909cb7ae0>]
- map(element)[source]
映射单个元素。