diff --git a/demos/Gemma3_Multimodal.ipynb b/demos/Gemma3_Multimodal.ipynb new file mode 100644 index 000000000..015232523 --- /dev/null +++ b/demos/Gemma3_Multimodal.ipynb @@ -0,0 +1,341 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Gemma 3 Multimodal Demo with TransformerBridge\n", + "\n", + "This notebook demonstrates how to use TransformerBridge with `Gemma3ForConditionalGeneration`,\n", + "the vision-language variant of Gemma 3. The model pairs a SigLIP vision encoder with the\n", + "Gemma 3 language model and is the same architecture used by MedGemma.\n", + "\n", + "We demonstrate:\n", + "1. Loading Gemma 3 (4B-it) through TransformerBridge\n", + "2. Multimodal generation from an image + text prompt\n", + "3. Capturing vision-language activations with `run_with_cache()`\n", + "\n", + "> **Gated model.** The `google/gemma-3-*` checkpoints are gated on Hugging Face. Accept\n", + "> the license at https://huggingface.co/google/gemma-3-4b-it and run `huggingface-cli login`\n", + "> (or set `HF_TOKEN`) before executing this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Detect Colab and install dependencies if needed\n", + "DEVELOPMENT_MODE = False\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis\n", + "except:\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import torch\n", + "from PIL import Image\n", + "import requests\n", + "from io import BytesIO\n", + "\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "from transformer_lens.model_bridge import TransformerBridge\n", + "\n", + "try:\n", + " import circuitsvis as cv\n", + "except ImportError:\n", + " print('circuitsvis not installed, attention visualization will not work')\n", + " cv = None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Gemma 3 through TransformerBridge\n", + "\n", + "TransformerBridge maps `Gemma3ForConditionalGeneration` to its multimodal adapter, which\n", + "wraps the SigLIP vision tower, the multimodal projector, and the Gemma 3 language model\n", + "into a single hooked model.\n", + "\n", + "We use **bfloat16** here \u2014 Gemma 3 is trained in bf16 and fp16 can produce unstable activations.\n", + "The 4B-it variant is the smallest multimodal Gemma 3 (the 270m and 1B checkpoints are text-only)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "model = TransformerBridge.boot_transformers(\n", + " \"google/gemma-3-4b-it\",\n", + " device=device,\n", + " dtype=torch.bfloat16,\n", + ")\n", + "\n", + "for param in model.parameters():\n", + " param.requires_grad = False\n", + "\n", + "print(f\"Model loaded on {device}\")\n", + "print(f\"Multimodal: {getattr(model.cfg, 'is_multimodal', False)}\")\n", + "print(f\"Layers: {model.cfg.n_layers}, Heads: {model.cfg.n_heads}\")\n", + "print(f\"Vision tokens per image: {getattr(model.cfg, 'mm_tokens_per_image', None)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load a test image\n", + "\n", + "We'll use a stop-sign photo from Australia to test the model's visual understanding." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "image_url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n", + "response = requests.get(image_url)\n", + "image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n", + "plt.imshow(image)\n", + "plt.axis('off')\n", + "plt.title('Test Image')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multimodal Generation\n", + "\n", + "Gemma 3 instruction-tuned models expect the chat template format with\n", + "`` / `` markers, and use `` as the image\n", + "placeholder (rather than LLaVA's ``). The processor expands ``\n", + "into the appropriate number of vision tokens (256 by default for Gemma 3 4B).\n", + "\n", + "We call `prepare_multimodal_inputs()` to run the processor on text + image, then pass\n", + "`pixel_values` to `generate()`. The bridge's `generate()` keeps a KV cache\n", + "(`use_past_kv_cache=True` by default) for efficient autoregressive decoding while\n", + "preserving full hook access." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "question = \"What do you see in this photo?\"\n", + "prompt = (\n", + " \"user\\n\"\n", + " f\"{question}\\n\"\n", + " \"model\\n\"\n", + ")\n", + "\n", + "# Prepare multimodal inputs (handles image processing + tokenization)\n", + "inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n", + "input_ids = inputs['input_ids']\n", + "pixel_values = inputs['pixel_values']\n", + "\n", + "# Pass any extra processor outputs (e.g. token_type_ids for Gemma 3)\n", + "extra_kwargs = {k: v for k, v in inputs.items()\n", + " if k not in ('input_ids', 'pixel_values')}\n", + "\n", + "generated_text = model.generate(\n", + " input_ids,\n", + " pixel_values=pixel_values,\n", + " max_new_tokens=80,\n", + " do_sample=False,\n", + " use_past_kv_cache=True,\n", + " return_type=\"str\",\n", + " **extra_kwargs,\n", + ")\n", + "\n", + "print('Generated text:', generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try a second image to confirm the model adapts its description:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "image_url_2 = \"https://github.com/zazamrykh/PicFinder/blob/main/images/doge.jpg?raw=true\"\n", + "response = requests.get(image_url_2)\n", + "image_2 = Image.open(BytesIO(response.content)).convert(\"RGB\")\n", + "plt.imshow(image_2)\n", + "plt.axis('off')\n", + "plt.show()\n", + "\n", + "inputs = model.prepare_multimodal_inputs(text=prompt, images=image_2)\n", + "input_ids = inputs['input_ids']\n", + "pixel_values = inputs['pixel_values']\n", + "extra_kwargs = {k: v for k, v in inputs.items()\n", + " if k not in ('input_ids', 'pixel_values')}\n", + "\n", + "generated_text = model.generate(\n", + " input_ids,\n", + " pixel_values=pixel_values,\n", + " max_new_tokens=80,\n", + " do_sample=False,\n", + " use_past_kv_cache=True,\n", + " return_type=\"str\",\n", + " **extra_kwargs,\n", + ")\n", + "print('Generated text:', generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inspecting Vision-Language Activations\n", + "\n", + "`run_with_cache()` accepts the same `pixel_values` argument and captures activations from\n", + "the vision encoder, the multimodal projector, and every transformer block in the language\n", + "model. This lets us inspect how the language tokens attend to image tokens during\n", + "multimodal processing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n", + "extra_kwargs = {k: v for k, v in inputs.items()\n", + " if k not in ('input_ids', 'pixel_values')}\n", + "\n", + "with torch.no_grad():\n", + " logits, cache = model.run_with_cache(\n", + " inputs['input_ids'],\n", + " pixel_values=inputs['pixel_values'],\n", + " **extra_kwargs,\n", + " )\n", + "\n", + "print(f'Logits shape: {logits.shape}')\n", + "print(f'Cache entries: {len(cache)}')\n", + "vision_keys = [k for k in cache.keys() if 'vision' in k.lower()]\n", + "print(f'Vision-related cache entries: {len(vision_keys)}')\n", + "print(f'Sample vision keys: {vision_keys[:5]}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "if cv is not None:\n", + " layer_to_visualize = 16\n", + " tokens_to_show = 30\n", + "\n", + " pattern_keys = [k for k in cache.keys() if f'blocks.{layer_to_visualize}' in k and 'pattern' in k]\n", + " if pattern_keys:\n", + " attention_pattern = cache[pattern_keys[0]]\n", + " if attention_pattern.ndim == 4:\n", + " attention_pattern = attention_pattern[0]\n", + "\n", + " token_ids = inputs['input_ids'][0].cpu()\n", + " str_tokens = model.tokenizer.convert_ids_to_tokens(token_ids)\n", + "\n", + " print(f'Layer {layer_to_visualize} Head Attention Patterns (last {tokens_to_show} tokens):')\n", + " display(cv.attention.attention_patterns(\n", + " tokens=str_tokens[-tokens_to_show:],\n", + " attention=attention_pattern[:, -tokens_to_show:, -tokens_to_show:].float().cpu(),\n", + " ))\n", + " else:\n", + " print(f'No attention pattern found for layer {layer_to_visualize}')\n", + " print(f'Available attention-related keys: {[k for k in cache.keys() if \"attn\" in k][:10]}')\n", + "else:\n", + " print('circuitsvis not available \u2014 skipping visualization')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "TransformerBridge provides native multimodal support for `Gemma3ForConditionalGeneration`:\n", + "\n", + "- **`boot_transformers(\"google/gemma-3-4b-it\")`** loads the full vision + projector + language pipeline\n", + "- **`prepare_multimodal_inputs(text=..., images=...)`** handles image processing and tokenization\n", + "- **`generate(input_ids, pixel_values=...)`** runs multimodal generation with KV cache and hooks\n", + "- **`run_with_cache(input_ids, pixel_values=...)`** captures activations including SigLIP vision tokens\n", + "\n", + "A few Gemma 3 specifics worth noting:\n", + "\n", + "- Use the chat-template format (`user ... model`) for instruction-tuned variants\n", + "- The image placeholder is ``, not ``\n", + "- Gemma 3 is trained in bf16 \u2014 prefer `torch.bfloat16` over `torch.float16`\n", + "- The same code path works for MedGemma (`google/medgemma-4b-it`, `google/medgemma-27b-it`) and any other `Gemma3ForConditionalGeneration` checkpoint" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "transformer-lens", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tests/unit/model_bridge/generalized_components/test_normalization_bridge.py b/tests/unit/model_bridge/generalized_components/test_normalization_bridge.py new file mode 100644 index 000000000..c93caf03f --- /dev/null +++ b/tests/unit/model_bridge/generalized_components/test_normalization_bridge.py @@ -0,0 +1,118 @@ +"""RMSNorm-vs-LayerNorm dispatch in NormalizationBridge. + +Regression coverage for Gemma 3 multimodal: SigLIP's post_layernorm was wrapped +under the LM's uses_rms_norm=True config, silently dropping mean-centering and +bias and producing gibberish completions. +""" + +import torch +import torch.nn as nn + +from transformer_lens.model_bridge.generalized_components.normalization import ( + NormalizationBridge, +) + + +class _Cfg: + def __init__(self, uses_rms_norm: bool, eps: float = 1e-5): + self.uses_rms_norm = uses_rms_norm + self.eps = eps + + +def _make_bridge(layer: nn.Module, cfg: _Cfg, **kwargs) -> NormalizationBridge: + bridge = NormalizationBridge(name="ln", config=cfg, **kwargs) + bridge.set_original_component(layer) + return bridge + + +def _layernorm(d: int) -> nn.LayerNorm: + layer = nn.LayerNorm(d, eps=1e-5) + nn.init.normal_(layer.weight, std=0.1) + nn.init.normal_(layer.bias, std=0.1) + layer.eval() + return layer + + +def test_override_forces_layernorm_when_config_says_rmsnorm(): + d = 16 + layer = _layernorm(d) + bridge = _make_bridge(layer, _Cfg(uses_rms_norm=True), uses_rms_norm=False) + x = torch.randn(2, 5, d) + torch.testing.assert_close(bridge(x), layer(x), rtol=1e-5, atol=1e-5) + + +def test_introspects_layernorm_when_config_says_rmsnorm(): + d = 16 + layer = _layernorm(d) + bridge = _make_bridge(layer, _Cfg(uses_rms_norm=True)) + assert bridge.uses_rms_norm is False + x = torch.randn(2, 5, d) + torch.testing.assert_close(bridge(x), layer(x), rtol=1e-5, atol=1e-5) + + +def test_introspects_rmsnorm_class_by_name(): + class FakeRMSNorm(nn.Module): + def __init__(self, d, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(d)) + self.variance_epsilon = eps + + def forward(self, x): + rms = (x.float().pow(2).mean(-1, keepdim=True) + self.variance_epsilon).sqrt() + return (x.float() / rms * self.weight).to(x.dtype) + + rms = FakeRMSNorm(16) + nn.init.normal_(rms.weight, std=0.1) + rms.eval() + bridge = _make_bridge(rms, _Cfg(uses_rms_norm=False)) + assert bridge.uses_rms_norm is True + + +def test_falls_back_to_config_when_component_unset(): + bridge = NormalizationBridge(name="ln", config=_Cfg(uses_rms_norm=True)) + assert bridge.original_component is None + assert bridge.uses_rms_norm is True + + bridge2 = NormalizationBridge(name="ln", config=_Cfg(uses_rms_norm=False)) + assert bridge2.uses_rms_norm is False + + +def test_override_takes_precedence_over_config(): + layer = nn.LayerNorm(8) + assert _make_bridge(layer, _Cfg(uses_rms_norm=True), uses_rms_norm=False).uses_rms_norm is False + assert _make_bridge(layer, _Cfg(uses_rms_norm=False), uses_rms_norm=True).uses_rms_norm is True + + +def test_siglip_post_layernorm_resolves_to_layernorm_under_rmsnorm_config(): + from transformer_lens.model_bridge.generalized_components.siglip_vision_encoder import ( + SiglipVisionEncoderBridge, + ) + + encoder = SiglipVisionEncoderBridge(name="vision_tower", config=_Cfg(uses_rms_norm=True)) + post_ln = encoder.submodules["post_layernorm"] + post_ln.set_original_component(nn.LayerNorm(8)) + assert post_ln.uses_rms_norm is False + + +def test_clip_layernorms_resolve_to_layernorm_under_rmsnorm_config(): + from transformer_lens.model_bridge.generalized_components.clip_vision_encoder import ( + CLIPVisionEncoderBridge, + ) + + encoder = CLIPVisionEncoderBridge(name="vision_tower", config=_Cfg(uses_rms_norm=True)) + pre = encoder.submodules["pre_layernorm"] + post = encoder.submodules["post_layernorm"] + pre.set_original_component(nn.LayerNorm(8)) + post.set_original_component(nn.LayerNorm(8)) + assert pre.uses_rms_norm is False + assert post.uses_rms_norm is False + + +def test_native_autograd_path_also_respects_override(): + d = 16 + layer = _layernorm(d) + bridge = _make_bridge( + layer, _Cfg(uses_rms_norm=True), uses_rms_norm=False, use_native_layernorm_autograd=True + ) + x = torch.randn(2, 5, d) + torch.testing.assert_close(bridge(x), layer(x), rtol=1e-5, atol=1e-5) diff --git a/transformer_lens/model_bridge/generalized_components/normalization.py b/transformer_lens/model_bridge/generalized_components/normalization.py index 1dca9b0ce..bed44dafc 100644 --- a/transformer_lens/model_bridge/generalized_components/normalization.py +++ b/transformer_lens/model_bridge/generalized_components/normalization.py @@ -23,6 +23,7 @@ def __init__( config: Any, submodules: Optional[Dict[str, GeneralizedComponent]] = {}, use_native_layernorm_autograd: bool = False, + uses_rms_norm: Optional[bool] = None, ): """Initialize the normalization bridge. @@ -33,11 +34,32 @@ def __init__( use_native_layernorm_autograd: If True, use HuggingFace's native LayerNorm autograd for exact gradient matching. If False, use custom implementation. Defaults to False. + uses_rms_norm: Force RMSNorm vs LayerNorm; None defers to introspection + then ``config.uses_rms_norm``. """ super().__init__(name, config, submodules=submodules) self.hook_normalized = HookPoint() self.hook_scale = HookPoint() self.use_native_layernorm_autograd = use_native_layernorm_autograd + self._uses_rms_norm_override = uses_rms_norm + + @property + def uses_rms_norm(self) -> bool: + """Whether this bridge treats the wrapped module as RMSNorm. + + Override > module introspection > config. Introspection guards against + a shared config (RMSNorm LM + LayerNorm vision tower) misclassifying + a real ``nn.LayerNorm``. + """ + if self._uses_rms_norm_override is not None: + return self._uses_rms_norm_override + component = self.original_component + if component is not None: + if isinstance(component, torch.nn.LayerNorm): + return False + if "RMSNorm" in type(component).__name__: + return True + return bool(getattr(self.config, "uses_rms_norm", False)) def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor: """Forward pass through the normalization bridge. @@ -61,7 +83,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor: elif hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding: result = self._hf_autograd_forward_with_hooks(hidden_states) else: - uses_rms_norm = getattr(self.config, "uses_rms_norm", False) + uses_rms_norm = self.uses_rms_norm # Upcast to float32 for normalization precision (matches HT's RMSNorm behavior) input_dtype = hidden_states.dtype if input_dtype not in (torch.float32, torch.float64): @@ -105,7 +127,7 @@ def _hf_autograd_forward_with_hooks(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): # Upcast to float32 for hook precision (matches HT's RMSNorm/LayerNorm behavior) x_float = x.float() if x.dtype not in (torch.float32, torch.float64) else x - if not getattr(self.config, "uses_rms_norm", False): + if not self.uses_rms_norm: x_centered = x_float - x_float.mean(-1, keepdim=True) else: x_centered = x_float diff --git a/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py b/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py index 569a09643..f02667a0f 100644 --- a/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py +++ b/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py @@ -117,6 +117,8 @@ def __init__( # original_component (a SiglipVisionModel) by setup_submodules(). # SiglipVisionModel wraps SiglipVisionTransformer as .vision_model, # so all paths go through vision_model.*. + # post_layernorm is nn.LayerNorm; NormalizationBridge introspects the + # wrapped module so the RMSNorm-LM config (Gemma 3, LLaVA) doesn't leak. default_submodules = { "embeddings": GeneralizedComponent(name="vision_model.embeddings"), "encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"),