gm.data.ContrastiveTask

gm.data.ContrastiveTask#

class gemma.gm.data.ContrastiveTask(*, in_prompt: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], in_chosen: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], in_rejected: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_tokens: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_targets: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], out_mask: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>], tokenizer: gemma.gm.text._tokenizer.Tokenizer, max_length: int, truncate: bool = False, drop_inputs: bool = True)[source]

基类:grain._src.core.transforms.MapTransform

为 DPO 类型的损失创建对比模型输入。

输入

{
    'prompt': 'How much are 2+2 ?',
    'chosen': 'Yes, this is 4.',
    'rejected': 'Of course, 2+2 is 42.',
}

输出

{
    'tokens': i32['2 max_length'],
    'mask': bool['2 max_length'],
}

在输出中,[chosen, rejected] 标记 ID 按顺序堆叠。

in_prompt

输入键

类型:

Any

in_chosen

输入键

类型:

Any

in_rejected

输入键

类型:

Any

out_tokens

输出键(将添加到示例字典中)

类型:

Any

out_mask

输出键(将添加到示例字典中)

类型:

Any

tokenizer

要使用的分词器。

类型:

gemma.gm.text._tokenizer.Tokenizer

max_length

序列的最大长度(示例将被填充/截断到此长度)。

类型:

int

truncate

是否将序列截断到最大长度。如果 False,则长度超过 max_length 的序列将引发错误。

类型:

bool

drop_inputs

如果为 True,则从输出中删除输入键。

类型:

bool

in_prompt: Annotated[Any, <object object at 0x75a909cb7ae0>]
in_chosen: Annotated[Any, <object object at 0x75a909cb7ae0>]
in_rejected: Annotated[Any, <object object at 0x75a909cb7ae0>]
out_tokens: Annotated[Any, <object object at 0x75a909cb7ae0>]
out_targets: Annotated[Any, <object object at 0x75a909cb7ae0>]
out_mask: Annotated[Any, <object object at 0x75a909cb7ae0>]
tokenizer: gemma.gm.text._tokenizer.Tokenizer
max_length: int
truncate: bool = False
drop_inputs: bool = True
map(element)[source]

映射单个元素。