Skip to content
125 changes: 125 additions & 0 deletions tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Asserts ``TransformerBridge`` reproduces ``AutoModelForCausalLM`` eager-attention logits.

Issue #385 reported drift between bridge and HF for rotary models like Pythia. The drift
was an attention-implementation mismatch — bridge always uses eager, default HF loads use
SDPA, which reorders ops in a fused kernel. Bridge vs HF *eager* matches to fp32-noise.
"""

from typing import Callable

import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from transformer_lens.model_bridge import TransformerBridge

MODEL_NAME = "EleutherAI/pythia-70m"

# Op-reorder noise floor for fp32 transformer forward passes. We currently
# measure 0.0 on this model, but allow a small epsilon so harmless refactors
# (intermediate allocations, equivalent op reorderings) don't break the test.
FP32_NOISE_TOL = 1e-5


@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)


@pytest.fixture(scope="module")
def bridge():
return TransformerBridge.boot_transformers(MODEL_NAME, device="cpu", dtype=torch.float32)


@pytest.fixture(scope="module")
def hf_eager():
"""HF model loaded independently of the bridge's wrapped instance."""
return AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float32, attn_implementation="eager"
).eval()


@pytest.fixture
def tokenize(tokenizer) -> Callable[[str], torch.Tensor]:
def _tok(prompt: str) -> torch.Tensor:
return tokenizer(prompt, return_tensors="pt").input_ids

return _tok


@pytest.mark.parametrize("prompt", ["Hello, world!", "The quick brown fox jumps"])
def test_bridge_logits_match_hf_eager(bridge, hf_eager, tokenize, prompt):
tokens = tokenize(prompt)
with torch.inference_mode():
bridge_logits = bridge(tokens)
hf_logits = hf_eager(tokens).logits
max_diff = (bridge_logits - hf_logits).abs().max().item()
assert max_diff < FP32_NOISE_TOL, (
f"{MODEL_NAME!r} bridge vs HF eager drift={max_diff:.2e} on {prompt!r} "
f"exceeds fp32-noise tolerance {FP32_NOISE_TOL:.0e} — bridge's "
f"_reconstruct_attention may have regressed (see issue #385)."
)


def test_bridge_residual_stream_matches_hf_eager(bridge, hf_eager, tokenize):
"""Per-layer parity catches compensating errors that wash out at the final logits."""
tokens = tokenize("Hello, world!")
n_layers = len(hf_eager.gpt_neox.layers)

hf_layer_out: dict[int, torch.Tensor] = {}

def _make_hf_hook(idx):
def _h(_m, _i, o):
hf_layer_out[idx] = (o[0] if isinstance(o, tuple) else o).detach()

return _h

handles = [
layer.register_forward_hook(_make_hf_hook(i))
for i, layer in enumerate(hf_eager.gpt_neox.layers)
]
try:
with torch.inference_mode():
hf_eager(tokens)
finally:
for h in handles:
h.remove()

bridge_layer_out: dict[int, torch.Tensor] = {}
fwd_hooks = [
(
f"blocks.{i}.hook_resid_post",
lambda v, hook, idx=i: bridge_layer_out.__setitem__(idx, v.detach()),
)
for i in range(n_layers)
]
with torch.inference_mode():
bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks)

for i in range(n_layers):
d = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item()
assert d < FP32_NOISE_TOL, (
f"layer {i} residual drift={d:.2e} exceeds fp32-noise tolerance "
f"{FP32_NOISE_TOL:.0e} — bridge layer output diverges from HF eager."
)


def test_bridge_attention_reconstruction_actually_runs(bridge, tokenize):
"""Guard against tautology: prove bridge's custom attention path executes.

If a future refactor made the bridge delegate to HF directly, the previous
parity tests would pass trivially. This one fails fast in that case by
asserting bridge-specific hooks fire during forward.
"""
tokens = tokenize("Hello, world!")
attn_scores_fired: list[bool] = []
bridge.run_with_hooks(
tokens,
fwd_hooks=[
("blocks.0.attn.hook_attn_scores", lambda v, hook: attn_scores_fired.append(True)),
],
)
assert attn_scores_fired, (
"blocks.0.attn.hook_attn_scores did not fire — bridge no longer runs its "
"own attention reconstruction, making the parity tests tautological."
)
59 changes: 41 additions & 18 deletions tests/unit/factored_matrix/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,52 @@ def test_transpose_property(self, factored_matrices):

