# Sharding

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sharding.ipynb)

Sharding for Gemma models. This example run inference on Gemma 27B, on a TPU v2, using 8 devices.

In [None]:
!pip install -q gemma

In [1]:
# Common imports
import jax
import jax.numpy as jnp

# Gemma imports
from gemma import gm
from kauldron import kd

For this colab, make sure to be connected to the TPU kernel by selecting `Change runtime type` > `v2-8 TPU` to access multiple accelerators. Jax should display multiple devices.

In [2]:
jax.device_count()

8

Load the model, and the params. Here we load the 27B model.

In [3]:
model = gm.nn.Gemma3_27B()

When restoring the weights, you can pass `sharding=` parameter to the `gm.ckpts.load_params`. Here we use a naive `kd.sharding.FSDPSharding` heuristic.

In [4]:
params = gm.ckpts.load_params(
    gm.ckpts.CheckpointPath.GEMMA3_27B_IT,
    sharding=kd.sharding.FSDPSharding(),
)

## Sampling

Test the sharding using the `gm.text.ChatSampler`.

In [6]:
sampler = gm.text.ChatSampler(model=model, params=params)

out = sampler.chat('Tell me an unknown interesting biology fact about the brain.')
print(out)

Okay, here's a fascinating and relatively unknown fact about the brain:

**Your brain actively "cleans itself" during sleep with a system called the glymphatic system, and this cleaning process is *much* more efficient when you sleep on your side.**

Here's the breakdown:

* **The Glymphatic System:** For a long time, it was thought the brain didn't have a traditional lymphatic system (which clears waste from the body).  But in 2012, researchers discovered the glymphatic system. It's essentially a brain-wide waste clearance system that uses cerebrospinal fluid (CSF) to flush out metabolic waste products that build up during waking hours – things like amyloid-beta, a protein associated with Alzheimer's disease.
* **How it Works:** CSF flows *along* arteries into the brain tissue and then drains out along veins. This flow is significantly enhanced during sleep.
* **Side Sleeping is Key:**  Research (particularly studies using MRI scans) has shown that sleeping on your side – *especially*

Even though we've learned something new, the model still halucinated urls, still showing limitations of current system.

## Calling the model directly

It's also possbile to directly call the model. For this, the input has first to be manually encoded.

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()

In [None]:
prompt = tokenizer.encode('My name is', add_bos=True)  # /!\ Don't forget to add the BOS token
prompt = jnp.asarray(prompt)

When using sharding, input also has to be sharded. Here, we have a single prompt, so we use `kd.sharding.REPLICATED` sharding so each device get a copy of the prompt.

During training, usually the prompts will be batched and padded, then sharded using `kd.sharding.FIRST_DIM`, so the prompts are distributed across each devices.

In [None]:
prompt = kd.sharding.with_sharding_constraint(prompt, kd.sharding.REPLICATED)

In [None]:
# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Sample a token from the predicted logits
next_token = jax.random.categorical(
    jax.random.key(1),
    out.logits
)
tokenizer.decode(next_token)

' Mary'

You can also display the next token probability.

In [None]:
tokenizer.plot_logits(out.logits)

## Training

To use sharding during training, simply set the `sharding=` attribute of the trainer, like:

```python
trainer = kd.train.Trainer(
    ...,
    sharding=kd.sharding.ShardingStrategy(
        params=kd.sharding.FSDPSharding(),
    ),
    ...,
)
```

See a full example at: https://github.com/google-deepmind/gemma/tree/main/examples/sharding.py