Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Tests for TransformerBridge.generate() and generate_stream() when no tokenizer is set.

Bridge counterpart to tests/unit/test_generate_no_tokenizer.py — regression
coverage for https://github.com/TransformerLensOrg/TransformerLens/issues/483.

The bridge can be constructed via boot_transformers() with a tokenizer loaded
from HF; tests then clear ``bridge.tokenizer`` to exercise the tokenizer-free
generation path (algorithmic/custom-tokenized use cases).
"""

import pytest
import torch

from transformer_lens.model_bridge import TransformerBridge


@pytest.fixture(scope="module")
def tokenizer_free_bridge():
bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
bridge.tokenizer = None
return bridge


def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_bridge):
"""generate() with no tokenizer, stop_at_eos=False, use_past_kv_cache=True."""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

tokens = torch.zeros((1, 5), dtype=torch.long)
output = bridge.generate(
tokens,
max_new_tokens=3,
stop_at_eos=False,
use_past_kv_cache=True,
return_type="tokens",
verbose=False,
)
assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}"


def test_generate_without_tokenizer_stop_at_eos_false_no_kv_cache(tokenizer_free_bridge):
"""generate() with no tokenizer, stop_at_eos=False, use_past_kv_cache=False."""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

tokens = torch.zeros((1, 5), dtype=torch.long)
output = bridge.generate(
tokens,
max_new_tokens=3,
stop_at_eos=False,
use_past_kv_cache=False,
return_type="tokens",
verbose=False,
)
assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}"


def test_generate_without_tokenizer_explicit_eos_kv_cache(tokenizer_free_bridge):
"""generate() with no tokenizer, explicit eos_token_id, use_past_kv_cache=True.

Uses a high-valued eos_token_id unlikely to be sampled from zero input so
generation runs to ``max_new_tokens`` and we can assert exact output shape.
"""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

tokens = torch.zeros((1, 5), dtype=torch.long)
output = bridge.generate(
tokens,
max_new_tokens=3,
stop_at_eos=True,
eos_token_id=50256,
do_sample=False,
use_past_kv_cache=True,
return_type="tokens",
verbose=False,
)
assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}"


def test_generate_without_tokenizer_explicit_eos_no_kv_cache(tokenizer_free_bridge):
"""generate() with no tokenizer, explicit eos_token_id, use_past_kv_cache=False."""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

tokens = torch.zeros((1, 5), dtype=torch.long)
output = bridge.generate(
tokens,
max_new_tokens=3,
stop_at_eos=True,
eos_token_id=50256,
do_sample=False,
use_past_kv_cache=False,
return_type="tokens",
verbose=False,
)
assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}"


def test_generate_without_tokenizer_stop_at_eos_requires_eos_id(tokenizer_free_bridge):
"""generate() must still error when stop_at_eos=True, no eos_token_id, no tokenizer."""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

tokens = torch.zeros((1, 5), dtype=torch.long)
with pytest.raises(AssertionError, match="eos_token_id"):
bridge.generate(
tokens, max_new_tokens=3, stop_at_eos=True, return_type="tokens", verbose=False
)


def test_generate_string_input_without_tokenizer_errors(tokenizer_free_bridge):
"""generate() must still error when string input is used without a tokenizer."""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

with pytest.raises(AssertionError, match="to_tokens without a tokenizer"):
bridge.generate("hello", max_new_tokens=3, verbose=False)


def test_generate_return_type_str_without_tokenizer_errors(tokenizer_free_bridge):
"""generate(return_type='str') must error when no tokenizer is set.

Generation itself succeeds (tensor input, stop_at_eos=False); the assert
fires only at the decode step, proving the str-decode path is guarded.
"""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

tokens = torch.zeros((1, 5), dtype=torch.long)
with pytest.raises(AssertionError):
bridge.generate(
tokens,
max_new_tokens=3,
stop_at_eos=False,
return_type="str",
verbose=False,
)


def test_generate_stream_without_tokenizer_explicit_eos(tokenizer_free_bridge):
"""generate_stream() with no tokenizer; verify all max_new_tokens land in the chunks.

