多模态#

Open in Colab

关于如何使用 Gemma 模型进行多模态的示例。

!pip install -q gemma
# Common imports
import os
import jax.numpy as jnp
import tensorflow_datasets as tfds

# Gemma imports
from gemma import gm

默认情况下,Jax 不会使用全部 GPU 内存,但这可以被覆盖。请参阅GPU 内存分配

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

首先,让我们加载一张图片

ds = tfds.data_source('oxford_flowers102', split='train')
image = ds[0]['image']
image
ndarray (500, 667, 3) 
array([[[1, 1, 0],
        [1, 1, 0],
        [1, 1, 0],
        ...,
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]],

       [[1, 1, 0],
        [1, 1, 0],
        [1, 1, 0],
        ...,
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]],

       [[1, 1, 0],
        [1, 1, 0],
        [1, 1, 0],
        ...,
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]],

       ...,

       [[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        ...,
        [7, 7, 5],
        [6, 6, 4],
        [5, 5, 3]],

       [[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        ...,
        [7, 7, 5],
        [6, 6, 4],
        [5, 5, 3]],

       [[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        ...,
        [7, 7, 5],
        [6, 6, 4],
        [5, 5, 3]]], shape=(500, 667, 3), dtype=uint8)

加载模型和参数。

model = gm.nn.Gemma3_4B()

params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)

完整提示采样#

要使用多模态功能,只需

  • 在提示中:添加 <start_of_image> 特殊 token,图片应插入在这些 token 的位置。

  • 将图片传递给 sampler 的 images= 参数

sampler = gm.text.ChatSampler(
    model=model,
    params=params,
)

out = sampler.chat(
    'What can you say about this image: <start_of_image>',
    images=image,
)
print(out)
Here's a breakdown of what I can say about the image:

**Overall Impression:**

The image is a stunning, close-up photograph of a water lily in full bloom. It’s dramatically lit, creating a strong contrast between light and shadow, which really emphasizes the flower's form and texture.

**Specific Details:**

*   **Flower Type:** It appears to be a Nymphaea (water lily). The shape of the petals and the prominent stamens are characteristic of this type of flower.
*   **Color:** The petals are primarily white with a subtle pinkish hue at the base. The stamens are a bright, vibrant yellow.
*   **Lighting:** The lighting is key. There's a strong light source coming from the upper left, casting dramatic shadows and highlighting the edges of the petals. This creates a sense of depth and makes the flower appear almost sculptural.
*   **Texture:** You can see the delicate texture of the petals – they appear smooth but with subtle ridges and folds.
*   **Composition:** The flower is centered in the frame, drawing the viewer's eye directly to it. The dark background isolates the flower and makes it the focal point.
*   **Water Droplets:** There are a few water droplets on the petals, adding a touch of freshness and realism.

**Mood/Feeling:**

The image evokes a feeling of tranquility, beauty, and perhaps a touch of mystery due to the dramatic lighting. It feels serene and peaceful.

**Do you want me to focus on a specific aspect of the image, such as:**

*   The lighting technique?
*   The flower's anatomy?
*   The overall mood it creates?<end_of_turn>

注意事项

  • 该模型在 jpeg 图片上训练。如果您有 PNG 图片,应将其编码/解码为 Jpeg,以避免偏差。

  • 您可以传递多张图片。只需在应插入图片的位置添加 <start_of_image>。所有图片应调整为相同形状。输入形状将为 batch, num_images, h, w, c (而不是 batch, h, w, c)。

  • 如果一个批次内的提示具有不同数量的图片,只需用 0(或任何)值填充张量中未使用的图片。

直接调用模型#

向模型添加图片仅需

  • 在提示中:在应插入图片的位置添加 <start_of_image> 特殊 token。

tokenizer = gm.text.Gemma3Tokenizer()


prompt = """<start_of_turn>user
Describe this image in a single word.

<start_of_image>

<end_of_turn>
<start_of_turn>model
"""
prompt = jnp.asarray(tokenizer.encode(prompt, add_bos=True))
  • 在模型中:将 images= 传递给 model.apply

# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    images=image,
    return_last_only=True,  # Only predict the last token
)


# Plot the probability distribution
tokenizer.plot_logits(out.logits)

微调#

使用多模态进行微调也很简单。从原始的 微调,更改为多模态仅需 2 处更改

  • 拥有一个也返回图片的 datasets (b h w c) 或多张图片 (b n h w c)

  • 指定模型输入中批次中的哪个字段对应于图片

    model = gm.nn.Gemma3_4B(
        tokens='batch.tokens',
        images='batch.image',
    )
    

请参阅 multimodal.py 示例。