gm.data.make_seq2seq_fields#
- gemma.gm.data.make_seq2seq_fields(
- prompt: jaxtyping.Int[Array, 'prompt_len'] | jaxtyping.Int[ndarray, 'prompt_len'],
- response: jaxtyping.Int[Array, 'response_len'] | jaxtyping.Int[ndarray, 'response_len'],
创建模型 input、target 和 loss_mask。
从 prompt 和 response 令牌 ID 生成模型 input、target 和 loss_mask。
示例
# Input: prompt = [10, 11, 12, 13], response = [20, 21, 1], # Here, response ends with EOS token. # Ouptut: out.input = [10, 11, 12, 13, 20, 21], out.target = [11, 12, 13, 20, 21, 1], out.target_mask = [ 0, 0, 0, 1, 1, 1],
注意
Input 和 target 是相同的序列,但 target 序列向后移动一个令牌。
target 序列的最后一个令牌会从 input 中截断(因为没有对应的 target)。
- 参数:
prompt – Prompt 令牌。
response – Response 令牌。
- 返回值:
input、target 和 mask,所有长度均为 prompt_len + response_len - 1。