diff --git a/demos/Gemma3_Multimodal.ipynb b/demos/Gemma3_Multimodal.ipynb
new file mode 100644
index 000000000..015232523
--- /dev/null
+++ b/demos/Gemma3_Multimodal.ipynb
@@ -0,0 +1,341 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "
\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Gemma 3 Multimodal Demo with TransformerBridge\n",
+ "\n",
+ "This notebook demonstrates how to use TransformerBridge with `Gemma3ForConditionalGeneration`,\n",
+ "the vision-language variant of Gemma 3. The model pairs a SigLIP vision encoder with the\n",
+ "Gemma 3 language model and is the same architecture used by MedGemma.\n",
+ "\n",
+ "We demonstrate:\n",
+ "1. Loading Gemma 3 (4B-it) through TransformerBridge\n",
+ "2. Multimodal generation from an image + text prompt\n",
+ "3. Capturing vision-language activations with `run_with_cache()`\n",
+ "\n",
+ "> **Gated model.** The `google/gemma-3-*` checkpoints are gated on Hugging Face. Accept\n",
+ "> the license at https://huggingface.co/google/gemma-3-4b-it and run `huggingface-cli login`\n",
+ "> (or set `HF_TOKEN`) before executing this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Detect Colab and install dependencies if needed\n",
+ "DEVELOPMENT_MODE = False\n",
+ "try:\n",
+ " import google.colab\n",
+ " IN_COLAB = True\n",
+ " print(\"Running as a Colab notebook\")\n",
+ " %pip install transformer_lens\n",
+ " %pip install circuitsvis\n",
+ "except:\n",
+ " IN_COLAB = False"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "import torch\n",
+ "from PIL import Image\n",
+ "import requests\n",
+ "from io import BytesIO\n",
+ "\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline\n",
+ "\n",
+ "from transformer_lens.model_bridge import TransformerBridge\n",
+ "\n",
+ "try:\n",
+ " import circuitsvis as cv\n",
+ "except ImportError:\n",
+ " print('circuitsvis not installed, attention visualization will not work')\n",
+ " cv = None"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load Gemma 3 through TransformerBridge\n",
+ "\n",
+ "TransformerBridge maps `Gemma3ForConditionalGeneration` to its multimodal adapter, which\n",
+ "wraps the SigLIP vision tower, the multimodal projector, and the Gemma 3 language model\n",
+ "into a single hooked model.\n",
+ "\n",
+ "We use **bfloat16** here \u2014 Gemma 3 is trained in bf16 and fp16 can produce unstable activations.\n",
+ "The 4B-it variant is the smallest multimodal Gemma 3 (the 270m and 1B checkpoints are text-only)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "model = TransformerBridge.boot_transformers(\n",
+ " \"google/gemma-3-4b-it\",\n",
+ " device=device,\n",
+ " dtype=torch.bfloat16,\n",
+ ")\n",
+ "\n",
+ "for param in model.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ "print(f\"Model loaded on {device}\")\n",
+ "print(f\"Multimodal: {getattr(model.cfg, 'is_multimodal', False)}\")\n",
+ "print(f\"Layers: {model.cfg.n_layers}, Heads: {model.cfg.n_heads}\")\n",
+ "print(f\"Vision tokens per image: {getattr(model.cfg, 'mm_tokens_per_image', None)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load a test image\n",
+ "\n",
+ "We'll use a stop-sign photo from Australia to test the model's visual understanding."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "image_url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
+ "response = requests.get(image_url)\n",
+ "image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
+ "plt.imshow(image)\n",
+ "plt.axis('off')\n",
+ "plt.title('Test Image')\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Multimodal Generation\n",
+ "\n",
+ "Gemma 3 instruction-tuned models expect the chat template format with\n",
+ "`` / `` markers, and use `` as the image\n",
+ "placeholder (rather than LLaVA's ``). The processor expands ``\n",
+ "into the appropriate number of vision tokens (256 by default for Gemma 3 4B).\n",
+ "\n",
+ "We call `prepare_multimodal_inputs()` to run the processor on text + image, then pass\n",
+ "`pixel_values` to `generate()`. The bridge's `generate()` keeps a KV cache\n",
+ "(`use_past_kv_cache=True` by default) for efficient autoregressive decoding while\n",
+ "preserving full hook access."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "question = \"What do you see in this photo?\"\n",
+ "prompt = (\n",
+ " \"user\\n\"\n",
+ " f\"{question}\\n\"\n",
+ " \"model\\n\"\n",
+ ")\n",
+ "\n",
+ "# Prepare multimodal inputs (handles image processing + tokenization)\n",
+ "inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n",
+ "input_ids = inputs['input_ids']\n",
+ "pixel_values = inputs['pixel_values']\n",
+ "\n",
+ "# Pass any extra processor outputs (e.g. token_type_ids for Gemma 3)\n",
+ "extra_kwargs = {k: v for k, v in inputs.items()\n",
+ " if k not in ('input_ids', 'pixel_values')}\n",
+ "\n",
+ "generated_text = model.generate(\n",
+ " input_ids,\n",
+ " pixel_values=pixel_values,\n",
+ " max_new_tokens=80,\n",
+ " do_sample=False,\n",
+ " use_past_kv_cache=True,\n",
+ " return_type=\"str\",\n",
+ " **extra_kwargs,\n",
+ ")\n",
+ "\n",
+ "print('Generated text:', generated_text)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's try a second image to confirm the model adapts its description:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "image_url_2 = \"https://github.com/zazamrykh/PicFinder/blob/main/images/doge.jpg?raw=true\"\n",
+ "response = requests.get(image_url_2)\n",
+ "image_2 = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
+ "plt.imshow(image_2)\n",
+ "plt.axis('off')\n",
+ "plt.show()\n",
+ "\n",
+ "inputs = model.prepare_multimodal_inputs(text=prompt, images=image_2)\n",
+ "input_ids = inputs['input_ids']\n",
+ "pixel_values = inputs['pixel_values']\n",
+ "extra_kwargs = {k: v for k, v in inputs.items()\n",
+ " if k not in ('input_ids', 'pixel_values')}\n",
+ "\n",
+ "generated_text = model.generate(\n",
+ " input_ids,\n",
+ " pixel_values=pixel_values,\n",
+ " max_new_tokens=80,\n",
+ " do_sample=False,\n",
+ " use_past_kv_cache=True,\n",
+ " return_type=\"str\",\n",
+ " **extra_kwargs,\n",
+ ")\n",
+ "print('Generated text:', generated_text)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Inspecting Vision-Language Activations\n",
+ "\n",
+ "`run_with_cache()` accepts the same `pixel_values` argument and captures activations from\n",
+ "the vision encoder, the multimodal projector, and every transformer block in the language\n",
+ "model. This lets us inspect how the language tokens attend to image tokens during\n",
+ "multimodal processing."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n",
+ "extra_kwargs = {k: v for k, v in inputs.items()\n",
+ " if k not in ('input_ids', 'pixel_values')}\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " logits, cache = model.run_with_cache(\n",
+ " inputs['input_ids'],\n",
+ " pixel_values=inputs['pixel_values'],\n",
+ " **extra_kwargs,\n",
+ " )\n",
+ "\n",
+ "print(f'Logits shape: {logits.shape}')\n",
+ "print(f'Cache entries: {len(cache)}')\n",
+ "vision_keys = [k for k in cache.keys() if 'vision' in k.lower()]\n",
+ "print(f'Vision-related cache entries: {len(vision_keys)}')\n",
+ "print(f'Sample vision keys: {vision_keys[:5]}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "if cv is not None:\n",
+ " layer_to_visualize = 16\n",
+ " tokens_to_show = 30\n",
+ "\n",
+ " pattern_keys = [k for k in cache.keys() if f'blocks.{layer_to_visualize}' in k and 'pattern' in k]\n",
+ " if pattern_keys:\n",
+ " attention_pattern = cache[pattern_keys[0]]\n",
+ " if attention_pattern.ndim == 4:\n",
+ " attention_pattern = attention_pattern[0]\n",
+ "\n",
+ " token_ids = inputs['input_ids'][0].cpu()\n",
+ " str_tokens = model.tokenizer.convert_ids_to_tokens(token_ids)\n",
+ "\n",
+ " print(f'Layer {layer_to_visualize} Head Attention Patterns (last {tokens_to_show} tokens):')\n",
+ " display(cv.attention.attention_patterns(\n",
+ " tokens=str_tokens[-tokens_to_show:],\n",
+ " attention=attention_pattern[:, -tokens_to_show:, -tokens_to_show:].float().cpu(),\n",
+ " ))\n",
+ " else:\n",
+ " print(f'No attention pattern found for layer {layer_to_visualize}')\n",
+ " print(f'Available attention-related keys: {[k for k in cache.keys() if \"attn\" in k][:10]}')\n",
+ "else:\n",
+ " print('circuitsvis not available \u2014 skipping visualization')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "TransformerBridge provides native multimodal support for `Gemma3ForConditionalGeneration`:\n",
+ "\n",
+ "- **`boot_transformers(\"google/gemma-3-4b-it\")`** loads the full vision + projector + language pipeline\n",
+ "- **`prepare_multimodal_inputs(text=..., images=...)`** handles image processing and tokenization\n",
+ "- **`generate(input_ids, pixel_values=...)`** runs multimodal generation with KV cache and hooks\n",
+ "- **`run_with_cache(input_ids, pixel_values=...)`** captures activations including SigLIP vision tokens\n",
+ "\n",
+ "A few Gemma 3 specifics worth noting:\n",
+ "\n",
+ "- Use the chat-template format (`user ... model`) for instruction-tuned variants\n",
+ "- The image placeholder is ``, not ``\n",
+ "- Gemma 3 is trained in bf16 \u2014 prefer `torch.bfloat16` over `torch.float16`\n",
+ "- The same code path works for MedGemma (`google/medgemma-4b-it`, `google/medgemma-27b-it`) and any other `Gemma3ForConditionalGeneration` checkpoint"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "transformer-lens",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/tests/unit/model_bridge/generalized_components/test_normalization_bridge.py b/tests/unit/model_bridge/generalized_components/test_normalization_bridge.py
new file mode 100644
index 000000000..c93caf03f
--- /dev/null
+++ b/tests/unit/model_bridge/generalized_components/test_normalization_bridge.py
@@ -0,0 +1,118 @@
+"""RMSNorm-vs-LayerNorm dispatch in NormalizationBridge.
+
+Regression coverage for Gemma 3 multimodal: SigLIP's post_layernorm was wrapped
+under the LM's uses_rms_norm=True config, silently dropping mean-centering and
+bias and producing gibberish completions.
+"""
+
+import torch
+import torch.nn as nn
+
+from transformer_lens.model_bridge.generalized_components.normalization import (
+ NormalizationBridge,
+)
+
+
+class _Cfg:
+ def __init__(self, uses_rms_norm: bool, eps: float = 1e-5):
+ self.uses_rms_norm = uses_rms_norm
+ self.eps = eps
+
+
+def _make_bridge(layer: nn.Module, cfg: _Cfg, **kwargs) -> NormalizationBridge:
+ bridge = NormalizationBridge(name="ln", config=cfg, **kwargs)
+ bridge.set_original_component(layer)
+ return bridge
+
+
+def _layernorm(d: int) -> nn.LayerNorm:
+ layer = nn.LayerNorm(d, eps=1e-5)
+ nn.init.normal_(layer.weight, std=0.1)
+ nn.init.normal_(layer.bias, std=0.1)
+ layer.eval()
+ return layer
+
+
+def test_override_forces_layernorm_when_config_says_rmsnorm():
+ d = 16
+ layer = _layernorm(d)
+ bridge = _make_bridge(layer, _Cfg(uses_rms_norm=True), uses_rms_norm=False)
+ x = torch.randn(2, 5, d)
+ torch.testing.assert_close(bridge(x), layer(x), rtol=1e-5, atol=1e-5)
+
+
+def test_introspects_layernorm_when_config_says_rmsnorm():
+ d = 16
+ layer = _layernorm(d)
+ bridge = _make_bridge(layer, _Cfg(uses_rms_norm=True))
+ assert bridge.uses_rms_norm is False
+ x = torch.randn(2, 5, d)
+ torch.testing.assert_close(bridge(x), layer(x), rtol=1e-5, atol=1e-5)
+
+
+def test_introspects_rmsnorm_class_by_name():
+ class FakeRMSNorm(nn.Module):
+ def __init__(self, d, eps=1e-5):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(d))
+ self.variance_epsilon = eps
+
+ def forward(self, x):
+ rms = (x.float().pow(2).mean(-1, keepdim=True) + self.variance_epsilon).sqrt()
+ return (x.float() / rms * self.weight).to(x.dtype)
+
+ rms = FakeRMSNorm(16)
+ nn.init.normal_(rms.weight, std=0.1)
+ rms.eval()
+ bridge = _make_bridge(rms, _Cfg(uses_rms_norm=False))
+ assert bridge.uses_rms_norm is True
+
+
+def test_falls_back_to_config_when_component_unset():
+ bridge = NormalizationBridge(name="ln", config=_Cfg(uses_rms_norm=True))
+ assert bridge.original_component is None
+ assert bridge.uses_rms_norm is True
+
+ bridge2 = NormalizationBridge(name="ln", config=_Cfg(uses_rms_norm=False))
+ assert bridge2.uses_rms_norm is False
+
+
+def test_override_takes_precedence_over_config():
+ layer = nn.LayerNorm(8)
+ assert _make_bridge(layer, _Cfg(uses_rms_norm=True), uses_rms_norm=False).uses_rms_norm is False
+ assert _make_bridge(layer, _Cfg(uses_rms_norm=False), uses_rms_norm=True).uses_rms_norm is True
+
+
+def test_siglip_post_layernorm_resolves_to_layernorm_under_rmsnorm_config():
+ from transformer_lens.model_bridge.generalized_components.siglip_vision_encoder import (
+ SiglipVisionEncoderBridge,
+ )
+
+ encoder = SiglipVisionEncoderBridge(name="vision_tower", config=_Cfg(uses_rms_norm=True))
+ post_ln = encoder.submodules["post_layernorm"]
+ post_ln.set_original_component(nn.LayerNorm(8))
+ assert post_ln.uses_rms_norm is False
+
+
+def test_clip_layernorms_resolve_to_layernorm_under_rmsnorm_config():
+ from transformer_lens.model_bridge.generalized_components.clip_vision_encoder import (
+ CLIPVisionEncoderBridge,
+ )
+
+ encoder = CLIPVisionEncoderBridge(name="vision_tower", config=_Cfg(uses_rms_norm=True))
+ pre = encoder.submodules["pre_layernorm"]
+ post = encoder.submodules["post_layernorm"]
+ pre.set_original_component(nn.LayerNorm(8))
+ post.set_original_component(nn.LayerNorm(8))
+ assert pre.uses_rms_norm is False
+ assert post.uses_rms_norm is False
+
+
+def test_native_autograd_path_also_respects_override():
+ d = 16
+ layer = _layernorm(d)
+ bridge = _make_bridge(
+ layer, _Cfg(uses_rms_norm=True), uses_rms_norm=False, use_native_layernorm_autograd=True
+ )
+ x = torch.randn(2, 5, d)
+ torch.testing.assert_close(bridge(x), layer(x), rtol=1e-5, atol=1e-5)
diff --git a/transformer_lens/model_bridge/generalized_components/normalization.py b/transformer_lens/model_bridge/generalized_components/normalization.py
index 1dca9b0ce..bed44dafc 100644
--- a/transformer_lens/model_bridge/generalized_components/normalization.py
+++ b/transformer_lens/model_bridge/generalized_components/normalization.py
@@ -23,6 +23,7 @@ def __init__(
config: Any,
submodules: Optional[Dict[str, GeneralizedComponent]] = {},
use_native_layernorm_autograd: bool = False,
+ uses_rms_norm: Optional[bool] = None,
):
"""Initialize the normalization bridge.
@@ -33,11 +34,32 @@ def __init__(
use_native_layernorm_autograd: If True, use HuggingFace's native LayerNorm
autograd for exact gradient matching. If False,
use custom implementation. Defaults to False.
+ uses_rms_norm: Force RMSNorm vs LayerNorm; None defers to introspection
+ then ``config.uses_rms_norm``.
"""
super().__init__(name, config, submodules=submodules)
self.hook_normalized = HookPoint()
self.hook_scale = HookPoint()
self.use_native_layernorm_autograd = use_native_layernorm_autograd
+ self._uses_rms_norm_override = uses_rms_norm
+
+ @property
+ def uses_rms_norm(self) -> bool:
+ """Whether this bridge treats the wrapped module as RMSNorm.
+
+ Override > module introspection > config. Introspection guards against
+ a shared config (RMSNorm LM + LayerNorm vision tower) misclassifying
+ a real ``nn.LayerNorm``.
+ """
+ if self._uses_rms_norm_override is not None:
+ return self._uses_rms_norm_override
+ component = self.original_component
+ if component is not None:
+ if isinstance(component, torch.nn.LayerNorm):
+ return False
+ if "RMSNorm" in type(component).__name__:
+ return True
+ return bool(getattr(self.config, "uses_rms_norm", False))
def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Forward pass through the normalization bridge.
@@ -61,7 +83,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor:
elif hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding:
result = self._hf_autograd_forward_with_hooks(hidden_states)
else:
- uses_rms_norm = getattr(self.config, "uses_rms_norm", False)
+ uses_rms_norm = self.uses_rms_norm
# Upcast to float32 for normalization precision (matches HT's RMSNorm behavior)
input_dtype = hidden_states.dtype
if input_dtype not in (torch.float32, torch.float64):
@@ -105,7 +127,7 @@ def _hf_autograd_forward_with_hooks(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
# Upcast to float32 for hook precision (matches HT's RMSNorm/LayerNorm behavior)
x_float = x.float() if x.dtype not in (torch.float32, torch.float64) else x
- if not getattr(self.config, "uses_rms_norm", False):
+ if not self.uses_rms_norm:
x_centered = x_float - x_float.mean(-1, keepdim=True)
else:
x_centered = x_float
diff --git a/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py b/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py
index 569a09643..f02667a0f 100644
--- a/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py
+++ b/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py
@@ -117,6 +117,8 @@ def __init__(
# original_component (a SiglipVisionModel) by setup_submodules().
# SiglipVisionModel wraps SiglipVisionTransformer as .vision_model,
# so all paths go through vision_model.*.
+ # post_layernorm is nn.LayerNorm; NormalizationBridge introspects the
+ # wrapped module so the RMSNorm-LM config (Gemma 3, LLaVA) doesn't leak.
default_submodules = {
"embeddings": GeneralizedComponent(name="vision_model.embeddings"),
"encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"),