gm.ckpts.AnchoredPolicyLoader

gm.ckpts.AnchoredPolicyLoader#

class gemma.gm.ckpts.AnchoredPolicyLoader(
*,
policy: kauldron.checkpoints.partial_loader.AbstractPartialLoader,
anchor: kauldron.checkpoints.partial_loader.AbstractPartialLoader | None = None,
)[source]

Bases: kauldron.checkpoints.partial_loader.AbstractPartialLoader

用于 gm.nn.AnchoredPolicy 模型的加载器。

通过提供子转换,分别加载策略和锚点。

这假设子加载器仅覆盖 state.params,而不修改状态的其余部分。

policy: kauldron.checkpoints.partial_loader.AbstractPartialLoader
anchor: kauldron.checkpoints.partial_loader.AbstractPartialLoader | None = None
transform(
state: kauldron.train.train_step.TrainState,
) kauldron.train.train_step.TrainState[source]

通过使用预训练的值更新状态来转换状态。

注意

  • transform 函数可以修改 state 值,但不应修改其结构、形状或数据类型。

  • transform 应该正确地传播给定状态的分片信息。

参数:

state – 要转换的 state 对象

返回:

更新后的 state