gm.losses.DpoLoss

gm.losses.DpoLoss#

class gemma.gm.losses.DpoLoss(*, step: kontext.Key = 'step', mask: Optional[kontext.Key] = None, weight: int | float | Schedule = 1.0, normalize_by: Literal['mask', 'values'] = 'mask', tau: float | typing.Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.1, label_smoothing: float | typing.Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.0, tokens: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__', sequence_mask: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__', policy_logits: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__', anchor_logits: typing.Annotated[typing.Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__')[source]

基类:kauldron.losses.base.Loss

DPO 损失。

tau

损失的温度。

类型:

float | Callable[[int], float | jaxtyping.Float[Array, ‘’] | jaxtyping.Float[ndarray, ‘’]]

label_smoothing

应用于损失的标签平滑。

类型:

float | Callable[[int], float | jaxtyping.Float[Array, ‘’] | jaxtyping.Float[ndarray, ‘’]]

tokens

要预测的 tokens 的键。

类型:

Any

sequence_mask

序列掩码的键。

类型:

Any

policy_logits

策略 logits 的键。

类型:

Any

anchor_logits

锚点 logits 的键。

类型:

Any

tau: float | Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.1
label_smoothing: float | Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.0
tokens: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
sequence_mask: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
policy_logits: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
anchor_logits: Annotated[Any, <object object at 0x75a909cb7ae0>] = '__KEY_REQUIRED__'
get_values(
*,
tokens: jaxtyping.Int[Array, '*B N L'] | jaxtyping.Int[ndarray, '*B N L'],
sequence_mask: jaxtyping.Bool[Array, '*B N L'] | jaxtyping.Bool[ndarray, '*B N L'],
policy_logits: jaxtyping.Float[Array, '*B N L V'] | jaxtyping.Float[ndarray, '*B N L V'],
anchor_logits: jaxtyping.Float[Array, '*B N L V'] | jaxtyping.Float[ndarray, '*B N L V'],
) jaxtyping.Float[Array, '*B 1'] | jaxtyping.Float[ndarray, '*B 1'][source]

计算 DPO 损失。

empty() kauldron.metrics.base.Metric.State[source]
get_state(
*args,
mask: jaxtyping.Shaped[Array, '...'] | jaxtyping.Shaped[ndarray, '...'] | None = None,
step: int | None = None,
**kwargs,
) kauldron.losses.base.AllReduceMean[source]

计算损失状态,并处理掩码和损失权重。

默认情况下,Loss.State 是 AllReduceMean,它跟踪单个标量损失值,但即使在使用掩码时也能确保正确平均。

参数:
  • *args – 要传递给 get_values 的位置参数。

  • mask – 一个可选的掩码,用于从总数中排除一些损失值。此掩码的形状需要可广播到从 get_values 返回的值的形状。值为 1 表示应包含值(值为 0 表示排除)。

  • step – 当前步数,用于在 self.weight 设置为 schedule 时计算损失权重。否则,step 将被忽略。

  • **kwargs – 要传递给 get_values 的关键字参数。

返回:

Loss.State 的实例(默认为 AllReduceMean),它跟踪单个标量损失值,但即使在使用掩码时也能确保正确平均。此最终损失值可以通过调用 state.compute() 从此状态计算得出。可以选择先减少状态(以在 pmap 后移除设备维度)或与其他(先前的)损失状态合并。