[TRTLLM-9526][feat] optimize host perf for python cache transceiver#12273
[TRTLLM-9526][feat] optimize host perf for python cache transceiver#12273chuangz0 merged 6 commits intoNVIDIA:mainfrom
Conversation
46e84f8 to
a271356
Compare
📝 WalkthroughWalkthroughThis pull request systematically converts Python list-based data structures to NumPy arrays throughout the disaggregated cache transmission system, including updates to C++ bindings, base classes, implementations, serialization logic, and tests to support vectorized operations. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tensorrt_llm/_torch/disaggregation/base/agent.py (1)
41-48: Consider addingstrict=Truetozip()for safety.The
zip()call on line 45 silently truncates ifdescs_or_addrs,sizes, anddevice_idshave different lengths. While upstream code typically ensures equal lengths, addingstrict=True(Python 3.10+) would catch mismatches early:💡 Proposed fix
def __init__(self, type, descs_or_addrs, sizes=None, device_ids=None): self.type = type if sizes is not None: self.descs = [ - (int(a), int(s), int(d)) for a, s, d in zip(descs_or_addrs, sizes, device_ids) + (int(a), int(s), int(d)) for a, s, d in zip(descs_or_addrs, sizes, device_ids, strict=True) ] else: self.descs = descs_or_addrs🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/disaggregation/base/agent.py` around lines 41 - 48, In the __init__ method of the agent (where self.descs is built), change the zip call that combines descs_or_addrs, sizes, and device_ids to use zip(..., strict=True) so mismatched input lengths raise immediately; update the tuple-building logic around the zip in __init__ (referencing the __init__ method and self.descs) to use strict=True to catch length mismatches early.tensorrt_llm/_torch/disaggregation/native/transfer.py (1)
1640-1645: Minor: Consider using.sizeinstead oflen()for consistency.While
len()works on NumPy arrays,.sizeis the more idiomatic NumPy approach and would be consistent with the rest of this PR's changes.💡 Proposed fix
def _register_aux_buffer(self): aux_meta = self._aux_buffer.meta - ptr_num = len(aux_meta.ptrs) + ptr_num = aux_meta.ptrs.size ptr_descs = [] for i in range(ptr_num):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/disaggregation/native/transfer.py` around lines 1640 - 1645, Replace the len() call with the NumPy .size property for consistency: change the assignment to ptr_num so it uses aux_meta.ptrs.size (referencing the local variable ptr_num and the attribute aux_meta.ptrs in transfer.py) and keep the rest of the loop building ptr_descs unchanged.tests/unittest/disaggregated/test_kv_transfer.py (1)
444-473: Return type annotation is outdated.The function now returns
List[np.ndarray](each element isnp.asarray(..., dtype=np.int64)), but the type hint still declaresList[List[int]].📝 Proposed fix for type annotation
def get_block_ids_per_layer_groups( kv_cache_manager, transfer_worker, request_id: int, use_v2: bool, tokens_per_block: int -) -> List[List[int]]: +) -> List[np.ndarray]: """Get block_ids for each layer group with window_size filtering."""🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/disaggregated/test_kv_transfer.py` around lines 444 - 473, The return type annotation for get_block_ids_per_layer_groups is outdated (it returns numpy arrays); update its signature to reflect List[np.ndarray] (or Sequence[np.ndarray] if you prefer immutability) instead of List[List[int]]. Locate get_block_ids_per_layer_groups and change the annotation and any related docstring/comment to List[np.ndarray]; ensure imports include numpy as np and typing.List is used consistently with np.ndarray. Verify callers of get_block_ids_per_layer_groups (e.g., uses of block_ids_per_layer_groups) still work with numpy arrays and adjust any type checks if necessary.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindings.cpp`:
- Around line 87-108: The lambda __init__ for kvc::MemoryDescs reads n elements
from addrs, sizes, and deviceIds but only uses addrs.shape(0) for n; add
explicit validation at the start of that lambda: compute n = addrs.shape(0) then
assert sizes.shape(0) == n and deviceIds.shape(0) == n and if not, throw a clear
exception (e.g., std::invalid_argument or nb::value_error) indicating mismatched
array lengths; keep the rest of the logic unchanged so you avoid out-of-bounds
reads when constructing descs.
---
Nitpick comments:
In `@tensorrt_llm/_torch/disaggregation/base/agent.py`:
- Around line 41-48: In the __init__ method of the agent (where self.descs is
built), change the zip call that combines descs_or_addrs, sizes, and device_ids
to use zip(..., strict=True) so mismatched input lengths raise immediately;
update the tuple-building logic around the zip in __init__ (referencing the
__init__ method and self.descs) to use strict=True to catch length mismatches
early.
In `@tensorrt_llm/_torch/disaggregation/native/transfer.py`:
- Around line 1640-1645: Replace the len() call with the NumPy .size property
for consistency: change the assignment to ptr_num so it uses aux_meta.ptrs.size
(referencing the local variable ptr_num and the attribute aux_meta.ptrs in
transfer.py) and keep the rest of the loop building ptr_descs unchanged.
In `@tests/unittest/disaggregated/test_kv_transfer.py`:
- Around line 444-473: The return type annotation for
get_block_ids_per_layer_groups is outdated (it returns numpy arrays); update its
signature to reflect List[np.ndarray] (or Sequence[np.ndarray] if you prefer
immutability) instead of List[List[int]]. Locate get_block_ids_per_layer_groups
and change the annotation and any related docstring/comment to List[np.ndarray];
ensure imports include numpy as np and typing.List is used consistently with
np.ndarray. Verify callers of get_block_ids_per_layer_groups (e.g., uses of
block_ids_per_layer_groups) still work with numpy arrays and adjust any type
checks if necessary.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c636e8b5-bfb5-4228-a6c5-d26fe10c0491
📒 Files selected for processing (14)
cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindings.cpptensorrt_llm/_torch/disaggregation/base/agent.pytensorrt_llm/_torch/disaggregation/base/region.pytensorrt_llm/_torch/disaggregation/base/transfer.pytensorrt_llm/_torch/disaggregation/native/auxiliary.pytensorrt_llm/_torch/disaggregation/native/mixers/attention/peer.pytensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.pytensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.pytensorrt_llm/_torch/disaggregation/native/transfer.pytensorrt_llm/_torch/disaggregation/resource/kv_extractor.pytests/unittest/disaggregated/region/test_block.pytests/unittest/disaggregated/test_extractor.pytests/unittest/disaggregated/test_kv_transfer.pytests/unittest/disaggregated/test_kv_transfer_mp.py
|
/bot run |
|
PR_Github #39355 [ run ] triggered by Bot. Commit: |
|
PR_Github #39355 [ run ] completed with state
|
215aed1 to
b10c86d
Compare
|
/bot run |
|
PR_Github #39444 [ run ] triggered by Bot. Commit: |
|
PR_Github #39444 [ run ] completed with state
|
31748f2 to
83d0a29
Compare
|
/bot run |
|
PR_Github #39575 [ run ] triggered by Bot. Commit: |
|
PR_Github #39575 [ run ] completed with state
|
83d0a29 to
225c455
Compare
|
/bot run |
|
PR_Github #39678 [ run ] triggered by Bot. Commit: |
|
PR_Github #39678 [ run ] completed with state
|
225c455 to
68aba48
Compare
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
265b3eb to
93728ee
Compare
eb1465f to
7062fcb
Compare
|
PR_Github #40669 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #40675 [ run ] triggered by Bot. Commit: |
|
PR_Github #40669 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #40679 [ run ] triggered by Bot. Commit: |
|
PR_Github #40675 [ run ] completed with state |
|
PR_Github #40679 [ run ] completed with state
|
98d4f41 to
62ed37c
Compare
|
/bot run --stage-list "A30-PyTorch-1,A30-PyTorch-2, DGX_B200-8_GPUs-PyTorch-1, GB200-8_GPUs-2_Nodes-PyTorch-2" |
|
PR_Github #40841 [ run ] triggered by Bot. Commit: |
|
/bot run --stage-list "A30-PyTorch-1, A30-PyTorch-2, DGX_B200-8_GPUs-PyTorch-1, GB200-8_GPUs-2_Nodes-PyTorch-2" |
|
PR_Github #40841 [ run ] completed with state
|
|
/bot run --stage-list "A30-PyTorch-1, DGX_B200-8_GPUs-PyTorch-1, GB200-8_GPUs-2_Nodes-PyTorch-2" |
|
PR_Github #40896 [ run ] triggered by Bot. Commit: |
|
PR_Github #40896 [ run ] completed with state
|
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
62ed37c to
b220b62
Compare
|
/bot skip --comment "all test has passed" |
|
PR_Github #41049 [ skip ] triggered by Bot. Commit: |
|
PR_Github #41049 [ skip ] completed with state |
…VIDIA#12273) Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Summary by CodeRabbit
Release Notes
Description
with this PR and #12490 .
optimize host perf for python kv transfer
model and config:
opt1 (4e505e2)
opt2 (b10c86d)
KV transfer perf
E2E
Concurrency = 1127
Concurrency = 1229
If KV transmission speeds up, the batch size on the gen side will increase faster, the number of iterations will decrease slightly, but TPOT will increase.
config 2 deepseek
deepseek R1 ctx4_dep4_gen1_dep8
8k1k
Baseline is py_cache transceiver without host optimization
1. Output Token Throughput (tok/s)
Output throughput is essentially identical between the two configurations (within +/-0.8% noise).
2. User Token Throughput (tok/s)
kv transfer
config 3 deepseek
deepseek R1 ctx4_dep4_gen1_dep8
8k1k
Baseline is cpp cache transceiver
1. Output Token Throughput (tok/s)
2. User Token Throughput (tok/s)
3. TTFT(ms)
C++ cache transceiver has better perf, but get worse context forward perf.
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.