# LoRA (Sampling)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/lora_sampling.ipynb)

Example on using LoRA with Gemma (for inference). For an example of fine-tuning with LoRA, see [LoRA finetuning](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) example.

In [None]:
!pip install -q gemma

In [None]:
# Common imports
import os
import jax
import jax.numpy as jnp
import treescope

# Gemma imports
from gemma import gm
from gemma import peft  # Parameter fine-tuning module

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

## Initializing the model

To use Gemma with LoRA, simply wrap any Gemma model in `gm.nn.LoRA`:

In [None]:
model = gm.nn.LoRA(
    rank=4,
    model=gm.nn.Gemma3_4B(text_only=True),
)

Initialize the weights:

In [None]:
token_ids = jnp.zeros((1, 256,), dtype=jnp.int32)  # Create the (batch_size, seq_length)

params = model.init(
    jax.random.key(0),
    token_ids,
)

params = params['params']

Inspect the params shape/structure. We can see LoRA weights have been added.

In [None]:
treescope.show(params)

Restore the pre-trained params. We use `peft.split_params` and `peft.merge_params` to replace the randomly initialized params with the pre-trained ones.

When using `gm.ckpts.load_params`, make sure to pass the `params=original` kwarg. This ensure that:

* The memory from the old params is released (so only a single copy of the weights stays in memory)
* The restored params reuse the same sharding as the input (here there's no sharding, so isn't required)

In [None]:
# Splits the params into non-LoRA and LoRA weights
original, lora = peft.split_params(params)

# Load the params from the checkpoint
original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)

# Merge the pretrained params back with LoRA
params = peft.merge_params(original, lora)

## Fine-tuning

See our [finetuning guide](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) for more info.

For a end-to-end finetuning example, see our [lora.py](https://github.com/google-deepmind/gemma/tree/main/examples/lora.py) config.

## Inference

Here's an example of running a single model call:

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()

prompt = tokenizer.encode('The capital of France is')
prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)


# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Show the token distribution
tokenizer.plot_logits(out.logits)

To sample an entire sentence:

In [None]:
sampler = gm.text.ChatSampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.chat('The capital of France is?')