微调#
这是一个关于微调 Gemma 的示例。有关如何运行预训练 Gemma 模型的示例,请参阅采样教程。
为了微调 Gemma,我们使用 kauldron 库,它抽象了大部分样板代码(检查点管理、训练循环、评估、指标报告、分片等)。
!pip install -q gemma
# Common imports
import os
import optax
import treescope
# Gemma imports
from kauldron import kd
from gemma import gm
默认情况下,Jax 不会充分利用 GPU 内存,但这可以被覆盖。请参阅 GPU 内存分配
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
数据管道#
首先创建分词器,因为数据管道中需要它。
tokenizer = gm.text.Gemma3Tokenizer()
tokenizer.encode('This is an example sentence', add_bos=True)
[<_Gemma2SpecialTokens.BOS: 2>, 1596, 603, 671, 3287, 13060]
首先我们需要一个数据管道。支持多种管道,包括
添加您自己的数据或从多个来源创建混合数据非常简单。请参阅管道文档。
我们使用 transforms
来自定义数据管道,这包括
对输入进行分词 (使用
gm.data.Tokenize
)创建模型输入 (使用
gm.data.Tokenize
))添加填充 (使用
gm.data.Pad
) (需要对不同长度的输入进行批处理)
请注意,在实践中,您可以将多个 transforms 组合成一个更高级别的 transform。有关示例,请参阅 gm.data.ContrastiveTask()
transform 在 DPO 示例 中的应用。
在这里,我们尝试 mtnt,这是一个小型翻译数据集。数据集结构为 {'src': ..., 'dst': ...}
。
ds = kd.data.py.Tfds(
name='mtnt/en-fr',
split='train',
shuffle=True,
batch_size=8,
transforms=[
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://tensorflowcn.cn/datasets/catalog/mtnt
in_prompt='src',
in_response='dst',
# Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
out_input='input',
out_target='target',
out_target_mask='loss_mask',
tokenizer=tokenizer,
# Padding parameters
max_length=200,
truncate=True,
),
],
)
ex = ds[0]
treescope.show(ex)
Disabling pygrain multi-processing (unsupported in colab).
{
'input': i64[8 200],
'loss_mask': bool_[8 200 1],
'target': i64[8 200 1],
}
我们可以从批次中解码一个示例来检查模型输入。我们看到 <start_of_turn>
/ <end_of_turn>
被正确添加以遵循 Gemma 对话格式。
text = tokenizer.decode(ex['input'][0])
print(text)
<start_of_turn>user
Would love any other tips from anyone, but specially from someone who’s been where I’m at.<end_of_turn>
<start_of_turn>model
J'apprécierais vraiment d'autres astuces, mais particulièrement par quelqu'un qui était était déjà là où je me trouve.
训练器#
kauldron 训练器允许通过简单地提供数据集、模型、损失和优化器来训练 Gemma。
数据集、模型和损失通过键字符串系统连接在一起。有关更多信息,请参阅键文档。
每个键都以注册的前缀开头。常见的前缀包括
batch
: 数据集的输出(在所有转换之后)。这里我们的批次是{'input': ..., 'loss_mask': ..., 'target': ...}
preds
: 模型的输出。对于 Gemma 模型,这是gm.nn.Output(logits=..., cache=...)
params
: 模型参数(可用于添加权重衰减损失,或在指标中监控参数范数)
model = gm.nn.Gemma3_4B(
tokens="batch.input",
)
loss = kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
)
然后我们创建训练器
trainer = kd.train.Trainer(
seed=42, # The seed of enlightenment
workdir='/tmp/ckpts', # TODO(epot): Make the workdir optional by default
# Dataset
train_ds=ds,
# Model
model=model,
init_transform=gm.ckpts.LoadCheckpoint( # Load the weights from the pretrained checkpoint
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
# Training parameters
num_train_steps=300,
train_losses={"loss": loss},
optimizer=optax.adafactor(learning_rate=1e-3),
)
可以使用 .train()
方法启动训练。
请注意,训练器和模型都是不可变的,因此它不存储状态或参数。而是返回包含训练参数的状态。
state, aux = trainer.train()
Configuring ...
Initializing ...
Starting training loop at step 0
检查点保存#
要保存模型参数,您可以选择
通过添加以下内容在训练器中激活检查点保存
trainer = kd.train.Trainer( workdir='/tmp/my_experiment/', checkpointer=kd.ckpts.Checkpointer( save_interval_steps=500, ), ... )
这也将保存优化器、步数、数据集状态等。
手动保存训练的参数
gm.ckpts.save_params(state.params, '/tmp/my_ckpt/')
评估#
在这里,我们仅通过采样提示执行定性评估。
有关评估的更多信息
sampler = gm.text.ChatSampler(
model=model,
params=state.params,
tokenizer=tokenizer,
)
我们测试一个句子,使用与微调期间相同的格式
sampler.chat('Hello! My next holidays are in Paris.')
'Salut ! Mes vacances suivantes seront à Paris.'
模型正确地将我们的提示翻译成了法语!
下一步#
要在 Colab 之外进行微调,您可以查看我们的 examples/ 文件夹,了解更复杂的训练器配置,包括 LoRA、DPO 和分片。