Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18753
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 2 Unrelated FailuresAs of commit 8e7877f with merge base e109ac8 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D99900289. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Refactors the Attention Sink KV-cache implementation from eviction/shift logic to a fixed-size ring buffer to improve torch.export compatibility, and updates the runtime and tests accordingly.
Changes:
- Implement ring-buffer-based Attention Sink KV cache + cache position management and integrate with
AttentionMHA.forwardvia anis_ring_buffermode. - Update runner/config parsing to remove
eviction_batch_sizeand adjust generation/prefill constraints for sliding-window models. - Rewrite Attention Sink tests and add an example YAML config for export.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/runner/text_llm_runner.cpp | Adds max_seq_len prefill-chunk validation and adjusts max_new_tokens budgeting for sliding-window/ring-buffer models. |
| extension/llm/export/config/llm_config.py | Updates use_attention_sink validation to the new 2-field format. |
| examples/models/llama/source_transformation/attention_sink.py | Replaces eviction-based Attention Sink with ring-buffer KV cache, adds sink-aware causal mask and position manager. |
| examples/models/llama/attention.py | Updates AttentionMHA.forward to natively support ring-buffer caches by deferring mask creation until after KV update. |
| examples/models/llama/source_transformation/custom_kv_cache.py | Prevents converting KVCacheWithAttentionSink into CustomKVCache. |
| examples/models/llama/source_transformation/test_attention_sink.py | Rewrites tests to cover ring-buffer behavior, sink preservation, wrapping, and masking. |
| examples/models/llama/model.py | Updates parsing/validation of use_attention_sink and enforces incompatibility with use_sdpa_with_kv_cache. |
| examples/models/llama/config/test_llm_config.py | Updates config validation tests for the new Attention Sink format. |
| examples/models/llama/config/llama_attention_sink.yaml | Adds an example config for Attention Sink export. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if self.use_attention_sink: | ||
| attention_sink_params = self.use_attention_sink.split(",") | ||
| if len(attention_sink_params) != 3: | ||
| if len(attention_sink_params) < 2: | ||
| raise ValueError( | ||
| "The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'" | ||
| "The value of use_attention_sink must be structured like '<sink_size>,<window_size>'" | ||
| ) |
There was a problem hiding this comment.
ModelConfig._validate_attention_sink currently allows 3+ comma-separated values (len < 2). Downstream code (e.g., examples/models/llama/model.py) now asserts exactly 2 values, so configs like "4,2048,1024" will pass validation here but fail later at runtime. Consider validating len(attention_sink_params) == 2 (and optionally validating both parse as ints) to keep configuration errors localized and consistent.
| start_pos = input_pos[0].item() | ||
| torch._check_is_size(start_pos) | ||
|
|
||
| orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos | ||
|
|
||
| # Sink tokens go to fixed slots; window tokens use ring buffer | ||
| indices = torch.where( | ||
| orig_indices < self.sink_size, | ||
| orig_indices, | ||
| self.sink_size + (orig_indices - self.sink_size) % self.ring_size, | ||
| ) | ||
| rerotation_sin = ( | ||
| new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin | ||
|
|
||
| # Update cache_positions exactly like original CachePositionsManager | ||
| full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) | ||
| arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) | ||
| cache_positions = torch.where( | ||
| arange_tensor < start_pos, self.cache_positions, full_t | ||
| ) | ||
| self.cache_positions.copy_(cache_positions) | ||
| self.cache_positions.index_copy_(0, indices, orig_indices) |
There was a problem hiding this comment.
CachePositionsManagerWithSink.calculate_positions_and_update_indices builds orig_indices/full_t/arange_tensor on the default device (CPU). If the module/buffers are moved to CUDA (e.g., model.to('cuda')), the subsequent torch.where/copy_/index_copy_ will error due to device mismatch. Consider creating these tensors on self.cache_positions.device (and matching dtype) so the manager works on any device.
| assert seq_len <= self.cache_positions_manager.ring_size, ( | ||
| f"Prefill sequence length ({seq_len}) exceeds ring buffer capacity " | ||
| f"({self.cache_positions_manager.ring_size}), which would cause " |
There was a problem hiding this comment.
KVCacheWithAttentionSink.update asserts seq_len <= ring_size. This rejects valid prefills that include sink tokens (e.g., starting at pos=0 with seq_len == sink_size + ring_size), and it will fail with the current tests that prefill the entire cache. If the goal is to prevent duplicate indices within a single index_copy_ call, consider making the constraint depend on start_pos (e.g., only enforce the ring_size limit for the window portion) or otherwise ensure indices are unique before index_copy_.
| assert seq_len <= self.cache_positions_manager.ring_size, ( | |
| f"Prefill sequence length ({seq_len}) exceeds ring buffer capacity " | |
| f"({self.cache_positions_manager.ring_size}), which would cause " | |
| start_pos = int(input_pos.reshape(-1)[0].item()) | |
| sink_tokens_remaining = max(0, self.sink_size - start_pos) | |
| ring_write_len = max(0, seq_len - sink_tokens_remaining) | |
| assert ring_write_len <= self.cache_positions_manager.ring_size, ( | |
| f"Update writes {ring_write_len} tokens into the ring buffer, " | |
| f"which exceeds ring buffer capacity " | |
| f"({self.cache_positions_manager.ring_size}) and would cause " |
Summary: Rewrite the Attention Sink KV cache implementation from eviction-based to ring buffer approach for torch.export compatibility. Key changes: - Ring buffer KV cache: Replace dynamic eviction (torch.cat, narrow, shift) with fixed-size ring buffer using index_copy_. Cache layout: [sink slots | ring buffer slots]. Sink tokens (e.g., BOS) occupy fixed positions; window tokens wrap around in the ring buffer region. - Remove eviction_batch_size: No longer needed -- ring buffer overwrites old entries automatically. Removed from all interfaces (attention_sink.py, model.py, llm_config.py, yaml config). - Remove attention_sink_forward: No more monkey-patching AttentionMHA.forward. Instead, KVCacheWithAttentionSink sets is_ring_buffer=True, and AttentionMHA.forward handles ring buffer models natively (skip start_pos bounds check, compute mask after KV update). - Remove rerotate_k / position shifting: Ring buffer uses original positions for RoPE -- no re-rotation needed. - Fix C++ runner: Remove TEMPORARY max_new_tokens hack. Add max_seq_len prefill check. Make context length check conditional for sliding window models. - Rewrite tests: Replace 16 eviction-based tests with 18 ring buffer tests covering sink preservation, ring wrapping, causal masking, and degenerate (sink_size=0) cases. - Add llama_attention_sink.yaml: Example config for attention sink export. Differential Revision: D99900289
7cff409 to
8e7877f
Compare
Summary:
Rewrite the Attention Sink KV cache implementation from eviction-based to ring buffer approach for torch.export compatibility.
Key changes:
Differential Revision: D99900289