gm.losses.DpoLoss#
- class gemma.gm.losses.DpoLoss(*, step: kontext.Key = 'step', mask: Optional[kontext.Key] = None, weight: int | float | Schedule = 1.0, normalize_by: Literal['mask', 'values'] = 'mask', tau: float | typing.Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.1, label_smoothing: float | typing.Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.0, tokens: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__', sequence_mask: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__', policy_logits: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__', anchor_logits: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__')[source]
基类:
kauldron.losses.base.Loss
DPO 损失。
- tau
损失的温度。
- 类型:
float | Callable[[int], float | jaxtyping.Float[Array, ‘’] | jaxtyping.Float[ndarray, ‘’]]
- label_smoothing
应用于损失的标签平滑。
- 类型:
float | Callable[[int], float | jaxtyping.Float[Array, ‘’] | jaxtyping.Float[ndarray, ‘’]]
- tokens
要预测的 tokens 的键。
- 类型:
Any
- sequence_mask
序列掩码的键。
- 类型:
Any
- policy_logits
策略 logits 的键。
- 类型:
Any
- anchor_logits
锚点 logits 的键。
- 类型:
Any
- tau: float | Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.1
- label_smoothing: float | Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.0
- tokens: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
- sequence_mask: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
- policy_logits: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
- anchor_logits: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
- get_values(
- *,
- tokens: jaxtyping.Int[Array, '*B N L'] | jaxtyping.Int[ndarray, '*B N L'],
- sequence_mask: jaxtyping.Bool[Array, '*B N L'] | jaxtyping.Bool[ndarray, '*B N L'],
- policy_logits: jaxtyping.Float[Array, '*B N L V'] | jaxtyping.Float[ndarray, '*B N L V'],
- anchor_logits: jaxtyping.Float[Array, '*B N L V'] | jaxtyping.Float[ndarray, '*B N L V'],
计算 DPO 损失。
- empty() kauldron.metrics.base.Metric.State [source]
- get_state(
- *args,
- mask: jaxtyping.Shaped[Array, '...'] | jaxtyping.Shaped[ndarray, '...'] | None = None,
- step: int | None = None,
- **kwargs,
计算损失状态,并处理掩码和损失权重。
默认情况下,Loss.State 是 AllReduceMean,它跟踪单个标量损失值,但即使在使用掩码时也能确保正确平均。
- 参数:
*args – 要传递给 get_values 的位置参数。
mask – 一个可选的掩码,用于从总数中排除一些损失值。此掩码的形状需要可广播到从 get_values 返回的值的形状。值为 1 表示应包含值(值为 0 表示排除)。
step – 当前步数,用于在 self.weight 设置为 schedule 时计算损失权重。否则,step 将被忽略。
**kwargs – 要传递给 get_values 的关键字参数。
- 返回:
Loss.State 的实例(默认为 AllReduceMean),它跟踪单个标量损失值,但即使在使用掩码时也能确保正确平均。此最终损失值可以通过调用 state.compute() 从此状态计算得出。可以选择先减少状态(以在 pmap 后移除设备维度)或与其他(先前的)损失状态合并。