LoRA (微调)#
这是一个关于使用 LoRA 微调 Gemma 的示例。最好先阅读微调 colab,以理解这一个。
如果您只想使用 LoRA 进行推理,请参阅 LoRA 采样 教程。
!pip install -q gemma
# Common imports
import os
import optax
import treescope
# Gemma imports
from kauldron import kd
from gemma import gm
默认情况下,Jax 不会充分利用完整的 GPU 内存,但这可以被覆盖。请参阅 GPU 内存分配
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
配置更新#
如果您熟悉微调教程,切换到 LoRA 只需要对训练器进行 3 处更改。
有关端到端示例,请参阅 lora.py 配置。
1. 模型#
将模型包装在 gm.nn.LoRA
中。这将应用模型手术来替换所有线性层和兼容层为 LoRA 层。
model = gm.nn.LoRA(
rank=4,
model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)
在内部,这使用 gemma.peft
微型库来执行模型手术。
2. 检查点#
将 init 转换包装在 gm.ckpts.SkipLoRA
中。之所以需要包装器,是因为使用和不使用 LoRA 的参数结构是不同的。
仅加载初始预训练权重,但 LoRA 权重保持其随机初始化。
init_transform = gm.ckpts.SkipLoRA(
wrapped=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
)
注意:如果您直接使用 gm.ckpts.load_params
加载权重,则可以使用 peft.split_params
和 peft.merge_params
代替。有关示例,请参阅 LoRA 采样。
3. 优化器#
最后,我们向优化器添加一个掩码(使用 kd.optim.partial_updates
),以便仅训练 LoRA 权重。
optimizer = kd.optim.partial_updates(
optax.adafactor(learning_rate=0.005),
# We only optimize the LoRA weights. The rest of the model is frozen.
mask=kd.optim.select("lora"),
)
训练#
数据管道#
与 微调 示例一样,我们重新创建分词器
tokenizer = gm.text.Gemma3Tokenizer()
tokenizer.encode('This is an example sentence', add_bos=True)
[<_Gemma2SpecialTokens.BOS: 2>, 1596, 603, 671, 3287, 13060]
以及数据管道
ds = kd.data.py.Tfds(
name='mtnt/en-fr',
split='train',
shuffle=True,
batch_size=8,
transforms=[
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://tensorflowcn.cn/datasets/catalog/mtnt
in_prompt='src',
in_response='dst',
# Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
out_input='input',
out_target='target',
out_target_mask='loss_mask',
tokenizer=tokenizer,
# Padding parameters
max_length=200,
truncate=True,
),
],
)
ex = ds[0]
treescope.show(ex)
Disabling pygrain multi-processing (unsupported in colab).
我们可以解码批次中的一个示例,以检查模型输入并检查其格式是否正确
text = tokenizer.decode(ex['input'][0])
print(text)
<start_of_turn>user
As far as battle mode, 64 is the best.<end_of_turn>
<start_of_turn>model
En ce qui concerne le mode bataille, 64 est le meilleur.
训练器#
然后,我们创建训练器,重用上面创建的 model
、init_transform
和 optimizer
trainer = kd.train.Trainer(
seed=42, # The seed of enlightenment
workdir='/tmp/ckpts', # TODO(epot): Make the workdir optional by default
# Dataset
train_ds=ds,
# Model
model=model,
init_transform=init_transform,
# Training parameters
num_train_steps=500,
train_losses={
"loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
optimizer=optimizer,
)
可以使用 .train()
方法启动训练。
请注意,训练器就像模型一样是不可变的,因此它不存储状态或参数。相反,返回包含训练参数的状态。
state, aux = trainer.train()
Configuring ...
Initializing ...
Starting training loop at step 0
检查点保存#
# TODO(epot): Doc on:
# * saving only LoRA params
# * Fuse params
评估#
在这里,我们仅通过采样提示执行定性评估。
有关评估的更多信息
sampler = gm.text.ChatSampler(
model=model,
params=state.params,
tokenizer=tokenizer,
)
我们测试一个句子,使用与微调期间相同的格式
sampler.chat('I\'m feeling happy today!')
"Je me sens heureux aujourd'hui !"
模型正确地将我们的提示翻译成了法语!