分片#
Gemma 模型的分片。此示例在 TPU v2 上使用 8 个设备对 Gemma 27B 运行推理。
!pip install -q gemma
# Common imports
import jax
import jax.numpy as jnp
# Gemma imports
from gemma import gm
from kauldron import kd
对于此 colab,请确保通过选择 Change runtime type
> v2-8 TPU
连接到 TPU 内核以访问多个加速器。Jax 应该显示多个设备。
jax.device_count()
8
加载模型和参数。这里我们加载 27B 模型。
model = gm.nn.Gemma3_27B()
恢复权重时,您可以将 sharding=
参数传递给 gm.ckpts.load_params
。这里我们使用一个简单的 kd.sharding.FSDPSharding
启发式方法。
params = gm.ckpts.load_params(
gm.ckpts.CheckpointPath.GEMMA3_27B_IT,
sharding=kd.sharding.FSDPSharding(),
)
采样#
使用 gm.text.ChatSampler
测试分片。
sampler = gm.text.ChatSampler(model=model, params=params)
out = sampler.chat('Tell me an unknown interesting biology fact about the brain.')
print(out)
Okay, here's a fascinating and relatively unknown fact about the brain:
**Your brain actively "cleans itself" during sleep with a system called the glymphatic system, and this cleaning process is *much* more efficient when you sleep on your side.**
Here's the breakdown:
* **The Glymphatic System:** For a long time, it was thought the brain didn't have a traditional lymphatic system (which clears waste from the body). But in 2012, researchers discovered the glymphatic system. It's essentially a brain-wide waste clearance system that uses cerebrospinal fluid (CSF) to flush out metabolic waste products that build up during waking hours – things like amyloid-beta, a protein associated with Alzheimer's disease.
* **How it Works:** CSF flows *along* arteries into the brain tissue and then drains out along veins. This flow is significantly enhanced during sleep.
* **Side Sleeping is Key:** Research (particularly studies using MRI scans) has shown that sleeping on your side – *especially* the left side – is the most effective position for clearing waste from the brain. This is because the lateral position allows gravity to assist the flow of CSF and facilitates the clearance of interstitial fluid (the fluid between brain cells) and waste products. Sleeping on your back is *less* effective, and sleeping on your stomach is the *least* effective.
**Why is this relatively unknown?** The glymphatic system is a relatively recent discovery, and research is still ongoing. It's also a complex system to study.
**Source/Further Reading:**
* **Oregon Health & Science University (OHSU) - The Brain's Cleaning System:** [https://www.ohsu.edu/news/2013/08/brain-cleanses-itself-during-sleep](https://www.ohsu.edu/news/2013/08/brain-cleanses-itself-during-sleep)
* **ScienceAlert - Scientists Discover How You Can Optimize Your Brain's Cleaning System While You Sleep:** [https://www.sciencealert.com/scientists-discover-how-you-can-optimize-your-brain-s-cleaning-system-while-you-sleep](https://www.sciencealert.com/scientists-discover-how-you-can
即使我们学到了一些新东西,该模型仍然幻觉出网址,仍然显示出当前系统的局限性。
直接调用模型#
也可以直接调用模型。为此,首先必须手动编码输入。
tokenizer = gm.text.Gemma3Tokenizer()
prompt = tokenizer.encode('My name is', add_bos=True) # /!\ Don't forget to add the BOS token
prompt = jnp.asarray(prompt)
使用分片时,输入也必须分片。在这里,我们有一个单一的提示,所以我们使用 kd.sharding.REPLICATED
分片,以便每个设备都获得提示的副本。
在训练期间,通常提示将被批处理和填充,然后使用 kd.sharding.FIRST_DIM
分片,以便提示分布在每个设备上。
prompt = kd.sharding.with_sharding_constraint(prompt, kd.sharding.REPLICATED)
# Run the model
out = model.apply(
{'params': params},
tokens=prompt,
return_last_only=True, # Only predict the last token
)
# Sample a token from the predicted logits
next_token = jax.random.categorical(
jax.random.key(1),
out.logits
)
tokenizer.decode(next_token)
' Mary'
您还可以显示下一个 token 的概率。
tokenizer.plot_logits(out.logits)
训练#
要在训练期间使用分片,只需设置 trainer 的 sharding=
属性,例如
trainer = kd.train.Trainer(
...,
sharding=kd.sharding.ShardingStrategy(
params=kd.sharding.FSDPSharding(),
),
...,
)
在以下位置查看完整示例:https://github.com/google-deepmind/gemma/tree/main/examples/sharding.py