gm.data.AddSeq2SeqFields

gm.data.AddSeq2SeqFields#

class gemma.gm.data.AddSeq2SeqFields(*, in_prompt: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], in_response: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_input: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_target: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_target_mask: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>])[source]

基类: grain._src.core.transforms.MapTransform

添加模型的 inputtargetloss_mask

从 prompt 和 response 令牌 ID 生成模型的 inputtargetloss_mask

示例

# Input:
{
    'prompt': [10, 11, 12, 13],
    'response': [20, 21, 1],  # Here, response ends with EOS token.
}
# Ouptut:
{
    'input':       [10, 11, 12, 13, 20, 21],
    'target':      [11, 12, 13, 20, 21,  1],
    'target_mask': [ 0,  0,  0,  1,  1,  1],
}

注意

  • Input 和 target 是相同的序列,但 target 比 input 偏移一个令牌。

  • target 的最后一个令牌会从 input 中截断(因为它没有对应的 target)。

in_prompt

输入键

类型:

任意类型

in_response

输入键

类型:

任意类型

out_input

输出键 (将添加到示例字典中)

类型:

任意类型

out_target

输出键 (将添加到示例字典中)

类型:

任意类型

out_target_mask

输出键 (将添加到示例字典中)

类型:

任意类型

in_prompt: Annotated[Any, <object object at 0x75a909cb7ae0>]
in_response: Annotated[Any, <object object at 0x75a909cb7ae0>]
out_input: Annotated[Any, <object object at 0x75a909cb7ae0>]
out_target: Annotated[Any, <object object at 0x75a909cb7ae0>]
out_target_mask: Annotated[Any, <object object at 0x75a909cb7ae0>]
map(element)[source]

映射单个元素。