Skip to content

Attention sink support for LLM runner#18753

Open
kirklandsign wants to merge 1 commit intomainfrom
export-D99900289
Open

Attention sink support for LLM runner#18753
kirklandsign wants to merge 1 commit intomainfrom
export-D99900289

Conversation

@kirklandsign
Copy link
Copy Markdown
Contributor

Summary:
Rewrite the Attention Sink KV cache implementation from eviction-based to ring buffer approach for torch.export compatibility.

Key changes:

  • Ring buffer KV cache: Replace dynamic eviction (torch.cat, narrow, shift) with fixed-size ring buffer using index_copy_. Cache layout: [sink slots | ring buffer slots]. Sink tokens (e.g., BOS) occupy fixed positions; window tokens wrap around in the ring buffer region.
  • Remove eviction_batch_size: No longer needed -- ring buffer overwrites old entries automatically. Removed from all interfaces (attention_sink.py, model.py, llm_config.py, yaml config).
  • Remove attention_sink_forward: No more monkey-patching AttentionMHA.forward. Instead, KVCacheWithAttentionSink sets is_ring_buffer=True, and AttentionMHA.forward handles ring buffer models natively (skip start_pos bounds check, compute mask after KV update).
  • Remove rerotate_k / position shifting: Ring buffer uses original positions for RoPE -- no re-rotation needed.
  • Fix C++ runner: Remove TEMPORARY max_new_tokens hack. Add max_seq_len prefill check. Make context length check conditional for sliding window models.
  • Rewrite tests: Replace 16 eviction-based tests with 18 ring buffer tests covering sink preservation, ring wrapping, causal masking, and degenerate (sink_size=0) cases.
  • Add llama_attention_sink.yaml: Example config for attention sink export.

Differential Revision: D99900289

Copilot AI review requested due to automatic review settings April 7, 2026 23:51
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 7, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18753

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit 8e7877f with merge base e109ac8 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 7, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Apr 7, 2026

@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D99900289.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors the Attention Sink KV-cache implementation from eviction/shift logic to a fixed-size ring buffer to improve torch.export compatibility, and updates the runtime and tests accordingly.

Changes:

  • Implement ring-buffer-based Attention Sink KV cache + cache position management and integrate with AttentionMHA.forward via an is_ring_buffer mode.
  • Update runner/config parsing to remove eviction_batch_size and adjust generation/prefill constraints for sliding-window models.
  • Rewrite Attention Sink tests and add an example YAML config for export.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
extension/llm/runner/text_llm_runner.cpp Adds max_seq_len prefill-chunk validation and adjusts max_new_tokens budgeting for sliding-window/ring-buffer models.
extension/llm/export/config/llm_config.py Updates use_attention_sink validation to the new 2-field format.
examples/models/llama/source_transformation/attention_sink.py Replaces eviction-based Attention Sink with ring-buffer KV cache, adds sink-aware causal mask and position manager.
examples/models/llama/attention.py Updates AttentionMHA.forward to natively support ring-buffer caches by deferring mask creation until after KV update.
examples/models/llama/source_transformation/custom_kv_cache.py Prevents converting KVCacheWithAttentionSink into CustomKVCache.
examples/models/llama/source_transformation/test_attention_sink.py Rewrites tests to cover ring-buffer behavior, sink preservation, wrapping, and masking.
examples/models/llama/model.py Updates parsing/validation of use_attention_sink and enforces incompatibility with use_sdpa_with_kv_cache.
examples/models/llama/config/test_llm_config.py Updates config validation tests for the new Attention Sink format.
examples/models/llama/config/llama_attention_sink.yaml Adds an example config for Attention Sink export.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 219 to 224
if self.use_attention_sink:
attention_sink_params = self.use_attention_sink.split(",")
if len(attention_sink_params) != 3:
if len(attention_sink_params) < 2:
raise ValueError(
"The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
"The value of use_attention_sink must be structured like '<sink_size>,<window_size>'"
)
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModelConfig._validate_attention_sink currently allows 3+ comma-separated values (len < 2). Downstream code (e.g., examples/models/llama/model.py) now asserts exactly 2 values, so configs like "4,2048,1024" will pass validation here but fail later at runtime. Consider validating len(attention_sink_params) == 2 (and optionally validating both parse as ints) to keep configuration errors localized and consistent.