First chunk contains input + at least one generated token; later chunks contain
only new tokens. Greedy + high eos_token_id keeps generation from halting early
so total yielded length is exactly input + max_new_tokens.
"""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

input_len = 5
max_new = 4
tokens = torch.zeros((1, input_len), dtype=torch.long)
chunks = list(
bridge.generate_stream(
tokens,
max_new_tokens=max_new,
max_tokens_per_yield=2,
stop_at_eos=True,
eos_token_id=50256,
do_sample=False,
use_past_kv_cache=True,
return_type="tokens",
verbose=False,
)
)
total_yielded = sum(chunk.shape[-1] for chunk in chunks)
assert (
total_yielded == input_len + max_new
), f"Expected {input_len + max_new} tokens across chunks, got {total_yielded}"


def test_generate_stream_without_tokenizer_stop_at_eos_requires_eos_id(tokenizer_free_bridge):
"""generate_stream() must error when stop_at_eos=True with no eos_token_id and no tokenizer."""
bridge = tokenizer_free_bridge
assert bridge.tokenizer is None

tokens = torch.zeros((1, 5), dtype=torch.long)
with pytest.raises(AssertionError, match="eos_token_id"):
# Generator is lazy — must consume to trigger the assert.
list(
bridge.generate_stream(
tokens,
max_new_tokens=3,
stop_at_eos=True,
return_type="tokens",
verbose=False,
)
)
39 changes: 25 additions & 14 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ def to_tokens(
Returns:
Token tensor of shape ``[batch, pos]``.
"""
assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
if prepend_bos is None:
prepend_bos = getattr(self.cfg, "default_prepend_bos", True)
if padding_side is None:
Expand Down Expand Up @@ -2548,22 +2549,26 @@ def generate(
stop_tokens = []
eos_token_for_padding = 0
if stop_at_eos:
tokenizer_has_eos_token = (
self.tokenizer is not None and self.tokenizer.eos_token_id is not None
)
if eos_token_id is None:
assert (
self.tokenizer.eos_token_id is not None
), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id"
tokenizer_has_eos_token
), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
assert self.tokenizer is not None
eos_token_id = self.tokenizer.eos_token_id

if isinstance(eos_token_id, int):
stop_tokens = [eos_token_id]
eos_token_for_padding = eos_token_id
else:
stop_tokens = list(eos_token_id)
eos_token_for_padding = (
self.tokenizer.eos_token_id
if self.tokenizer.eos_token_id is not None
else eos_token_id[0]
)
if tokenizer_has_eos_token:
assert self.tokenizer is not None
eos_token_for_padding = self.tokenizer.eos_token_id
else:
eos_token_for_padding = eos_token_id[0]

# Track which sequences have finished
finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
Expand Down Expand Up @@ -2728,6 +2733,7 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...

# Format output
if return_type == "str":
assert self.tokenizer is not None
if input_type == "str":
return self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
else:
Expand Down Expand Up @@ -2824,21 +2830,25 @@ def generate_stream(
stop_tokens: List[int] = []
eos_token_for_padding = 0
if stop_at_eos:
tokenizer_has_eos_token = (
self.tokenizer is not None and self.tokenizer.eos_token_id is not None
)
if eos_token_id is None:
assert (
self.tokenizer.eos_token_id is not None
), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id"
tokenizer_has_eos_token
), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
assert self.tokenizer is not None
eos_token_id = self.tokenizer.eos_token_id
if isinstance(eos_token_id, int):
stop_tokens = [eos_token_id]
eos_token_for_padding = eos_token_id
else:
stop_tokens = list(eos_token_id)
eos_token_for_padding = (
self.tokenizer.eos_token_id
if self.tokenizer.eos_token_id is not None
else eos_token_id[0]
)
if tokenizer_has_eos_token:
assert self.tokenizer is not None
eos_token_for_padding = self.tokenizer.eos_token_id
else:
eos_token_for_padding = eos_token_id[0]

finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)

Expand All @@ -2859,6 +2869,7 @@ def _maybe_decode(
tokens: torch.Tensor,
) -> Union[torch.Tensor, str]:
if return_type == "str":
assert self.tokenizer is not None
return self.tokenizer.decode(tokens[0], skip_special_tokens=True)
return tokens

Expand Down
Loading