diff --git a/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py b/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py new file mode 100644 index 000000000..180bd9fd0 --- /dev/null +++ b/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py @@ -0,0 +1,125 @@ +"""Asserts ``TransformerBridge`` reproduces ``AutoModelForCausalLM`` eager-attention logits. + +Issue #385 reported drift between bridge and HF for rotary models like Pythia. The drift +was an attention-implementation mismatch — bridge always uses eager, default HF loads use +SDPA, which reorders ops in a fused kernel. Bridge vs HF *eager* matches to fp32-noise. +""" + +from typing import Callable + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from transformer_lens.model_bridge import TransformerBridge + +MODEL_NAME = "EleutherAI/pythia-70m" + +# Op-reorder noise floor for fp32 transformer forward passes. We currently +# measure 0.0 on this model, but allow a small epsilon so harmless refactors +# (intermediate allocations, equivalent op reorderings) don't break the test. +FP32_NOISE_TOL = 1e-5 + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def bridge(): + return TransformerBridge.boot_transformers(MODEL_NAME, device="cpu", dtype=torch.float32) + + +@pytest.fixture(scope="module") +def hf_eager(): + """HF model loaded independently of the bridge's wrapped instance.""" + return AutoModelForCausalLM.from_pretrained( + MODEL_NAME, torch_dtype=torch.float32, attn_implementation="eager" + ).eval() + + +@pytest.fixture +def tokenize(tokenizer) -> Callable[[str], torch.Tensor]: + def _tok(prompt: str) -> torch.Tensor: + return tokenizer(prompt, return_tensors="pt").input_ids + + return _tok + + +@pytest.mark.parametrize("prompt", ["Hello, world!", "The quick brown fox jumps"]) +def test_bridge_logits_match_hf_eager(bridge, hf_eager, tokenize, prompt): + tokens = tokenize(prompt) + with torch.inference_mode(): + bridge_logits = bridge(tokens) + hf_logits = hf_eager(tokens).logits + max_diff = (bridge_logits - hf_logits).abs().max().item() + assert max_diff < FP32_NOISE_TOL, ( + f"{MODEL_NAME!r} bridge vs HF eager drift={max_diff:.2e} on {prompt!r} " + f"exceeds fp32-noise tolerance {FP32_NOISE_TOL:.0e} — bridge's " + f"_reconstruct_attention may have regressed (see issue #385)." + ) + + +def test_bridge_residual_stream_matches_hf_eager(bridge, hf_eager, tokenize): + """Per-layer parity catches compensating errors that wash out at the final logits.""" + tokens = tokenize("Hello, world!") + n_layers = len(hf_eager.gpt_neox.layers) + + hf_layer_out: dict[int, torch.Tensor] = {} + + def _make_hf_hook(idx): + def _h(_m, _i, o): + hf_layer_out[idx] = (o[0] if isinstance(o, tuple) else o).detach() + + return _h + + handles = [ + layer.register_forward_hook(_make_hf_hook(i)) + for i, layer in enumerate(hf_eager.gpt_neox.layers) + ] + try: + with torch.inference_mode(): + hf_eager(tokens) + finally: + for h in handles: + h.remove() + + bridge_layer_out: dict[int, torch.Tensor] = {} + fwd_hooks = [ + ( + f"blocks.{i}.hook_resid_post", + lambda v, hook, idx=i: bridge_layer_out.__setitem__(idx, v.detach()), + ) + for i in range(n_layers) + ] + with torch.inference_mode(): + bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks) + + for i in range(n_layers): + d = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item() + assert d < FP32_NOISE_TOL, ( + f"layer {i} residual drift={d:.2e} exceeds fp32-noise tolerance " + f"{FP32_NOISE_TOL:.0e} — bridge layer output diverges from HF eager." + ) + + +def test_bridge_attention_reconstruction_actually_runs(bridge, tokenize): + """Guard against tautology: prove bridge's custom attention path executes. + + If a future refactor made the bridge delegate to HF directly, the previous + parity tests would pass trivially. This one fails fast in that case by + asserting bridge-specific hooks fire during forward. + """ + tokens = tokenize("Hello, world!") + attn_scores_fired: list[bool] = [] + bridge.run_with_hooks( + tokens, + fwd_hooks=[ + ("blocks.0.attn.hook_attn_scores", lambda v, hook: attn_scores_fired.append(True)), + ], + ) + assert attn_scores_fired, ( + "blocks.0.attn.hook_attn_scores did not fire — bridge no longer runs its " + "own attention reconstruction, making the parity tests tautological." + ) diff --git a/tests/unit/factored_matrix/test_properties.py b/tests/unit/factored_matrix/test_properties.py index 091db15e8..d8f13d450 100644 --- a/tests/unit/factored_matrix/test_properties.py +++ b/tests/unit/factored_matrix/test_properties.py @@ -78,19 +78,52 @@ def test_transpose_property(self, factored_matrices): def test_svd_property(self, factored_matrices): for factored_matrix in factored_matrices: - U, S, Vh = factored_matrix.svd() - assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5) - # test that U and Vh are unitary + U, S, V = factored_matrix.svd() + assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.T, atol=1e-5) + # test that U and V are unitary assert torch.allclose(U.T @ U, torch.eye(U.shape[-1]), atol=1e-5) - assert torch.allclose(Vh.T @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5) + assert torch.allclose(V.T @ V, torch.eye(V.shape[-1]), atol=1e-5) def test_svd_property_leading_ones(self, factored_matrices_leading_ones): for factored_matrix in factored_matrices_leading_ones: - U, S, Vh = factored_matrix.svd() - assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5) - # test that U and Vh are unitary + U, S, V = factored_matrix.svd() + assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.mT, atol=1e-5) + # test that U and V are unitary assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5) - assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5) + assert torch.allclose(V.mT @ V, torch.eye(V.shape[-1]), atol=1e-5) + + def test_V_and_Vh_alias_match(self, factored_matrices): + import warnings + + for factored_matrix in factored_matrices: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + vh_value = factored_matrix.Vh + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert torch.equal(vh_value, factored_matrix.V) + + def test_svd_caches_per_instance(self): + """svd() should cache its result on the instance — repeated calls return the same tensors.""" + m = FactoredMatrix(randn(4, 3), randn(3, 4)) + first_U, first_S, first_V = m.svd() + second_U, second_S, second_V = m.svd() + assert first_U is second_U + assert first_S is second_S + assert first_V is second_V + + def test_svd_does_not_prevent_gc(self): + """svd's cache must not hold a strong reference that prevents the instance from being GC'd""" + import gc + import weakref + + m = FactoredMatrix(randn(4, 3), randn(3, 4)) + _ = m.svd() # populate the cache + ref = weakref.ref(m) + del m + gc.collect() + assert ( + ref() is None + ), "FactoredMatrix instance survived deletion — svd cache is leaking references." def test_eigenvalues_property(self, factored_matrices): for factored_matrix in factored_matrices: @@ -141,16 +174,6 @@ def test_collapse_l(self, factored_matrices): expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.V) assert torch.allclose(result, expected) - def test_V_and_Vh_alias_match(self, factored_matrices): - import warnings - - for factored_matrix in factored_matrices: - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - vh_value = factored_matrix.Vh - assert any(issubclass(w.category, DeprecationWarning) for w in caught) - assert torch.equal(vh_value, factored_matrix.V) - def test_collapse_r(self, factored_matrices): for factored_matrix in factored_matrices: result = factored_matrix.collapse_r() diff --git a/tests/unit/model_bridge/test_checkpoint_revision.py b/tests/unit/model_bridge/test_checkpoint_revision.py new file mode 100644 index 000000000..a3a4919dc --- /dev/null +++ b/tests/unit/model_bridge/test_checkpoint_revision.py @@ -0,0 +1,175 @@ +"""Unit tests for the bridge revision/checkpoint API (issue #453).""" + +from unittest.mock import patch + +import pytest + +from transformer_lens.model_bridge.sources.transformers import ( + _CHECKPOINT_REVISION_FORMATS, + _resolve_checkpoint_to_revision, +) + + +class TestResolveCheckpointToRevision: + def test_pythia_index_resolves_to_step_revision(self): + labels = [0, 1000, 3000, 10000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + revision = _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=2, checkpoint_value=None + ) + assert revision == "step3000" + + def test_pythia_value_resolves_to_step_revision(self): + labels = [0, 1000, 3000, 10000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + revision = _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=10000 + ) + assert revision == "step10000" + + def test_stanford_crfm_uses_checkpoint_prefix(self): + labels = [100, 200, 400] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + revision = _resolve_checkpoint_to_revision( + "stanford-crfm/alias-gpt2-small-x21", checkpoint_index=1, checkpoint_value=None + ) + assert revision == "checkpoint-200" + + def test_unknown_family_raises(self): + with pytest.raises(ValueError, match="known checkpoint revision convention"): + _resolve_checkpoint_to_revision("gpt2", checkpoint_index=0, checkpoint_value=None) + + def test_index_out_of_range_raises(self): + labels = [0, 1000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + with pytest.raises(ValueError, match="out of range"): + _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=5, checkpoint_value=None + ) + + def test_unknown_value_raises(self): + labels = [0, 1000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + with pytest.raises(ValueError, match="not in available checkpoints"): + _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=99999 + ) + + def test_neither_provided_raises(self): + with pytest.raises(ValueError, match="Must specify"): + _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=None + ) + + def test_known_families_registered(self): + assert "EleutherAI/pythia" in _CHECKPOINT_REVISION_FORMATS + assert "stanford-crfm" in _CHECKPOINT_REVISION_FORMATS + + +class _AbortBoot(Exception): + """Raised by the model-load patch to short-circuit ``boot()`` before any real load.""" + + +class TestBootRevisionPlumbing: + """Verify that ``revision`` and ``checkpoint_*`` reach HF's from_pretrained calls. + + Uses pythia-70m's real cached config (avoids MagicMock fragility through the + adapter/config-mapping path) and aborts at the model-load step. + """ + + def _patched_boot(self, **boot_kwargs): + from transformer_lens.model_bridge.sources import transformers as bridge_src + + captured: dict = {} + real_autoconfig = bridge_src.AutoConfig.from_pretrained + + def capture_autoconfig(name, **kwargs): + captured["autoconfig_kwargs"] = dict(kwargs) + # Strip the (possibly fake) revision so the real call hits the CI cache. + kwargs.pop("revision", None) + return real_autoconfig(name, **kwargs) + + def capture_model_load(*args, **kwargs): + captured["model_kwargs"] = kwargs + raise _AbortBoot() + + with patch.object( + bridge_src.AutoConfig, "from_pretrained", side_effect=capture_autoconfig + ), patch( + "transformers.AutoModelForCausalLM.from_pretrained", + side_effect=capture_model_load, + ): + with pytest.raises(_AbortBoot): + bridge_src.boot(model_name="EleutherAI/pythia-70m", device="cpu", **boot_kwargs) + + return captured + + def test_revision_forwarded_to_autoconfig(self): + captured = self._patched_boot(revision="step3000") + assert captured["autoconfig_kwargs"].get("revision") == "step3000" + + def test_revision_forwarded_to_model_load(self): + captured = self._patched_boot(revision="step3000") + assert captured.get("model_kwargs", {}).get("revision") == "step3000" + + def test_checkpoint_index_resolves_to_revision(self): + labels = [0, 1000, 3000, 10000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + captured = self._patched_boot(checkpoint_index=2) + assert captured["autoconfig_kwargs"].get("revision") == "step3000" + assert captured.get("model_kwargs", {}).get("revision") == "step3000" + + def test_conflicting_revision_and_checkpoint_raises(self): + from transformer_lens.model_bridge.sources import transformers as bridge_src + + with pytest.raises(ValueError, match="not both"): + bridge_src.boot( + model_name="EleutherAI/pythia-70m", + revision="step1000", + checkpoint_index=2, + ) + + def test_default_revision_is_none(self): + """With no revision/checkpoint args, revision is not added to model_kwargs.""" + captured = self._patched_boot() + assert captured["autoconfig_kwargs"].get("revision") is None + assert "revision" not in captured.get("model_kwargs", {}) + + +class TestHookedTransformerCheckpointLabelAlias: + def test_checkpoint_label_routes_to_checkpoint_value(self): + from transformer_lens import HookedTransformer + + with patch("transformer_lens.loading.get_pretrained_model_config") as mock_get_cfg: + mock_get_cfg.side_effect = RuntimeError("stop after config call") + with pytest.raises(RuntimeError, match="stop after config call"): + HookedTransformer.from_pretrained("EleutherAI/pythia-70m", checkpoint_label=3000) + + _, kwargs = mock_get_cfg.call_args + assert kwargs["checkpoint_value"] == 3000 + + def test_checkpoint_label_and_value_together_raises(self): + from transformer_lens import HookedTransformer + + with pytest.raises(ValueError, match="aliases"): + HookedTransformer.from_pretrained( + "EleutherAI/pythia-70m", checkpoint_label=3000, checkpoint_value=1000 + ) diff --git a/tests/unit/test_hook_introspection.py b/tests/unit/test_hook_introspection.py new file mode 100644 index 000000000..620195094 --- /dev/null +++ b/tests/unit/test_hook_introspection.py @@ -0,0 +1,156 @@ +"""Tests for the hook-introspection API added for issue #297.""" + +from unittest import mock + +from transformer_lens.hook_points import ( + HookedRootModule, + HookIntrospectionMixin, + HookPoint, +) + + +class _ToyModel(HookedRootModule): + """Minimal HookedRootModule with two hook points for testing.""" + + def __init__(self): + super().__init__() + self.hook_a = HookPoint() + self.hook_b = HookPoint() + self.setup() + + +def _my_named_hook(activation, hook): + return activation + + +def _other_hook(activation, hook): + return activation + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_lens_handle_stores_user_hook(mock_handle): + mock_handle.return_value.id = 0 + hp = HookPoint() + hp.add_hook(_my_named_hook, dir="fwd") + assert hp.fwd_hooks[0].user_hook is _my_named_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_hookpoint_repr_includes_hook_count(mock_handle): + mock_handle.return_value.id = 0 + hp = HookPoint() + hp.name = "blocks.0.hook_resid_post" + assert "blocks.0.hook_resid_post" in repr(hp) + hp.add_hook(_my_named_hook, dir="fwd") + hp.add_hook(_other_hook, dir="fwd") + rep = repr(hp) + assert "2 fwd" in rep + assert "bwd" not in rep + + +def test_hookpoint_repr_with_no_name_and_no_hooks(): + assert repr(HookPoint()) == "HookPoint()" + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_empty_model_returns_empty_dict(mock_handle): + model = _ToyModel() + assert model.list_hooks() == {} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_returns_user_callable(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + result = model.list_hooks() + assert set(result.keys()) == {"hook_a"} + handles = result["hook_a"] + assert len(handles) == 1 + assert handles[0].user_hook is _my_named_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_name_filter_string(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_b.add_hook(_other_hook, dir="fwd") + assert set(model.list_hooks(name_filter="hook_a").keys()) == {"hook_a"} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_name_filter_list(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_b.add_hook(_other_hook, dir="fwd") + assert set(model.list_hooks(name_filter=["hook_a", "hook_b"]).keys()) == {"hook_a", "hook_b"} + assert set(model.list_hooks(name_filter=["hook_b"]).keys()) == {"hook_b"} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_name_filter_callable(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_b.add_hook(_other_hook, dir="fwd") + result = model.list_hooks(name_filter=lambda n: n.endswith("a")) + assert set(result.keys()) == {"hook_a"} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_direction_filter(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_a.add_hook(_other_hook, dir="bwd") + assert len(model.list_hooks(dir="fwd")["hook_a"]) == 1 + assert len(model.list_hooks(dir="bwd")["hook_a"]) == 1 + assert len(model.list_hooks(dir="both")["hook_a"]) == 2 + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_excludes_permanent_when_requested(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd", is_permanent=True) + model.hook_a.add_hook(_other_hook, dir="fwd", is_permanent=False) + assert len(model.list_hooks(including_permanent=True)["hook_a"]) == 2 + handles = model.list_hooks(including_permanent=False)["hook_a"] + assert len(handles) == 1 + assert handles[0].user_hook is _other_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_mixin_works_on_class_with_hook_dict_attribute(mock_handle): + """Pin the duck-typed contract: mixin reads ``hook_dict`` off any class that exposes it.""" + mock_handle.return_value.id = 0 + + class Bag(HookIntrospectionMixin): + def __init__(self): + hp = HookPoint() + hp.add_hook(_my_named_hook, dir="fwd") + self.hook_dict = {"only_hook": hp} + + result = Bag().list_hooks() + assert set(result.keys()) == {"only_hook"} + assert result["only_hook"][0].user_hook is _my_named_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_mixin_works_on_class_with_hook_dict_property(mock_handle): + """``getattr`` indirection must accept a ``@property`` provider too (bridge case).""" + mock_handle.return_value.id = 0 + + class PropertyBag(HookIntrospectionMixin): + def __init__(self): + self._hooks = {"only_hook": HookPoint()} + self._hooks["only_hook"].add_hook(_my_named_hook, dir="fwd") + + @property + def hook_dict(self): + return self._hooks + + result = PropertyBag().list_hooks() + assert result["only_hook"][0].user_hook is _my_named_hook diff --git a/tests/unit/test_hook_points.py b/tests/unit/test_hook_points.py index 4e7828450..df85849f0 100644 --- a/tests/unit/test_hook_points.py +++ b/tests/unit/test_hook_points.py @@ -60,10 +60,11 @@ def hook2(activation, hook): # Make LensHandle constructor return a simple container capturing the pt_handle ('hook') class _LensHandleBox: - def __init__(self, handle, is_permanent, context_level): + def __init__(self, handle, is_permanent, context_level, user_hook=None): self.hook = handle self.is_permanent = is_permanent self.context_level = context_level + self.user_hook = user_hook mock_lens_handle.side_effect = _LensHandleBox diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 674dabf43..712394f55 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -6,7 +6,7 @@ from __future__ import annotations -from functools import lru_cache +from functools import cached_property from typing import Any, List, Protocol, Tuple, Union, cast, overload, runtime_checkable import torch @@ -214,17 +214,15 @@ def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]: def T(self) -> FactoredMatrix: return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1)) - @lru_cache(maxsize=None) - def svd( + @cached_property + def _svd_cached( self, ) -> Tuple[ Float[torch.Tensor, "*leading_dims ldim mdim"], Float[torch.Tensor, "*leading_dims mdim"], Float[torch.Tensor, "*leading_dims rdim mdim"], ]: - """Singular Value Decomposition: returns ``(U, S, V)`` such that ``U @ S.diag() @ V.transpose(-2, -1) == M``.""" - # Transpose Vh back to V — the long-standing return convention; downstream - # callers transposed the old `.Vh` result, so preserving V keeps them working. + # Cache on the instance (frees with it) rather than class-level — fixes the lru_cache leak. Ua, Sa, Vha = torch.linalg.svd(self.A, full_matrices=False) Ub, Sb, Vhb = torch.linalg.svd(self.B, full_matrices=False) Va = tensor_utils.transpose(Vha) @@ -232,10 +230,17 @@ def svd( middle = Sa[..., :, None] * tensor_utils.transpose(Va) @ Ub * Sb[..., None, :] Um, Sm, Vhm = torch.linalg.svd(middle, full_matrices=False) Vm = tensor_utils.transpose(Vhm) - U = Ua @ Um - V = Vb @ Vm - S = Sm - return U, S, V + return Ua @ Um, Sm, Vb @ Vm + + def svd( + self, + ) -> Tuple[ + Float[torch.Tensor, "*leading_dims ldim mdim"], + Float[torch.Tensor, "*leading_dims mdim"], + Float[torch.Tensor, "*leading_dims rdim mdim"], + ]: + """Singular Value Decomposition: returns ``(U, S, V)`` such that ``U @ S.diag() @ V.transpose(-2, -1) == M``.""" + return self._svd_cached @property def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 9a7db35ae..103e532bf 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1157,6 +1157,7 @@ def from_pretrained( refactor_factored_attn_matrices: bool = False, checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, + checkpoint_label: Optional[int] = None, hf_model: Optional[PreTrainedModel] = None, device: Optional[Union[str, torch.device]] = None, n_devices: int = 1, @@ -1254,6 +1255,8 @@ def from_pretrained( labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be ignored. + checkpoint_label: Alias for ``checkpoint_value`` kept for backwards compatibility with + older docs and downstream code. Cannot be combined with ``checkpoint_value``. hf_model: If you have already loaded in the HuggingFace model, you can pass it in here rather than needing to recreate the object. Defaults to None. @@ -1311,6 +1314,13 @@ def from_pretrained( 3. Global default ("right") first_n_layers: If specified, only load the first n layers of the model. """ + if checkpoint_value is not None and checkpoint_label is not None: + raise ValueError( + "Specify checkpoint_value or checkpoint_label, not both — they are aliases." + ) + elif checkpoint_label is not None: + checkpoint_value = checkpoint_label + if model_name.lower().startswith("t5"): raise RuntimeError( "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index a942ec39d..7ec1ed7c1 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -48,6 +48,9 @@ class LensHandle: context_level: Optional[int] = None """Context level associated with the hooks context manager for the given hook.""" + user_hook: Optional[Callable] = None + """The original hook callable, before ``add_hook`` wraps it.""" + # Define type aliases NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] @@ -167,6 +170,14 @@ def __init__(self): # This scales the SUM of gradients, not element-wise (to avoid PyTorch bugs) self.backward_scale: float = 1.0 + def __repr__(self) -> str: + bits = [f"name={self.name!r}"] if self.name is not None else [] + if self.fwd_hooks: + bits.append(f"{len(self.fwd_hooks)} fwd") + if self.bwd_hooks: + bits.append(f"{len(self.bwd_hooks)} bwd") + return f"HookPoint({', '.join(bits)})" if bits else "HookPoint()" + def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: self.add_hook(hook, dir=dir, is_permanent=True) @@ -273,7 +284,7 @@ def _bwd_hook_wrapper( else: raise ValueError(f"Invalid direction {dir}") - handle = LensHandle(pt_handle, is_permanent, level) + handle = LensHandle(pt_handle, is_permanent, level, user_hook=hook) if prepend: # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... @@ -376,7 +387,55 @@ def layer(self): # %% -class HookedRootModule(nn.Module): +class HookIntrospectionMixin: + """``list_hooks()`` mixin for any class exposing a ``hook_dict``. + + Accessed via ``getattr`` so subclasses can provide ``hook_dict`` as either + an instance attribute (``HookedRootModule``) or a ``@property`` (``TransformerBridge``). + """ + + def list_hooks( + self, + name_filter: NamesFilter = None, + dir: Literal["fwd", "bwd", "both"] = "both", + including_permanent: bool = True, + ) -> dict[str, list[LensHandle]]: + """Return attached hooks grouped by HookPoint name; empty HookPoints are omitted. + + Args: + name_filter: A hook name, list of names, or predicate. ``None`` matches all. + dir: Restrict to forward, backward, or both directions. + including_permanent: If False, drop permanent hooks from the result. + """ + if name_filter is None: + matches: Callable[[str], bool] = lambda _: True + elif callable(name_filter): + matches = name_filter + elif isinstance(name_filter, str): + target = name_filter + matches = lambda n: n == target + else: + allowed = set(name_filter) + matches = lambda n: n in allowed + + out: dict[str, list[LensHandle]] = {} + hook_dict: dict[str, HookPoint] = getattr(self, "hook_dict") + for name, hp in hook_dict.items(): + if not matches(name): + continue + handles: list[LensHandle] = [] + if dir in ("fwd", "both"): + handles.extend(hp.fwd_hooks) + if dir in ("bwd", "both"): + handles.extend(hp.bwd_hooks) + if not including_permanent: + handles = [h for h in handles if not h.is_permanent] + if handles: + out[name] = handles + return out + + +class HookedRootModule(HookIntrospectionMixin, nn.Module): """A class building on nn.Module to interface nicely with HookPoints. Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index cb3895368..91e685204 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -33,7 +33,7 @@ from transformer_lens import utilities as utils from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookPoint +from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.component_setup import set_original_components from transformer_lens.model_bridge.composition_scores import CompositionScores @@ -94,7 +94,7 @@ def build_alias_to_canonical_map(hook_dict, prefix=""): return aliases -class TransformerBridge(nn.Module): +class TransformerBridge(HookIntrospectionMixin, nn.Module): """Bridge between HuggingFace and TransformerLens models. This class provides a standardized interface to access components of a transformer @@ -196,6 +196,9 @@ def boot_transformers( n_devices: Optional[int] = None, max_memory: Optional[Dict[Union[str, int], str]] = None, n_ctx: Optional[int] = None, + revision: Optional[str] = None, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -204,6 +207,11 @@ def boot_transformers( Call ``enable_compatibility_mode()`` on the result for HookedTransformer- equivalent numerics. Generation, argmax, and CE loss are unaffected. + Attention implementation is forced to ``"eager"`` so hooks can capture scores + and patterns. For an apples-to-apples HF comparison, load the HF model with + ``attn_implementation="eager"`` too; comparing against the default ``"sdpa"`` + shows ~1e-3 fp32 drift from kernel-level op reordering, not a bridge bug. + Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. @@ -231,6 +239,14 @@ def boot_transformers( n_ctx: Optional context length override. Writes to the appropriate HF config field for this model automatically (callers don't need to know the field name). Warns if larger than the model's default context length. + revision: Optional HF revision (branch, tag, or commit). Forwarded to the underlying + ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained`` calls. + Mutually exclusive with ``checkpoint_index`` / ``checkpoint_value``. + checkpoint_index: Index into the available training checkpoints for the model family + (currently ``EleutherAI/pythia*`` and ``stanford-crfm/*``). Resolved to a revision + string via known per-family naming conventions. + checkpoint_value: Training step or token count of the desired checkpoint. Alternative + to ``checkpoint_index``; must match an entry in the family's checkpoint label list. Returns: The bridge to the loaded model. @@ -251,6 +267,9 @@ def boot_transformers( n_devices=n_devices, max_memory=max_memory, n_ctx=n_ctx, + revision=revision, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, ) @property diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index b7c4656f4..be2659e89 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -277,6 +277,55 @@ def get_hf_model_class_for_architecture(architecture: str): return AutoModelForCausalLM +# Known training-checkpoint revision conventions on HF. +_CHECKPOINT_REVISION_FORMATS: dict[str, str] = { + "EleutherAI/pythia": "step{value}", + "stanford-crfm": "checkpoint-{value}", +} + + +def _resolve_checkpoint_to_revision( + model_name: str, + checkpoint_index: int | None, + checkpoint_value: int | None, +) -> str: + """Convert a checkpoint index/value into an HF revision string, validated against ``get_checkpoint_labels``.""" + if checkpoint_index is None and checkpoint_value is None: + raise ValueError("Must specify either checkpoint_index or checkpoint_value.") + + format_str: str | None = None + for prefix, fmt in _CHECKPOINT_REVISION_FORMATS.items(): + if model_name.startswith(prefix): + format_str = fmt + break + if format_str is None: + raise ValueError( + f"Model {model_name!r} does not have a known checkpoint revision convention. " + f"Pass revision= directly if your model uses HF revisions. Known checkpoint " + f"families: {list(_CHECKPOINT_REVISION_FORMATS.keys())}." + ) + + from transformer_lens.loading_from_pretrained import get_checkpoint_labels + + labels, _ = get_checkpoint_labels(model_name) + if checkpoint_value is not None: + if checkpoint_value not in labels: + raise ValueError( + f"checkpoint_value={checkpoint_value} not in available checkpoints for " + f"{model_name!r}. {len(labels)} labels available, " + f"first/last: {labels[0]}..{labels[-1]}." + ) + else: + assert checkpoint_index is not None # narrowed by initial guard + if not 0 <= checkpoint_index < len(labels): + raise ValueError( + f"checkpoint_index={checkpoint_index} out of range [0, {len(labels)}) " + f"for {model_name!r}." + ) + checkpoint_value = labels[checkpoint_index] + return format_str.format(value=checkpoint_value) + + def boot( model_name: str, hf_config_overrides: dict | None = None, @@ -288,6 +337,9 @@ def boot( model_class: Any | None = None, hf_model: Any | None = None, n_ctx: int | None = None, + revision: str | None = None, + checkpoint_index: int | None = None, + checkpoint_value: int | None = None, # Experimental – Have not been fully tested on multi-gpu devices # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues device_map: str | dict[str, str | int] | None = None, @@ -321,6 +373,15 @@ def boot( uses (n_positions / max_position_embeddings / etc.), so callers don't need to know the field name. If larger than the model's default, a warning is emitted — quality may degrade past the trained length for rotary models. + revision: Optional HF revision string (branch, tag, or commit). Forwarded to + ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained``. + Mutually exclusive with ``checkpoint_index`` and ``checkpoint_value``. + checkpoint_index: Index into the available training checkpoints for the model family. + Convenience over ``revision`` for checkpointed models like EleutherAI/pythia* and + stanford-crfm/*. Resolved to a revision string via the known per-family naming + conventions (``step{value}`` for Pythia, ``checkpoint-{value}`` for stanford-crfm). + checkpoint_value: Training step or token count of the desired checkpoint. Alternative to + ``checkpoint_index``; must be one of the labels returned by ``get_checkpoint_labels``. Returns: The bridge to the loaded model. @@ -332,6 +393,12 @@ def boot( ) model_name = official_name break + if checkpoint_index is not None or checkpoint_value is not None: + if revision is not None: + raise ValueError( + "Specify either revision= or checkpoint_index/checkpoint_value, not both." + ) + revision = _resolve_checkpoint_to_revision(model_name, checkpoint_index, checkpoint_value) # Pass HF token for gated model access (e.g. meta-llama/*) from transformer_lens.utilities.hf_utils import get_hf_token @@ -346,6 +413,7 @@ def boot( output_attentions=True, trust_remote_code=trust_remote_code, token=_hf_token, + revision=revision, ) _n_ctx_field: str | None = None if n_ctx is not None: @@ -505,6 +573,8 @@ def boot( model_kwargs["token"] = _hf_token if trust_remote_code: model_kwargs["trust_remote_code"] = True + if revision is not None: + model_kwargs["revision"] = revision if resolved_device_map is not None: model_kwargs["device_map"] = resolved_device_map if resolved_max_memory is not None: