LoRA (微调)#

Open in Colab

这是一个关于使用 LoRA 微调 Gemma 的示例。最好先阅读微调 colab,以理解这一个。

如果您只想使用 LoRA 进行推理,请参阅 LoRA 采样 教程。

!pip install -q gemma
# Common imports
import os
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

默认情况下,Jax 不会充分利用完整的 GPU 内存,但这可以被覆盖。请参阅 GPU 内存分配

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

配置更新#

如果您熟悉微调教程,切换到 LoRA 只需要对训练器进行 3 处更改。

有关端到端示例,请参阅 lora.py 配置。

1. 模型#

将模型包装在 gm.nn.LoRA 中。这将应用模型手术来替换所有线性层和兼容层为 LoRA 层。

model = gm.nn.LoRA(
    rank=4,
    model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)

在内部,这使用 gemma.peft 微型库来执行模型手术。

2. 检查点#

将 init 转换包装在 gm.ckpts.SkipLoRA 中。之所以需要包装器,是因为使用和不使用 LoRA 的参数结构是不同的。

仅加载初始预训练权重,但 LoRA 权重保持其随机初始化。

init_transform = gm.ckpts.SkipLoRA(
    wrapped=gm.ckpts.LoadCheckpoint(
        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    ),
)

注意:如果您直接使用 gm.ckpts.load_params 加载权重,则可以使用 peft.split_paramspeft.merge_params 代替。有关示例,请参阅 LoRA 采样

3. 优化器#

最后,我们向优化器添加一个掩码(使用 kd.optim.partial_updates),以便仅训练 LoRA 权重。

optimizer = kd.optim.partial_updates(
    optax.adafactor(learning_rate=0.005),
    # We only optimize the LoRA weights. The rest of the model is frozen.
    mask=kd.optim.select("lora"),
)

训练#

数据管道#

微调 示例一样,我们重新创建分词器

tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)
[<_Gemma2SpecialTokens.BOS: 2>, 1596, 603, 671, 3287, 13060]

以及数据管道

ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://tensorflowcn.cn/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)
Disabling pygrain multi-processing (unsupported in colab).

我们可以解码批次中的一个示例,以检查模型输入并检查其格式是否正确

text = tokenizer.decode(ex['input'][0])

print(text)
<start_of_turn>user
As far as battle mode, 64 is the best.<end_of_turn>
<start_of_turn>model
En ce qui concerne le mode bataille, 64 est le meilleur.

训练器#

然后,我们创建训练器,重用上面创建的 modelinit_transformoptimizer

trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=init_transform,
    # Training parameters
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optimizer,
)

可以使用 .train() 方法启动训练。

请注意,训练器就像模型一样是不可变的,因此它不存储状态或参数。相反,返回包含训练参数的状态。

state, aux = trainer.train()
Configuring ...
Initializing ...
Starting training loop at step 0

检查点保存#

# TODO(epot): Doc on:
# * saving only LoRA params
# * Fuse params

评估#

在这里,我们仅通过采样提示执行定性评估。

有关评估的更多信息

  • 有关运行推理的更多信息,请参阅 采样 教程。

  • 要在训练期间添加评估,请参阅 Kauldron 评估器 文档。

sampler = gm.text.ChatSampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

我们测试一个句子,使用与微调期间相同的格式

sampler.chat('I\'m feeling happy today!')
"Je me sens heureux aujourd'hui !"

模型正确地将我们的提示翻译成了法语!