def test_svd_property(self, factored_matrices):
for factored_matrix in factored_matrices:
U, S, Vh = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5)
# test that U and Vh are unitary
U, S, V = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.T, atol=1e-5)
# test that U and V are unitary
assert torch.allclose(U.T @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.T @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)
assert torch.allclose(V.T @ V, torch.eye(V.shape[-1]), atol=1e-5)

def test_svd_property_leading_ones(self, factored_matrices_leading_ones):
for factored_matrix in factored_matrices_leading_ones:
U, S, Vh = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5)
# test that U and Vh are unitary
U, S, V = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.mT, atol=1e-5)
# test that U and V are unitary
assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)
assert torch.allclose(V.mT @ V, torch.eye(V.shape[-1]), atol=1e-5)

def test_V_and_Vh_alias_match(self, factored_matrices):
import warnings

for factored_matrix in factored_matrices:
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
vh_value = factored_matrix.Vh
assert any(issubclass(w.category, DeprecationWarning) for w in caught)
assert torch.equal(vh_value, factored_matrix.V)

def test_svd_caches_per_instance(self):
"""svd() should cache its result on the instance — repeated calls return the same tensors."""
m = FactoredMatrix(randn(4, 3), randn(3, 4))
first_U, first_S, first_V = m.svd()
second_U, second_S, second_V = m.svd()
assert first_U is second_U
assert first_S is second_S
assert first_V is second_V

def test_svd_does_not_prevent_gc(self):
"""svd's cache must not hold a strong reference that prevents the instance from being GC'd"""
import gc
import weakref

m = FactoredMatrix(randn(4, 3), randn(3, 4))
_ = m.svd() # populate the cache
ref = weakref.ref(m)
del m
gc.collect()
assert (
ref() is None
), "FactoredMatrix instance survived deletion — svd cache is leaking references."

def test_eigenvalues_property(self, factored_matrices):
for factored_matrix in factored_matrices:
Expand Down Expand Up @@ -141,16 +174,6 @@ def test_collapse_l(self, factored_matrices):
expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.V)
assert torch.allclose(result, expected)

def test_V_and_Vh_alias_match(self, factored_matrices):
import warnings

for factored_matrix in factored_matrices:
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
vh_value = factored_matrix.Vh
assert any(issubclass(w.category, DeprecationWarning) for w in caught)
assert torch.equal(vh_value, factored_matrix.V)

