gm.data.ContrastiveTask#
- class gemma.gm.data.ContrastiveTask(*, in_prompt: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], in_chosen: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], in_rejected: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_tokens: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_targets: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_mask: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], tokenizer: gemma.gm.text._tokenizer.Tokenizer, max_length: int, truncate: bool = False, drop_inputs: bool = True)[source]
基类:
grain._src.core.transforms.MapTransform
为 DPO 类型的损失创建对比模型输入。
输入
{ 'prompt': 'How much are 2+2 ?', 'chosen': 'Yes, this is 4.', 'rejected': 'Of course, 2+2 is 42.', }
输出
{ 'tokens': i32['2 max_length'], 'mask': bool['2 max_length'], }
在输出中,[chosen, rejected] 标记 ID 按顺序堆叠。
- in_prompt
输入键
- 类型:
Any
- in_chosen
输入键
- 类型:
Any
- in_rejected
输入键
- 类型:
Any
- out_tokens
输出键(将添加到示例字典中)
- 类型:
Any
- out_mask
输出键(将添加到示例字典中)
- 类型:
Any
- tokenizer
要使用的分词器。
- 类型:
gemma.gm.text._tokenizer.Tokenizer
- max_length
序列的最大长度(示例将被填充/截断到此长度)。
- 类型:
int
- truncate
是否将序列截断到最大长度。如果 False,则长度超过 max_length 的序列将引发错误。
- 类型:
bool
- drop_inputs
如果为 True,则从输出中删除输入键。
- 类型:
bool
- in_prompt: Annotated[Any, <object object at 0x75a909cb7ae0>]
- in_chosen: Annotated[Any, <object object at 0x75a909cb7ae0>]
- in_rejected: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_tokens: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_targets: Annotated[Any, <object object at 0x75a909cb7ae0>]
- out_mask: Annotated[Any, <object object at 0x75a909cb7ae0>]
- tokenizer: gemma.gm.text._tokenizer.Tokenizer
- max_length: int
- truncate: bool = False
- drop_inputs: bool = True
- map(element)[source]
映射单个元素。