多模态#
关于如何使用 Gemma 模型进行多模态的示例。
!pip install -q gemma
# Common imports
import os
import jax.numpy as jnp
import tensorflow_datasets as tfds
# Gemma imports
from gemma import gm
默认情况下,Jax 不会使用全部 GPU 内存,但这可以被覆盖。请参阅GPU 内存分配
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
首先,让我们加载一张图片
ds = tfds.data_source('oxford_flowers102', split='train')
image = ds[0]['image']
image
ndarray (500, 667, 3)
array([[[1, 1, 0], [1, 1, 0], [1, 1, 0], ..., [1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 0], [1, 1, 0], ..., [1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 0], [1, 1, 0], ..., [1, 1, 1], [1, 1, 1], [1, 1, 1]], ..., [[1, 1, 1], [1, 1, 1], [1, 1, 1], ..., [7, 7, 5], [6, 6, 4], [5, 5, 3]], [[1, 1, 1], [1, 1, 1], [1, 1, 1], ..., [7, 7, 5], [6, 6, 4], [5, 5, 3]], [[1, 1, 1], [1, 1, 1], [1, 1, 1], ..., [7, 7, 5], [6, 6, 4], [5, 5, 3]]], shape=(500, 667, 3), dtype=uint8)
加载模型和参数。
model = gm.nn.Gemma3_4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
完整提示采样#
要使用多模态功能,只需
在提示中:添加
<start_of_image>
特殊 token,图片应插入在这些 token 的位置。将图片传递给 sampler 的
images=
参数
sampler = gm.text.ChatSampler(
model=model,
params=params,
)
out = sampler.chat(
'What can you say about this image: <start_of_image>',
images=image,
)
print(out)
Here's a breakdown of what I can say about the image:
**Overall Impression:**
The image is a stunning, close-up photograph of a water lily in full bloom. It’s dramatically lit, creating a strong contrast between light and shadow, which really emphasizes the flower's form and texture.
**Specific Details:**
* **Flower Type:** It appears to be a Nymphaea (water lily). The shape of the petals and the prominent stamens are characteristic of this type of flower.
* **Color:** The petals are primarily white with a subtle pinkish hue at the base. The stamens are a bright, vibrant yellow.
* **Lighting:** The lighting is key. There's a strong light source coming from the upper left, casting dramatic shadows and highlighting the edges of the petals. This creates a sense of depth and makes the flower appear almost sculptural.
* **Texture:** You can see the delicate texture of the petals – they appear smooth but with subtle ridges and folds.
* **Composition:** The flower is centered in the frame, drawing the viewer's eye directly to it. The dark background isolates the flower and makes it the focal point.
* **Water Droplets:** There are a few water droplets on the petals, adding a touch of freshness and realism.
**Mood/Feeling:**
The image evokes a feeling of tranquility, beauty, and perhaps a touch of mystery due to the dramatic lighting. It feels serene and peaceful.
**Do you want me to focus on a specific aspect of the image, such as:**
* The lighting technique?
* The flower's anatomy?
* The overall mood it creates?<end_of_turn>
注意事项
该模型在
jpeg
图片上训练。如果您有 PNG 图片,应将其编码/解码为 Jpeg,以避免偏差。您可以传递多张图片。只需在应插入图片的位置添加
<start_of_image>
。所有图片应调整为相同形状。输入形状将为batch, num_images, h, w, c
(而不是batch, h, w, c
)。如果一个批次内的提示具有不同数量的图片,只需用 0(或任何)值填充张量中未使用的图片。
直接调用模型#
向模型添加图片仅需
在提示中:在应插入图片的位置添加
<start_of_image>
特殊 token。
tokenizer = gm.text.Gemma3Tokenizer()
prompt = """<start_of_turn>user
Describe this image in a single word.
<start_of_image>
<end_of_turn>
<start_of_turn>model
"""
prompt = jnp.asarray(tokenizer.encode(prompt, add_bos=True))
在模型中:将
images=
传递给model.apply
。
# Run the model
out = model.apply(
{'params': params},
tokens=prompt,
images=image,
return_last_only=True, # Only predict the last token
)
# Plot the probability distribution
tokenizer.plot_logits(out.logits)
微调#
使用多模态进行微调也很简单。从原始的 微调,更改为多模态仅需 2 处更改
拥有一个也返回图片的 datasets (
b h w c
) 或多张图片 (b n h w c
)指定模型输入中批次中的哪个字段对应于图片
model = gm.nn.Gemma3_4B( tokens='batch.tokens', images='batch.image', )
请参阅 multimodal.py
示例。