LoRA (采样)#
使用 LoRA 和 Gemma 进行推理的示例。有关使用 LoRA 进行微调的示例,请参阅 LoRA 微调 示例。
!pip install -q gemma
# Common imports
import os
import jax
import jax.numpy as jnp
import treescope
# Gemma imports
from gemma import gm
from gemma import peft # Parameter fine-tuning module
默认情况下,Jax 不会使用完整的 GPU 内存,但这可以被覆盖。请参阅 GPU 内存分配
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
初始化模型#
要将 Gemma 与 LoRA 一起使用,只需将任何 Gemma 模型包装在 gm.nn.LoRA
中即可
model = gm.nn.LoRA(
rank=4,
model=gm.nn.Gemma3_4B(text_only=True),
)
初始化权重
token_ids = jnp.zeros((1, 256,), dtype=jnp.int32) # Create the (batch_size, seq_length)
params = model.init(
jax.random.key(0),
token_ids,
)
params = params['params']
检查参数的形状/结构。我们可以看到已添加了 LoRA 权重。
treescope.show(params)
恢复预训练的参数。我们使用 peft.split_params
和 peft.merge_params
来替换随机初始化的参数为预训练的参数。
当使用 gm.ckpts.load_params
时,请确保传递 params=original
kwarg。这确保了
旧参数的内存被释放(因此内存中仅保留权重的单个副本)
恢复的参数重用与输入相同的分片(这里没有分片,因此不是必需的)
# Splits the params into non-LoRA and LoRA weights
original, lora = peft.split_params(params)
# Load the params from the checkpoint
original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)
# Merge the pretrained params back with LoRA
params = peft.merge_params(original, lora)
微调#
请参阅我们的 微调指南 以获取更多信息。
有关端到端微调示例,请参阅我们的 lora.py 配置。
推理#
这是一个运行单个模型调用的示例
tokenizer = gm.text.Gemma3Tokenizer()
prompt = tokenizer.encode('The capital of France is')
prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)
# Run the model
out = model.apply(
{'params': params},
tokens=prompt,
return_last_only=True, # Only predict the last token
)
# Show the token distribution
tokenizer.plot_logits(out.logits)
采样整个句子的示例
sampler = gm.text.ChatSampler(
model=model,
params=params,
tokenizer=tokenizer,
)
sampler.chat('The capital of France is?')