gm.data.make_seq2seq_fields

gm.data.make_seq2seq_fields#

gemma.gm.data.make_seq2seq_fields(
prompt: jaxtyping.Int[Array, 'prompt_len'] | jaxtyping.Int[ndarray, 'prompt_len'],
response: jaxtyping.Int[Array, 'response_len'] | jaxtyping.Int[ndarray, 'response_len'],
) gemma.gm.data._functional.Seq2SeqFields[source]

创建模型 inputtargetloss_mask

从 prompt 和 response 令牌 ID 生成模型 inputtargetloss_mask

示例

# Input:
prompt = [10, 11, 12, 13],
response = [20, 21, 1],  # Here, response ends with EOS token.

# Ouptut:
out.input =       [10, 11, 12, 13, 20, 21],
out.target =      [11, 12, 13, 20, 21,  1],
out.target_mask = [ 0,  0,  0,  1,  1,  1],

注意

  • Input 和 target 是相同的序列,但 target 序列向后移动一个令牌。

  • target 序列的最后一个令牌会从 input 中截断(因为没有对应的 target)。

参数:
  • prompt – Prompt 令牌。

  • response – Response 令牌。

返回值:

input、target 和 mask,所有长度均为 prompt_len + response_len - 1