diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index 70ffabbea..5e52962a7 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -1015,9 +1015,7 @@ "Mathematically, centering is a linear map, normalizing is *not* a linear map, and scaling and translation are linear maps. \n", "* **Centering:** LayerNorm is applied every time a layer reads from the residual stream, so the mean of any residual stream vector can never matter - `center_writing_weights` set every weight matrix writing to the residual to have zero mean. \n", "* **Normalizing:** Normalizing is not a linear map, and cannot be factored out. The `hook_scale` hook point lets you access and control for this.\n", - "* **Scaling and Translation:** Scaling and translation are linear maps, and are always followed by another linear map. The composition of two linear maps is another linear map, so we can *fold* the scaling and translation weights into the weights of the subsequent layer, and simplify things without changing the underlying computation. \n", - "\n", - "[See the docs for more details](https://github.com/TransformerLensOrg/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln)" + "* **Scaling and Translation:** Scaling and translation are linear maps, and are always followed by another linear map. The composition of two linear maps is another linear map, so we can *fold* the scaling and translation weights into the weights of the subsequent layer, and simplify things without changing the underlying computation. \n" ] }, { diff --git a/tests/unit/model_bridge/test_bridge_generate_bos.py b/tests/unit/model_bridge/test_bridge_generate_bos.py new file mode 100644 index 000000000..c3d41631f --- /dev/null +++ b/tests/unit/model_bridge/test_bridge_generate_bos.py @@ -0,0 +1,277 @@ +"""Unit tests for Bridge.generate() BOS handling and return_input_tokens. + +Tests cover: +- prepend_bos parameter being respected (not ignored) +- return_input_tokens flag returning input tokens +- return_input_tokens + return_cache combo +- generate_stream respecting prepend_bos +""" + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge + + +@pytest.fixture(scope="module") +def gpt2_bridge(): + """Load a small GPT-2 bridge for testing.""" + bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu") + if bridge.tokenizer.pad_token is None: + bridge.tokenizer.pad_token = bridge.tokenizer.eos_token + return bridge + + +class TestGeneratePrependBos: + """Test that generate() respects the prepend_bos parameter.""" + + def test_prepend_bos_true_adds_bos(self, gpt2_bridge): + """prepend_bos=True should add BOS token to the input.""" + bridge = gpt2_bridge + prompt = "Hello" + + _, input_tokens = bridge.generate( + prompt, + max_new_tokens=1, + prepend_bos=True, + return_input_tokens=True, + verbose=False, + ) + + assert input_tokens[0, 0].item() == bridge.tokenizer.bos_token_id + assert input_tokens.shape[1] >= 2 # At least BOS + one token + + def test_prepend_bos_false_no_bos(self, gpt2_bridge): + """prepend_bos=False should not add BOS token to the input.""" + bridge = gpt2_bridge + prompt = "Hello" + + _, input_tokens = bridge.generate( + prompt, + max_new_tokens=1, + prepend_bos=False, + return_input_tokens=True, + verbose=False, + ) + + # First token should NOT be BOS + assert input_tokens[0, 0].item() != bridge.tokenizer.bos_token_id + + def test_prepend_bos_difference_is_one_token(self, gpt2_bridge): + """The difference between prepend_bos=True and False should be exactly 1 token.""" + bridge = gpt2_bridge + prompt = "Hello" + + _, tokens_with_bos = bridge.generate( + prompt, + max_new_tokens=1, + prepend_bos=True, + return_input_tokens=True, + verbose=False, + ) + + _, tokens_no_bos = bridge.generate( + prompt, + max_new_tokens=1, + prepend_bos=False, + return_input_tokens=True, + verbose=False, + ) + + assert tokens_with_bos.shape[1] - tokens_no_bos.shape[1] == 1 + + def test_prepend_bos_none_uses_default(self, gpt2_bridge): + """prepend_bos=None should use cfg.default_prepend_bos.""" + bridge = gpt2_bridge + prompt = "Hello" + + _, tokens_default = bridge.generate( + prompt, + max_new_tokens=1, + prepend_bos=None, + return_input_tokens=True, + verbose=False, + ) + + _, tokens_explicit = bridge.generate( + prompt, + max_new_tokens=1, + prepend_bos=bridge.cfg.default_prepend_bos, + return_input_tokens=True, + verbose=False, + ) + + assert tokens_default.shape == tokens_explicit.shape + assert torch.equal(tokens_default, tokens_explicit) + + def test_prepend_bos_ignored_for_tensor_input(self, gpt2_bridge): + """prepend_bos should be ignored when input is already a token tensor.""" + bridge = gpt2_bridge + tokens = bridge.to_tokens("Hello", prepend_bos=False) + + # Pass tensor directly - prepend_bos should have no effect + _, input_tokens_true = bridge.generate( + tokens, + max_new_tokens=1, + prepend_bos=True, + return_input_tokens=True, + verbose=False, + ) + + _, input_tokens_false = bridge.generate( + tokens, + max_new_tokens=1, + prepend_bos=False, + return_input_tokens=True, + verbose=False, + ) + + # Both should be identical since input was already tokenized + assert torch.equal(input_tokens_true, input_tokens_false) + + +class TestReturnInputTokens: + """Test the return_input_tokens flag on generate().""" + + def test_return_input_tokens_returns_tuple(self, gpt2_bridge): + """return_input_tokens=True should return (output, input_tokens) tuple.""" + bridge = gpt2_bridge + + result = bridge.generate( + "Hello", + max_new_tokens=2, + return_input_tokens=True, + verbose=False, + ) + + assert isinstance(result, tuple) + assert len(result) == 2 + output, input_tokens = result + assert isinstance(input_tokens, torch.Tensor) + assert input_tokens.dim() == 2 # [batch, seq_len] + + def test_return_input_tokens_false_returns_single(self, gpt2_bridge): + """return_input_tokens=False should return just the output.""" + bridge = gpt2_bridge + + result = bridge.generate( + "Hello", + max_new_tokens=2, + return_input_tokens=False, + verbose=False, + ) + + # Should not be a tuple (or if it is, not from return_input_tokens) + assert not isinstance(result, tuple) or not isinstance(result[1], torch.Tensor) + + def test_return_input_tokens_matches_to_tokens(self, gpt2_bridge): + """Returned input_tokens should match what to_tokens() would produce.""" + bridge = gpt2_bridge + prompt = "Hello world" + + _, input_tokens = bridge.generate( + prompt, + max_new_tokens=1, + prepend_bos=True, + return_input_tokens=True, + verbose=False, + ) + + expected_tokens = bridge.to_tokens(prompt, prepend_bos=True) + + assert torch.equal(input_tokens, expected_tokens) + + def test_return_input_tokens_with_list_input(self, gpt2_bridge): + """return_input_tokens should work with list input.""" + bridge = gpt2_bridge + + _, input_tokens = bridge.generate( + ["Hello", "World"], + max_new_tokens=1, + return_input_tokens=True, + verbose=False, + ) + + assert input_tokens.shape[0] == 2 # Batch size 2 + + +class TestReturnInputTokensWithCache: + """Test return_input_tokens combined with return_cache.""" + + def test_return_cache_and_input_tokens(self, gpt2_bridge): + """return_cache=True and return_input_tokens=True should return 3-tuple.""" + bridge = gpt2_bridge + + result = bridge.generate( + "Hi", + max_new_tokens=2, + return_cache=True, + return_input_tokens=True, + verbose=False, + ) + + assert isinstance(result, tuple) + assert len(result) == 3 + output, cache, input_tokens = result + assert hasattr(cache, "keys") # ActivationCache is dict-like + assert isinstance(input_tokens, torch.Tensor) + + def test_return_cache_only(self, gpt2_bridge): + """return_cache=True alone should return 2-tuple (output, cache).""" + bridge = gpt2_bridge + + result = bridge.generate( + "Hi", + max_new_tokens=2, + return_cache=True, + return_input_tokens=False, + verbose=False, + ) + + assert isinstance(result, tuple) + assert len(result) == 2 + output, cache = result + assert hasattr(cache, "keys") # ActivationCache + + +class TestGenerateStreamPrependBos: + """Test that generate_stream() respects the prepend_bos parameter.""" + + def test_generate_stream_prepend_bos_true(self, gpt2_bridge): + """generate_stream with prepend_bos=True should include BOS in first yield.""" + bridge = gpt2_bridge + prompt = "Hello" + + # Get first yield which includes input tokens + first_yield = None + for tokens in bridge.generate_stream( + prompt, + max_new_tokens=3, + prepend_bos=True, + return_type="tokens", + verbose=False, + ): + first_yield = tokens + break + + assert first_yield is not None + assert first_yield[0, 0].item() == bridge.tokenizer.bos_token_id + + def test_generate_stream_prepend_bos_false(self, gpt2_bridge): + """generate_stream with prepend_bos=False should not include BOS.""" + bridge = gpt2_bridge + prompt = "Hello" + + first_yield = None + for tokens in bridge.generate_stream( + prompt, + max_new_tokens=3, + prepend_bos=False, + return_type="tokens", + verbose=False, + ): + first_yield = tokens + break + + assert first_yield is not None + assert first_yield[0, 0].item() != bridge.tokenizer.bos_token_id diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 077007cd5..4599c3b4e 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -3,6 +3,7 @@ This module provides the bridge components that wrap remote model components and provide a consistent interface for accessing their weights and performing operations. """ + import logging import re import warnings @@ -124,6 +125,29 @@ class TransformerBridge(HookIntrospectionMixin, nn.Module): BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt. + + BOS token and chat templates + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + ``model.tokenizer`` is configured with ``add_bos_token=True`` and is + **not** the stock HuggingFace tokenizer. Direct ``.encode()`` calls + will prepend BOS automatically. + + When passing pre-applied chat-template text (i.e., the output of + ``tokenizer.apply_chat_template(..., tokenize=False)``), pass + ``prepend_bos=False`` to :meth:`to_tokens` to avoid a double BOS:: + + # Correct pattern for chat templates: + text = model.tokenizer.apply_chat_template(messages, tokenize=False) + tokens = model.to_tokens(text, prepend_bos=False) + + The chat template already embeds the model's expected BOS token in + the rendered text; letting :meth:`to_tokens` add another would produce + a malformed sequence like ``[BOS, BOS, ...]``. + + To inspect what tokens will actually be fed to the model during + generation, use :meth:`to_tokens` directly or pass + ``return_input_tokens=True`` to :meth:`generate`. """ hook_aliases: Dict[str, Union[str, List[str]]] = { @@ -2506,18 +2530,22 @@ def _generate_tokens( temperature=temperature, freq_penalty=freq_penalty, repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), + tokens=( + penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens) + ), ).to(self.cfg.device) else: sampled_tokens = utils.sample_logits( final_logits, temperature=0.0, repetition_penalty=repetition_penalty, - tokens=penalty_tokens - if _generate_from_embeds - else (decoder_tokens if is_encoder_decoder else current_tokens), + tokens=( + penalty_tokens + if _generate_from_embeds + else (decoder_tokens if is_encoder_decoder else current_tokens) + ), ).to(self.cfg.device) # Handle EOS @@ -2570,12 +2598,18 @@ def generate( verbose: bool = True, output_logits: bool = False, return_cache: bool = False, + return_input_tokens: bool = False, names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, device: Optional[Union[str, torch.device]] = None, pixel_values: Optional[torch.Tensor] = None, **multimodal_kwargs, ) -> ( - str | list[str] | torch.Tensor | Any | tuple[Any, ActivationCache] + str + | list[str] + | torch.Tensor + | Any + | tuple[Any, ActivationCache] + | tuple[Any, torch.Tensor] ): # Any for transformers.utils.ModelOutput # Any: beartype forward ref limitation (beartype#546) """Sample tokens from the model. @@ -2597,9 +2631,11 @@ def generate( repetition by dividing positive logits and multiplying negative logits for previously seen tokens. Default 1.0 (no penalty). use_past_kv_cache: If True, use KV caching for faster generation - prepend_bos: Accepted for API compatibility but not applied during generation. - The HF model expects tokens in its native format (tokenizer defaults). - Overriding BOS can silently degrade generation quality. + prepend_bos: Whether to prepend a BOS token when tokenizing string inputs. + Defaults to None (uses ``cfg.default_prepend_bos``, typically True). + Pass ``prepend_bos=False`` when the input is pre-formatted chat-template + text that already contains the BOS token to avoid double-BOS. + Ignored when input is already a token tensor. padding_side: Which side to pad when tokenizing multiple strings of different lengths. For batched list inputs, left-padding is forced internally for correct generation behavior. Defaults to None (tokenizer default). @@ -2614,6 +2650,11 @@ def generate( encoder-decoder, SSM, multimodal, batched, and inputs_embeds inputs raise NotImplementedError. The cache spans prompt + max_new_tokens and can be large, use ``names_filter`` to scope it and/or ``device`` to offload it. + return_input_tokens: If True, return an ``(output, input_tokens)`` tuple where + ``input_tokens`` is the token tensor that was actually fed to the model + (after BOS handling). Useful for debugging tokenization, especially when + using chat templates where BOS handling can be subtle. Can be combined + with ``return_cache`` to get ``(output, cache, input_tokens)``. names_filter: Passed to ``run_with_cache`` when ``return_cache=True``; restricts which activations are cached (str, list of str, or callable). device: Passed through when ``return_cache=True`` to offload the cached tensors @@ -2627,25 +2668,18 @@ def generate( If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. If return_cache=True, returns an ``(output, ActivationCache)`` tuple where ``output`` is the value that would otherwise be returned and the cache equals ``run_with_cache(output)``. + If return_input_tokens=True, returns an ``(output, input_tokens)`` tuple. + If both return_cache and return_input_tokens are True, returns ``(output, cache, input_tokens)``. Example: ``out, cache = model.generate(prompt, max_new_tokens=20, return_cache=True)`` returns a normal ActivationCache over the full prompt + generated sequence (equivalent to ``run_with_cache(out)``). - """ - # prepend_bos is intentionally not applied during generation. - # The HF model expects tokens in its native format. Overriding BOS can silently - # degrade quality. - if prepend_bos is not None: - import warnings - warnings.warn( - "prepend_bos is ignored during TransformerBridge.generate(). " - "The HF model expects tokens with the tokenizer's default BOS handling. " - "To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the " - "resulting tensor to generate().", - stacklevel=2, - ) + ``out, input_tokens = model.generate(prompt, return_input_tokens=True)`` returns + the tokens that were fed to the model, useful for verifying BOS handling with + chat templates. + """ # padding_side is handled internally: for batched list inputs, left-padding # is forced to ensure correct generation. See _is_batched_list logic below. @@ -2657,7 +2691,9 @@ def generate( _generate_from_embeds = False if isinstance(input, str): - input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + input_tokens = self.to_tokens( + input, prepend_bos=prepend_bos, move_to_device=True, truncate=False + ) input_type = "str" elif isinstance(input, list): # Force left-padding for batched generation so real tokens are @@ -2665,7 +2701,9 @@ def generate( if _is_batched_list: _orig_padding_side = self.tokenizer.padding_side self.tokenizer.padding_side = "left" - input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + input_tokens = self.to_tokens( + input, prepend_bos=prepend_bos, move_to_device=True, truncate=False + ) if _is_batched_list: self.tokenizer.padding_side = _orig_padding_side input_type = "list" @@ -2919,15 +2957,21 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ... else: # return_type == "tokens" result = output_tokens - if not return_cache: + if not return_cache and not return_input_tokens: return result - # return_cache: recompute one clean forward over the full generated sequence so the - # cache is identical to run_with_cache(output_tokens) - all hook points, including - # attention patterns. The guards above restrict this to single-sequence, decoder-only - # text generation (see issue #697). - _, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device) - return result, cache + if return_cache: + # return_cache: recompute one clean forward over the full generated sequence so the + # cache is identical to run_with_cache(output_tokens) - all hook points, including + # attention patterns. The guards above restrict this to single-sequence, decoder-only + # text generation (see issue #697). + _, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device) + if return_input_tokens: + return result, cache, input_tokens + return result, cache + + # return_input_tokens only (no cache) + return result, input_tokens @torch.no_grad() def generate_stream( @@ -2967,7 +3011,11 @@ def generate_stream( freq_penalty: Frequency penalty for previous tokens. repetition_penalty: HF-style repetition penalty (>1.0 discourages repeats). use_past_kv_cache: Use KV caching for faster generation. - prepend_bos: Not applied (API compatibility). See generate() docstring. + prepend_bos: Whether to prepend a BOS token when tokenizing string inputs. + Defaults to None (uses ``cfg.default_prepend_bos``, typically True). + Pass ``prepend_bos=False`` when the input is pre-formatted chat-template + text that already contains the BOS token to avoid double-BOS. + Ignored when input is already a token tensor. padding_side: Which side to pad for batched list inputs. Left-padding is forced internally for batched generation. return_type: 'input' (match input type), 'str', or 'tokens'. @@ -2978,25 +3026,22 @@ def generate_stream( max_tokens_per_yield tokens between yields. First yield includes the input tokens; subsequent yields contain only new tokens. """ - if prepend_bos is not None: - warnings.warn( - "prepend_bos is ignored during TransformerBridge.generate_stream(). " - "The HF model expects tokens with the tokenizer's default BOS handling.", - stacklevel=2, - ) - # --- Input parsing (mirrors generate()) --- _is_batched_list = isinstance(input, list) and len(input) > 1 if isinstance(input, str): - input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + input_tokens = self.to_tokens( + input, prepend_bos=prepend_bos, move_to_device=True, truncate=False + ) input_type = "str" elif isinstance(input, list): if _is_batched_list: _orig_ps = self.tokenizer.padding_side self.tokenizer.padding_side = "left" try: - input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) + input_tokens = self.to_tokens( + input, prepend_bos=prepend_bos, move_to_device=True, truncate=False + ) finally: if _is_batched_list: self.tokenizer.padding_side = _orig_ps