def test_collapse_r(self, factored_matrices):
for factored_matrix in factored_matrices:
result = factored_matrix.collapse_r()
Expand Down
175 changes: 175 additions & 0 deletions tests/unit/model_bridge/test_checkpoint_revision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Unit tests for the bridge revision/checkpoint API (issue #453)."""

from unittest.mock import patch

import pytest

from transformer_lens.model_bridge.sources.transformers import (
_CHECKPOINT_REVISION_FORMATS,
_resolve_checkpoint_to_revision,
)


class TestResolveCheckpointToRevision:
def test_pythia_index_resolves_to_step_revision(self):
labels = [0, 1000, 3000, 10000]
with patch(
"transformer_lens.loading_from_pretrained.get_checkpoint_labels",
return_value=(labels, "step"),
):
revision = _resolve_checkpoint_to_revision(
"EleutherAI/pythia-70m", checkpoint_index=2, checkpoint_value=None
)
assert revision == "step3000"

def test_pythia_value_resolves_to_step_revision(self):
labels = [0, 1000, 3000, 10000]
with patch(
"transformer_lens.loading_from_pretrained.get_checkpoint_labels",
return_value=(labels, "step"),
):
revision = _resolve_checkpoint_to_revision(
"EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=10000
)
assert revision == "step10000"

def test_stanford_crfm_uses_checkpoint_prefix(self):
labels = [100, 200, 400]
with patch(
"transformer_lens.loading_from_pretrained.get_checkpoint_labels",
return_value=(labels, "step"),
):
revision = _resolve_checkpoint_to_revision(
"stanford-crfm/alias-gpt2-small-x21", checkpoint_index=1, checkpoint_value=None
)
assert revision == "checkpoint-200"

def test_unknown_family_raises(self):
with pytest.raises(ValueError, match="known checkpoint revision convention"):
_resolve_checkpoint_to_revision("gpt2", checkpoint_index=0, checkpoint_value=None)

def test_index_out_of_range_raises(self):
labels = [0, 1000]
with patch(
"transformer_lens.loading_from_pretrained.get_checkpoint_labels",
return_value=(labels, "step"),
):
with pytest.raises(ValueError, match="out of range"):
_resolve_checkpoint_to_revision(
"EleutherAI/pythia-70m", checkpoint_index=5, checkpoint_value=None
)

def test_unknown_value_raises(self):
labels = [0, 1000]
with patch(
"transformer_lens.loading_from_pretrained.get_checkpoint_labels",
return_value=(labels, "step"),
):
with pytest.raises(ValueError, match="not in available checkpoints"):
_resolve_checkpoint_to_revision(
"EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=99999
)

def test_neither_provided_raises(self):
with pytest.raises(ValueError, match="Must specify"):
_resolve_checkpoint_to_revision(
"EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=None
)

def test_known_families_registered(self):
assert "EleutherAI/pythia" in _CHECKPOINT_REVISION_FORMATS
assert "stanford-crfm" in _CHECKPOINT_REVISION_FORMATS


class _AbortBoot(Exception):
"""Raised by the model-load patch to short-circuit ``boot()`` before any real load."""


class TestBootRevisionPlumbing:
"""Verify that ``revision`` and ``checkpoint_*`` reach HF's from_pretrained calls.

Uses pythia-70m's real cached config (avoids MagicMock fragility through the
adapter/config-mapping path) and aborts at the model-load step.
"""

def _patched_boot(self, **boot_kwargs):
from transformer_lens.model_bridge.sources import transformers as bridge_src

captured: dict = {}
real_autoconfig = bridge_src.AutoConfig.from_pretrained

def capture_autoconfig(name, **kwargs):
captured["autoconfig_kwargs"] = dict(kwargs)
# Strip the (possibly fake) revision so the real call hits the CI cache.
kwargs.pop("revision", None)
return real_autoconfig(name, **kwargs)

def capture_model_load(*args, **kwargs):
captured["model_kwargs"] = kwargs
raise _AbortBoot()

with patch.object(
bridge_src.AutoConfig, "from_pretrained", side_effect=capture_autoconfig
), patch(
"transformers.AutoModelForCausalLM.from_pretrained",
side_effect=capture_model_load,
):
with pytest.raises(_AbortBoot):
bridge_src.boot(model_name="EleutherAI/pythia-70m", device="cpu", **boot_kwargs)

return captured

def test_revision_forwarded_to_autoconfig(self):
captured = self._patched_boot(revision="step3000")
assert captured["autoconfig_kwargs"].get("revision") == "step3000"

def test_revision_forwarded_to_model_load(self):
captured = self._patched_boot(revision="step3000")
assert captured.get("model_kwargs", {}).get("revision") == "step3000"

def test_checkpoint_index_resolves_to_revision(self):
labels = [0, 1000, 3000, 10000]
with patch(
"transformer_lens.loading_from_pretrained.get_checkpoint_labels",
return_value=(labels, "step"),
):
captured = self._patched_boot(checkpoint_index=2)
assert captured["autoconfig_kwargs"].get("revision") == "step3000"
assert captured.get("model_kwargs", {}).get("revision") == "step3000"

def test_conflicting_revision_and_checkpoint_raises(self):
from transformer_lens.model_bridge.sources import transformers as bridge_src

with pytest.raises(ValueError, match="not both"):
bridge_src.boot(
model_name="EleutherAI/pythia-70m",
revision="step1000",
checkpoint_index=2,
)

def test_default_revision_is_none(self):
"""With no revision/checkpoint args, revision is not added to model_kwargs."""
captured = self._patched_boot()
assert captured["autoconfig_kwargs"].get("revision") is None
assert "revision" not in captured.get("model_kwargs", {})


class TestHookedTransformerCheckpointLabelAlias:
def test_checkpoint_label_routes_to_checkpoint_value(self):
from transformer_lens import HookedTransformer

with patch("transformer_lens.loading.get_pretrained_model_config") as mock_get_cfg:
mock_get_cfg.side_effect = RuntimeError("stop after config call")
with pytest.raises(RuntimeError, match="stop after config call"):
HookedTransformer.from_pretrained("EleutherAI/pythia-70m", checkpoint_label=3000)

_, kwargs = mock_get_cfg.call_args
assert kwargs["checkpoint_value"] == 3000

def test_checkpoint_label_and_value_together_raises(self):
from transformer_lens import HookedTransformer

with pytest.raises(ValueError, match="aliases"):
HookedTransformer.from_pretrained(
"EleutherAI/pythia-70m", checkpoint_label=3000, checkpoint_value=1000
)
Loading
Loading