Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions demos/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,7 @@
"Mathematically, centering is a linear map, normalizing is *not* a linear map, and scaling and translation are linear maps. \n",
"* **Centering:** LayerNorm is applied every time a layer reads from the residual stream, so the mean of any residual stream vector can never matter - `center_writing_weights` set every weight matrix writing to the residual to have zero mean. \n",
"* **Normalizing:** Normalizing is not a linear map, and cannot be factored out. The `hook_scale` hook point lets you access and control for this.\n",
"* **Scaling and Translation:** Scaling and translation are linear maps, and are always followed by another linear map. The composition of two linear maps is another linear map, so we can *fold* the scaling and translation weights into the weights of the subsequent layer, and simplify things without changing the underlying computation. \n",
"\n",
"[See the docs for more details](https://github.com/TransformerLensOrg/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln)"
"* **Scaling and Translation:** Scaling and translation are linear maps, and are always followed by another linear map. The composition of two linear maps is another linear map, so we can *fold* the scaling and translation weights into the weights of the subsequent layer, and simplify things without changing the underlying computation. \n"
]
},
{
Expand Down
107 changes: 75 additions & 32 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,29 @@ class TransformerBridge(HookIntrospectionMixin, nn.Module):
BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and
``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize
as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt.

BOS token and chat templates
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

``model.tokenizer`` is configured with ``add_bos_token=True`` and is
**not** the stock HuggingFace tokenizer. Direct ``.encode()`` calls
will prepend BOS automatically.

When passing pre-applied chat-template text (i.e., the output of
``tokenizer.apply_chat_template(..., tokenize=False)``), pass
``prepend_bos=False`` to :meth:`to_tokens` to avoid a double BOS::

# Correct pattern for chat templates:
text = model.tokenizer.apply_chat_template(messages, tokenize=False)
tokens = model.to_tokens(text, prepend_bos=False)

The chat template already embeds the model's expected BOS token in
the rendered text; letting :meth:`to_tokens` add another would produce
a malformed sequence like ``[BOS, BOS, ...]``.

To inspect what tokens will actually be fed to the model during
generation, use :meth:`to_tokens` directly or pass
``return_input_tokens=True`` to :meth:`generate`.
"""

hook_aliases: Dict[str, Union[str, List[str]]] = {
Expand Down Expand Up @@ -2506,18 +2529,22 @@ def _generate_tokens(
temperature=temperature,
freq_penalty=freq_penalty,
repetition_penalty=repetition_penalty,
tokens=penalty_tokens
if _generate_from_embeds
else (decoder_tokens if is_encoder_decoder else current_tokens),
tokens=(
penalty_tokens
if _generate_from_embeds
else (decoder_tokens if is_encoder_decoder else current_tokens)
),
).to(self.cfg.device)
else:
sampled_tokens = utils.sample_logits(
final_logits,
temperature=0.0,
repetition_penalty=repetition_penalty,
tokens=penalty_tokens
if _generate_from_embeds
else (decoder_tokens if is_encoder_decoder else current_tokens),
tokens=(
penalty_tokens
if _generate_from_embeds
else (decoder_tokens if is_encoder_decoder else current_tokens)
),
).to(self.cfg.device)

# Handle EOS
Expand Down Expand Up @@ -2570,12 +2597,18 @@ def generate(
verbose: bool = True,
output_logits: bool = False,
return_cache: bool = False,
return_input_tokens: bool = False,
names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
device: Optional[Union[str, torch.device]] = None,
pixel_values: Optional[torch.Tensor] = None,
**multimodal_kwargs,
) -> (
str | list[str] | torch.Tensor | Any | tuple[Any, ActivationCache]
str
| list[str]
| torch.Tensor
| Any
| tuple[Any, ActivationCache]
| tuple[Any, torch.Tensor]
): # Any for transformers.utils.ModelOutput
# Any: beartype forward ref limitation (beartype#546)
"""Sample tokens from the model.
Expand All @@ -2597,9 +2630,11 @@ def generate(
repetition by dividing positive logits and multiplying negative logits for
previously seen tokens. Default 1.0 (no penalty).
use_past_kv_cache: If True, use KV caching for faster generation
prepend_bos: Accepted for API compatibility but not applied during generation.
The HF model expects tokens in its native format (tokenizer defaults).
Overriding BOS can silently degrade generation quality.
prepend_bos: Whether to prepend a BOS token when tokenizing string inputs.
Defaults to None (uses ``cfg.default_prepend_bos``, typically True).
Pass ``prepend_bos=False`` when the input is pre-formatted chat-template
text that already contains the BOS token to avoid double-BOS.
Ignored when input is already a token tensor.
padding_side: Which side to pad when tokenizing multiple strings of different
lengths. For batched list inputs, left-padding is forced internally for
correct generation behavior. Defaults to None (tokenizer default).
Expand All @@ -2614,6 +2649,11 @@ def generate(
encoder-decoder, SSM, multimodal, batched, and inputs_embeds inputs raise
NotImplementedError. The cache spans prompt + max_new_tokens and can be large,
use ``names_filter`` to scope it and/or ``device`` to offload it.
return_input_tokens: If True, return an ``(output, input_tokens)`` tuple where
``input_tokens`` is the token tensor that was actually fed to the model
(after BOS handling). Useful for debugging tokenization, especially when
using chat templates where BOS handling can be subtle. Can be combined
with ``return_cache`` to get ``(output, cache, input_tokens)``.
names_filter: Passed to ``run_with_cache`` when ``return_cache=True``; restricts
which activations are cached (str, list of str, or callable).
device: Passed through when ``return_cache=True`` to offload the cached tensors
Expand All @@ -2627,25 +2667,18 @@ def generate(
If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
If return_cache=True, returns an ``(output, ActivationCache)`` tuple where ``output`` is the
value that would otherwise be returned and the cache equals ``run_with_cache(output)``.
If return_input_tokens=True, returns an ``(output, input_tokens)`` tuple.
If both return_cache and return_input_tokens are True, returns ``(output, cache, input_tokens)``.

Example:
``out, cache = model.generate(prompt, max_new_tokens=20, return_cache=True)`` returns a
normal ActivationCache over the full prompt + generated sequence (equivalent to
``run_with_cache(out)``).
"""
# prepend_bos is intentionally not applied during generation.
# The HF model expects tokens in its native format. Overriding BOS can silently
# degrade quality.
if prepend_bos is not None:
import warnings

warnings.warn(
"prepend_bos is ignored during TransformerBridge.generate(). "
"The HF model expects tokens with the tokenizer's default BOS handling. "
"To control BOS, tokenize with to_tokens(prepend_bos=...) and pass the "
"resulting tensor to generate().",
stacklevel=2,
)
``out, input_tokens = model.generate(prompt, return_input_tokens=True)`` returns
the tokens that were fed to the model, useful for verifying BOS handling with
chat templates.
"""
# padding_side is handled internally: for batched list inputs, left-padding
# is forced to ensure correct generation. See _is_batched_list logic below.

Expand All @@ -2657,15 +2690,19 @@ def generate(

_generate_from_embeds = False
if isinstance(input, str):
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
input_tokens = self.to_tokens(
input, prepend_bos=prepend_bos, move_to_device=True, truncate=False
)
input_type = "str"
elif isinstance(input, list):
# Force left-padding for batched generation so real tokens are
# flush-right and logits[:, -1, :] is always the last real token.
if _is_batched_list:
_orig_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = "left"
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
input_tokens = self.to_tokens(
input, prepend_bos=prepend_bos, move_to_device=True, truncate=False
)
if _is_batched_list:
self.tokenizer.padding_side = _orig_padding_side
input_type = "list"
Expand Down Expand Up @@ -2919,15 +2956,21 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...
else: # return_type == "tokens"
result = output_tokens

if not return_cache:
if not return_cache and not return_input_tokens:
return result

# return_cache: recompute one clean forward over the full generated sequence so the
# cache is identical to run_with_cache(output_tokens) - all hook points, including
# attention patterns. The guards above restrict this to single-sequence, decoder-only
# text generation (see issue #697).
_, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device)
return result, cache
if return_cache:
# return_cache: recompute one clean forward over the full generated sequence so the
# cache is identical to run_with_cache(output_tokens) - all hook points, including
# attention patterns. The guards above restrict this to single-sequence, decoder-only
# text generation (see issue #697).
_, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device)
if return_input_tokens:
return result, cache, input_tokens
return result, cache

# return_input_tokens only (no cache)
return result, input_tokens

@torch.no_grad()
def generate_stream(
Expand Down
Loading