Copilot uses AI. Check for mistakes.
Comment on lines +137 to +156
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)

orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos

# Sink tokens go to fixed slots; window tokens use ring buffer
indices = torch.where(
orig_indices < self.sink_size,
orig_indices,
self.sink_size + (orig_indices - self.sink_size) % self.ring_size,
)
rerotation_sin = (
new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin

# Update cache_positions exactly like original CachePositionsManager
full_t = torch.full((self.max_context_length,), -1, dtype=torch.long)
arange_tensor = torch.arange(self.max_context_length, dtype=torch.long)
cache_positions = torch.where(
arange_tensor < start_pos, self.cache_positions, full_t
)
self.cache_positions.copy_(cache_positions)
self.cache_positions.index_copy_(0, indices, orig_indices)
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CachePositionsManagerWithSink.calculate_positions_and_update_indices builds orig_indices/full_t/arange_tensor on the default device (CPU). If the module/buffers are moved to CUDA (e.g., model.to('cuda')), the subsequent torch.where/copy_/index_copy_ will error due to device mismatch. Consider creating these tensors on self.cache_positions.device (and matching dtype) so the manager works on any device.

Copilot uses AI. Check for mistakes.
Comment on lines +234 to +236
assert seq_len <= self.cache_positions_manager.ring_size, (
f"Prefill sequence length ({seq_len}) exceeds ring buffer capacity "
f"({self.cache_positions_manager.ring_size}), which would cause "
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KVCacheWithAttentionSink.update asserts seq_len <= ring_size. This rejects valid prefills that include sink tokens (e.g., starting at pos=0 with seq_len == sink_size + ring_size), and it will fail with the current tests that prefill the entire cache. If the goal is to prevent duplicate indices within a single index_copy_ call, consider making the constraint depend on start_pos (e.g., only enforce the ring_size limit for the window portion) or otherwise ensure indices are unique before index_copy_.

Suggested change
assert seq_len <= self.cache_positions_manager.ring_size, (
f"Prefill sequence length ({seq_len}) exceeds ring buffer capacity "
f"({self.cache_positions_manager.ring_size}), which would cause "
start_pos = int(input_pos.reshape(-1)[0].item())
sink_tokens_remaining = max(0, self.sink_size - start_pos)
ring_write_len = max(0, seq_len - sink_tokens_remaining)
assert ring_write_len <= self.cache_positions_manager.ring_size, (
f"Update writes {ring_write_len} tokens into the ring buffer, "
f"which exceeds ring buffer capacity "
f"({self.cache_positions_manager.ring_size}) and would cause "

Copilot uses AI. Check for mistakes.
Summary:
Rewrite the Attention Sink KV cache implementation from eviction-based to ring buffer approach for torch.export compatibility.

Key changes:
- Ring buffer KV cache: Replace dynamic eviction (torch.cat, narrow, shift) with fixed-size ring buffer using index_copy_. Cache layout: [sink slots | ring buffer slots]. Sink tokens (e.g., BOS) occupy fixed positions; window tokens wrap around in the ring buffer region.
- Remove eviction_batch_size: No longer needed -- ring buffer overwrites old entries automatically. Removed from all interfaces (attention_sink.py, model.py, llm_config.py, yaml config).
- Remove attention_sink_forward: No more monkey-patching AttentionMHA.forward. Instead, KVCacheWithAttentionSink sets is_ring_buffer=True, and AttentionMHA.forward handles ring buffer models natively (skip start_pos bounds check, compute mask after KV update).
- Remove rerotate_k / position shifting: Ring buffer uses original positions for RoPE -- no re-rotation needed.
- Fix C++ runner: Remove TEMPORARY max_new_tokens hack. Add max_seq_len prefill check. Make context length check conditional for sliding window models.
- Rewrite tests: Replace 16 eviction-based tests with 18 ring buffer tests covering sink preservation, ring wrapping, causal masking, and degenerate (sink_size=0) cases.
- Add llama_attention_sink.yaml: Example config for attention sink export.

Differential Revision: D99900289
@meta-codesync meta-codesync bot force-pushed the export-D99900289 branch from 7cff409 to 8e7877f Compare April 8, 2026 06:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants