Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 341 additions & 0 deletions demos/Gemma3_Multimodal.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Gemma3_Multimodal.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gemma 3 Multimodal Demo with TransformerBridge\n",
"\n",
"This notebook demonstrates how to use TransformerBridge with `Gemma3ForConditionalGeneration`,\n",
"the vision-language variant of Gemma 3. The model pairs a SigLIP vision encoder with the\n",
"Gemma 3 language model and is the same architecture used by MedGemma.\n",
"\n",
"We demonstrate:\n",
"1. Loading Gemma 3 (4B-it) through TransformerBridge\n",
"2. Multimodal generation from an image + text prompt\n",
"3. Capturing vision-language activations with `run_with_cache()`\n",
"\n",
"> **Gated model.** The `google/gemma-3-*` checkpoints are gated on Hugging Face. Accept\n",
"> the license at https://huggingface.co/google/gemma-3-4b-it and run `huggingface-cli login`\n",
"> (or set `HF_TOKEN`) before executing this notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Detect Colab and install dependencies if needed\n",
"DEVELOPMENT_MODE = False\n",
"try:\n",
" import google.colab\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
" %pip install transformer_lens\n",
" %pip install circuitsvis\n",
"except:\n",
" IN_COLAB = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"import torch\n",
"from PIL import Image\n",
"import requests\n",
"from io import BytesIO\n",
"\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"from transformer_lens.model_bridge import TransformerBridge\n",
"\n",
"try:\n",
" import circuitsvis as cv\n",
"except ImportError:\n",
" print('circuitsvis not installed, attention visualization will not work')\n",
" cv = None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Gemma 3 through TransformerBridge\n",
"\n",
"TransformerBridge maps `Gemma3ForConditionalGeneration` to its multimodal adapter, which\n",
"wraps the SigLIP vision tower, the multimodal projector, and the Gemma 3 language model\n",
"into a single hooked model.\n",
"\n",
"We use **bfloat16** here \u2014 Gemma 3 is trained in bf16 and fp16 can produce unstable activations.\n",
"The 4B-it variant is the smallest multimodal Gemma 3 (the 270m and 1B checkpoints are text-only)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"model = TransformerBridge.boot_transformers(\n",
" \"google/gemma-3-4b-it\",\n",
" device=device,\n",
" dtype=torch.bfloat16,\n",
")\n",
"\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"print(f\"Model loaded on {device}\")\n",
"print(f\"Multimodal: {getattr(model.cfg, 'is_multimodal', False)}\")\n",
"print(f\"Layers: {model.cfg.n_layers}, Heads: {model.cfg.n_heads}\")\n",
"print(f\"Vision tokens per image: {getattr(model.cfg, 'mm_tokens_per_image', None)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load a test image\n",
"\n",
"We'll use a stop-sign photo from Australia to test the model's visual understanding."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"image_url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
"response = requests.get(image_url)\n",
"image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
"plt.imshow(image)\n",
"plt.axis('off')\n",
"plt.title('Test Image')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multimodal Generation\n",
"\n",
"Gemma 3 instruction-tuned models expect the chat template format with\n",
"`<start_of_turn>` / `<end_of_turn>` markers, and use `<start_of_image>` as the image\n",
"placeholder (rather than LLaVA's `<image>`). The processor expands `<start_of_image>`\n",
"into the appropriate number of vision tokens (256 by default for Gemma 3 4B).\n",
"\n",
"We call `prepare_multimodal_inputs()` to run the processor on text + image, then pass\n",
"`pixel_values` to `generate()`. The bridge's `generate()` keeps a KV cache\n",
"(`use_past_kv_cache=True` by default) for efficient autoregressive decoding while\n",
"preserving full hook access."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"question = \"What do you see in this photo?\"\n",
"prompt = (\n",
" \"<start_of_turn>user\\n\"\n",
" f\"<start_of_image>{question}<end_of_turn>\\n\"\n",
" \"<start_of_turn>model\\n\"\n",
")\n",
"\n",
"# Prepare multimodal inputs (handles image processing + tokenization)\n",
"inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n",
"input_ids = inputs['input_ids']\n",
"pixel_values = inputs['pixel_values']\n",
"\n",
"# Pass any extra processor outputs (e.g. token_type_ids for Gemma 3)\n",
"extra_kwargs = {k: v for k, v in inputs.items()\n",
" if k not in ('input_ids', 'pixel_values')}\n",
"\n",
"generated_text = model.generate(\n",
" input_ids,\n",
" pixel_values=pixel_values,\n",
" max_new_tokens=80,\n",
" do_sample=False,\n",
" use_past_kv_cache=True,\n",
" return_type=\"str\",\n",
" **extra_kwargs,\n",
")\n",
"\n",
"print('Generated text:', generated_text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try a second image to confirm the model adapts its description:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"image_url_2 = \"https://github.com/zazamrykh/PicFinder/blob/main/images/doge.jpg?raw=true\"\n",
"response = requests.get(image_url_2)\n",
"image_2 = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
"plt.imshow(image_2)\n",
"plt.axis('off')\n",
"plt.show()\n",
"\n",
"inputs = model.prepare_multimodal_inputs(text=prompt, images=image_2)\n",
"input_ids = inputs['input_ids']\n",
"pixel_values = inputs['pixel_values']\n",
"extra_kwargs = {k: v for k, v in inputs.items()\n",
" if k not in ('input_ids', 'pixel_values')}\n",
"\n",
"generated_text = model.generate(\n",
" input_ids,\n",
" pixel_values=pixel_values,\n",
" max_new_tokens=80,\n",
" do_sample=False,\n",
" use_past_kv_cache=True,\n",
" return_type=\"str\",\n",
" **extra_kwargs,\n",
")\n",
"print('Generated text:', generated_text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inspecting Vision-Language Activations\n",
"\n",
"`run_with_cache()` accepts the same `pixel_values` argument and captures activations from\n",
"the vision encoder, the multimodal projector, and every transformer block in the language\n",
"model. This lets us inspect how the language tokens attend to image tokens during\n",
"multimodal processing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n",
"extra_kwargs = {k: v for k, v in inputs.items()\n",
" if k not in ('input_ids', 'pixel_values')}\n",
"\n",
"with torch.no_grad():\n",
" logits, cache = model.run_with_cache(\n",
" inputs['input_ids'],\n",
" pixel_values=inputs['pixel_values'],\n",
" **extra_kwargs,\n",
" )\n",
"\n",
"print(f'Logits shape: {logits.shape}')\n",
"print(f'Cache entries: {len(cache)}')\n",
"vision_keys = [k for k in cache.keys() if 'vision' in k.lower()]\n",
"print(f'Vision-related cache entries: {len(vision_keys)}')\n",
"print(f'Sample vision keys: {vision_keys[:5]}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"if cv is not None:\n",
" layer_to_visualize = 16\n",
" tokens_to_show = 30\n",
"\n",
" pattern_keys = [k for k in cache.keys() if f'blocks.{layer_to_visualize}' in k and 'pattern' in k]\n",
" if pattern_keys:\n",
" attention_pattern = cache[pattern_keys[0]]\n",
" if attention_pattern.ndim == 4:\n",
" attention_pattern = attention_pattern[0]\n",
"\n",
" token_ids = inputs['input_ids'][0].cpu()\n",
" str_tokens = model.tokenizer.convert_ids_to_tokens(token_ids)\n",
"\n",
" print(f'Layer {layer_to_visualize} Head Attention Patterns (last {tokens_to_show} tokens):')\n",
" display(cv.attention.attention_patterns(\n",
" tokens=str_tokens[-tokens_to_show:],\n",
" attention=attention_pattern[:, -tokens_to_show:, -tokens_to_show:].float().cpu(),\n",
" ))\n",
" else:\n",
" print(f'No attention pattern found for layer {layer_to_visualize}')\n",
" print(f'Available attention-related keys: {[k for k in cache.keys() if \"attn\" in k][:10]}')\n",
"else:\n",
" print('circuitsvis not available \u2014 skipping visualization')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"TransformerBridge provides native multimodal support for `Gemma3ForConditionalGeneration`:\n",
"\n",
"- **`boot_transformers(\"google/gemma-3-4b-it\")`** loads the full vision + projector + language pipeline\n",
"- **`prepare_multimodal_inputs(text=..., images=...)`** handles image processing and tokenization\n",
"- **`generate(input_ids, pixel_values=...)`** runs multimodal generation with KV cache and hooks\n",
"- **`run_with_cache(input_ids, pixel_values=...)`** captures activations including SigLIP vision tokens\n",
"\n",
"A few Gemma 3 specifics worth noting:\n",
"\n",
"- Use the chat-template format (`<start_of_turn>user ... <end_of_turn><start_of_turn>model`) for instruction-tuned variants\n",
"- The image placeholder is `<start_of_image>`, not `<image>`\n",
"- Gemma 3 is trained in bf16 \u2014 prefer `torch.bfloat16` over `torch.float16`\n",
"- The same code path works for MedGemma (`google/medgemma-4b-it`, `google/medgemma-27b-it`) and any other `Gemma3ForConditionalGeneration` checkpoint"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "transformer-lens",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading