LoRA (采样)

LoRA (采样)#

Open in Colab

使用 LoRA 和 Gemma 进行推理的示例。有关使用 LoRA 进行微调的示例,请参阅 LoRA 微调 示例。

!pip install -q gemma
# Common imports
import os
import jax
import jax.numpy as jnp
import treescope

# Gemma imports
from gemma import gm
from gemma import peft  # Parameter fine-tuning module

默认情况下,Jax 不会使用完整的 GPU 内存,但这可以被覆盖。请参阅 GPU 内存分配

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

初始化模型#

要将 Gemma 与 LoRA 一起使用,只需将任何 Gemma 模型包装在 gm.nn.LoRA 中即可

model = gm.nn.LoRA(
    rank=4,
    model=gm.nn.Gemma3_4B(text_only=True),
)

初始化权重

token_ids = jnp.zeros((1, 256,), dtype=jnp.int32)  # Create the (batch_size, seq_length)

params = model.init(
    jax.random.key(0),
    token_ids,
)

params = params['params']

检查参数的形状/结构。我们可以看到已添加了 LoRA 权重。

treescope.show(params)

恢复预训练的参数。我们使用 peft.split_paramspeft.merge_params 来替换随机初始化的参数为预训练的参数。

当使用 gm.ckpts.load_params 时,请确保传递 params=original kwarg。这确保了

  • 旧参数的内存被释放(因此内存中仅保留权重的单个副本)

  • 恢复的参数重用与输入相同的分片(这里没有分片,因此不是必需的)

# Splits the params into non-LoRA and LoRA weights
original, lora = peft.split_params(params)

# Load the params from the checkpoint
original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)

# Merge the pretrained params back with LoRA
params = peft.merge_params(original, lora)

微调#

请参阅我们的 微调指南 以获取更多信息。

有关端到端微调示例,请参阅我们的 lora.py 配置。

推理#

这是一个运行单个模型调用的示例

tokenizer = gm.text.Gemma3Tokenizer()

prompt = tokenizer.encode('The capital of France is')
prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)


# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Show the token distribution
tokenizer.plot_logits(out.logits)

采样整个句子的示例

sampler = gm.text.ChatSampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.chat('The capital of France is?')