diff --git a/tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py b/tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py new file mode 100644 index 000000000..5b0f7771b --- /dev/null +++ b/tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py @@ -0,0 +1,189 @@ +"""Tests for TransformerBridge.generate() and generate_stream() when no tokenizer is set. + +Bridge counterpart to tests/unit/test_generate_no_tokenizer.py — regression +coverage for https://github.com/TransformerLensOrg/TransformerLens/issues/483. + +The bridge can be constructed via boot_transformers() with a tokenizer loaded +from HF; tests then clear ``bridge.tokenizer`` to exercise the tokenizer-free +generation path (algorithmic/custom-tokenized use cases). +""" + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge + + +@pytest.fixture(scope="module") +def tokenizer_free_bridge(): + bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") + bridge.tokenizer = None + return bridge + + +def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_bridge): + """generate() with no tokenizer, stop_at_eos=False, use_past_kv_cache=True.""" + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + tokens = torch.zeros((1, 5), dtype=torch.long) + output = bridge.generate( + tokens, + max_new_tokens=3, + stop_at_eos=False, + use_past_kv_cache=True, + return_type="tokens", + verbose=False, + ) + assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}" + + +def test_generate_without_tokenizer_stop_at_eos_false_no_kv_cache(tokenizer_free_bridge): + """generate() with no tokenizer, stop_at_eos=False, use_past_kv_cache=False.""" + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + tokens = torch.zeros((1, 5), dtype=torch.long) + output = bridge.generate( + tokens, + max_new_tokens=3, + stop_at_eos=False, + use_past_kv_cache=False, + return_type="tokens", + verbose=False, + ) + assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}" + + +def test_generate_without_tokenizer_explicit_eos_kv_cache(tokenizer_free_bridge): + """generate() with no tokenizer, explicit eos_token_id, use_past_kv_cache=True. + + Uses a high-valued eos_token_id unlikely to be sampled from zero input so + generation runs to ``max_new_tokens`` and we can assert exact output shape. + """ + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + tokens = torch.zeros((1, 5), dtype=torch.long) + output = bridge.generate( + tokens, + max_new_tokens=3, + stop_at_eos=True, + eos_token_id=50256, + do_sample=False, + use_past_kv_cache=True, + return_type="tokens", + verbose=False, + ) + assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}" + + +def test_generate_without_tokenizer_explicit_eos_no_kv_cache(tokenizer_free_bridge): + """generate() with no tokenizer, explicit eos_token_id, use_past_kv_cache=False.""" + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + tokens = torch.zeros((1, 5), dtype=torch.long) + output = bridge.generate( + tokens, + max_new_tokens=3, + stop_at_eos=True, + eos_token_id=50256, + do_sample=False, + use_past_kv_cache=False, + return_type="tokens", + verbose=False, + ) + assert output.shape == (1, 8), f"Expected shape (1, 8), got {output.shape}" + + +def test_generate_without_tokenizer_stop_at_eos_requires_eos_id(tokenizer_free_bridge): + """generate() must still error when stop_at_eos=True, no eos_token_id, no tokenizer.""" + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + tokens = torch.zeros((1, 5), dtype=torch.long) + with pytest.raises(AssertionError, match="eos_token_id"): + bridge.generate( + tokens, max_new_tokens=3, stop_at_eos=True, return_type="tokens", verbose=False + ) + + +def test_generate_string_input_without_tokenizer_errors(tokenizer_free_bridge): + """generate() must still error when string input is used without a tokenizer.""" + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + with pytest.raises(AssertionError, match="to_tokens without a tokenizer"): + bridge.generate("hello", max_new_tokens=3, verbose=False) + + +def test_generate_return_type_str_without_tokenizer_errors(tokenizer_free_bridge): + """generate(return_type='str') must error when no tokenizer is set. + + Generation itself succeeds (tensor input, stop_at_eos=False); the assert + fires only at the decode step, proving the str-decode path is guarded. + """ + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + tokens = torch.zeros((1, 5), dtype=torch.long) + with pytest.raises(AssertionError): + bridge.generate( + tokens, + max_new_tokens=3, + stop_at_eos=False, + return_type="str", + verbose=False, + ) + + +def test_generate_stream_without_tokenizer_explicit_eos(tokenizer_free_bridge): + """generate_stream() with no tokenizer; verify all max_new_tokens land in the chunks. + + First chunk contains input + at least one generated token; later chunks contain + only new tokens. Greedy + high eos_token_id keeps generation from halting early + so total yielded length is exactly input + max_new_tokens. + """ + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + input_len = 5 + max_new = 4 + tokens = torch.zeros((1, input_len), dtype=torch.long) + chunks = list( + bridge.generate_stream( + tokens, + max_new_tokens=max_new, + max_tokens_per_yield=2, + stop_at_eos=True, + eos_token_id=50256, + do_sample=False, + use_past_kv_cache=True, + return_type="tokens", + verbose=False, + ) + ) + total_yielded = sum(chunk.shape[-1] for chunk in chunks) + assert ( + total_yielded == input_len + max_new + ), f"Expected {input_len + max_new} tokens across chunks, got {total_yielded}" + + +def test_generate_stream_without_tokenizer_stop_at_eos_requires_eos_id(tokenizer_free_bridge): + """generate_stream() must error when stop_at_eos=True with no eos_token_id and no tokenizer.""" + bridge = tokenizer_free_bridge + assert bridge.tokenizer is None + + tokens = torch.zeros((1, 5), dtype=torch.long) + with pytest.raises(AssertionError, match="eos_token_id"): + # Generator is lazy — must consume to trigger the assert. + list( + bridge.generate_stream( + tokens, + max_new_tokens=3, + stop_at_eos=True, + return_type="tokens", + verbose=False, + ) + ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 30d58363c..cb3895368 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -925,6 +925,7 @@ def to_tokens( Returns: Token tensor of shape ``[batch, pos]``. """ + assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" if prepend_bos is None: prepend_bos = getattr(self.cfg, "default_prepend_bos", True) if padding_side is None: @@ -2548,10 +2549,14 @@ def generate( stop_tokens = [] eos_token_for_padding = 0 if stop_at_eos: + tokenizer_has_eos_token = ( + self.tokenizer is not None and self.tokenizer.eos_token_id is not None + ) if eos_token_id is None: assert ( - self.tokenizer.eos_token_id is not None - ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" + tokenizer_has_eos_token + ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" + assert self.tokenizer is not None eos_token_id = self.tokenizer.eos_token_id if isinstance(eos_token_id, int): @@ -2559,11 +2564,11 @@ def generate( eos_token_for_padding = eos_token_id else: stop_tokens = list(eos_token_id) - eos_token_for_padding = ( - self.tokenizer.eos_token_id - if self.tokenizer.eos_token_id is not None - else eos_token_id[0] - ) + if tokenizer_has_eos_token: + assert self.tokenizer is not None + eos_token_for_padding = self.tokenizer.eos_token_id + else: + eos_token_for_padding = eos_token_id[0] # Track which sequences have finished finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) @@ -2728,6 +2733,7 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ... # Format output if return_type == "str": + assert self.tokenizer is not None if input_type == "str": return self.tokenizer.decode(output_tokens[0], skip_special_tokens=True) else: @@ -2824,21 +2830,25 @@ def generate_stream( stop_tokens: List[int] = [] eos_token_for_padding = 0 if stop_at_eos: + tokenizer_has_eos_token = ( + self.tokenizer is not None and self.tokenizer.eos_token_id is not None + ) if eos_token_id is None: assert ( - self.tokenizer.eos_token_id is not None - ), "Must pass eos_token_id if stop_at_eos is True and tokenizer has no eos_token_id" + tokenizer_has_eos_token + ), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" + assert self.tokenizer is not None eos_token_id = self.tokenizer.eos_token_id if isinstance(eos_token_id, int): stop_tokens = [eos_token_id] eos_token_for_padding = eos_token_id else: stop_tokens = list(eos_token_id) - eos_token_for_padding = ( - self.tokenizer.eos_token_id - if self.tokenizer.eos_token_id is not None - else eos_token_id[0] - ) + if tokenizer_has_eos_token: + assert self.tokenizer is not None + eos_token_for_padding = self.tokenizer.eos_token_id + else: + eos_token_for_padding = eos_token_id[0] finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) @@ -2859,6 +2869,7 @@ def _maybe_decode( tokens: torch.Tensor, ) -> Union[torch.Tensor, str]: if return_type == "str": + assert self.tokenizer is not None return self.tokenizer.decode(tokens[0], skip_special_tokens=True) return tokens