微调#

Open in Colab

这是一个关于微调 Gemma 的示例。有关如何运行预训练 Gemma 模型的示例,请参阅采样教程。

为了微调 Gemma,我们使用 kauldron 库,它抽象了大部分样板代码(检查点管理、训练循环、评估、指标报告、分片等)。

!pip install -q gemma
# Common imports
import os
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

默认情况下,Jax 不会充分利用 GPU 内存,但这可以被覆盖。请参阅 GPU 内存分配

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

数据管道#

首先创建分词器,因为数据管道中需要它。

tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)
[<_Gemma2SpecialTokens.BOS: 2>, 1596, 603, 671, 3287, 13060]

首先我们需要一个数据管道。支持多种管道,包括

添加您自己的数据或从多个来源创建混合数据非常简单。请参阅管道文档

我们使用 transforms 来自定义数据管道,这包括

请注意,在实践中,您可以将多个 transforms 组合成一个更高级别的 transform。有关示例,请参阅 gm.data.ContrastiveTask() transform 在 DPO 示例 中的应用。

在这里,我们尝试 mtnt,这是一个小型翻译数据集。数据集结构为 {'src': ..., 'dst': ...}

ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://tensorflowcn.cn/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)
Disabling pygrain multi-processing (unsupported in colab).
{
    'input': i64[8 200],
    'loss_mask': bool_[8 200 1],
    'target': i64[8 200 1],
}

我们可以从批次中解码一个示例来检查模型输入。我们看到 <start_of_turn> / <end_of_turn> 被正确添加以遵循 Gemma 对话格式。

text = tokenizer.decode(ex['input'][0])

print(text)
<start_of_turn>user
Would love any other tips from anyone, but specially from someone who’s been where I’m at.<end_of_turn>
<start_of_turn>model
J'apprécierais vraiment d'autres astuces, mais particulièrement par quelqu'un qui était était déjà là où je me trouve.

训练器#

kauldron 训练器允许通过简单地提供数据集、模型、损失和优化器来训练 Gemma。

数据集、模型和损失通过键字符串系统连接在一起。有关更多信息,请参阅键文档

每个键都以注册的前缀开头。常见的前缀包括

  • batch: 数据集的输出(在所有转换之后)。这里我们的批次是 {'input': ..., 'loss_mask': ..., 'target': ...}

  • preds: 模型的输出。对于 Gemma 模型,这是 gm.nn.Output(logits=..., cache=...)

  • params: 模型参数(可用于添加权重衰减损失,或在指标中监控参数范数)

model = gm.nn.Gemma3_4B(
    tokens="batch.input",
)
loss = kd.losses.SoftmaxCrossEntropyWithIntLabels(
    logits="preds.logits",
    labels="batch.target",
    mask="batch.loss_mask",
)

然后我们创建训练器

trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=gm.ckpts.LoadCheckpoint(  # Load the weights from the pretrained checkpoint
        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    ),
    # Training parameters
    num_train_steps=300,
    train_losses={"loss": loss},
    optimizer=optax.adafactor(learning_rate=1e-3),
)

可以使用 .train() 方法启动训练。

请注意,训练器和模型都是不可变的,因此它不存储状态或参数。而是返回包含训练参数的状态。

state, aux = trainer.train()
Configuring ...
Initializing ...
Starting training loop at step 0

检查点保存#

要保存模型参数,您可以选择

  • 通过添加以下内容在训练器中激活检查点保存

    trainer = kd.train.Trainer(
        workdir='/tmp/my_experiment/',
        checkpointer=kd.ckpts.Checkpointer(
            save_interval_steps=500,
        ),
        ...
    )
    

    这也将保存优化器、步数、数据集状态等。

  • 手动保存训练的参数

    gm.ckpts.save_params(state.params, '/tmp/my_ckpt/')
    

评估#

在这里,我们仅通过采样提示执行定性评估。

有关评估的更多信息

  • 有关运行推理的更多信息,请参阅采样教程。

  • 要在训练期间添加评估,请参阅 Kauldron 评估器文档。

sampler = gm.text.ChatSampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

我们测试一个句子,使用与微调期间相同的格式

sampler.chat('Hello! My next holidays are in Paris.')
'Salut ! Mes vacances suivantes seront à Paris.'

模型正确地将我们的提示翻译成了法语!

下一步#

要在 Colab 之外进行微调,您可以查看我们的 examples/ 文件夹,了解更复杂的训练器配置,包括 LoRA、DPO 和分片。