Skip to content

Commit 44b6a1a

Browse files
committed
A
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
1 parent 39d0d39 commit 44b6a1a

4 files changed

Lines changed: 30 additions & 12 deletions

File tree

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
322322
.def_rw("missed_blocks", &tbk::KvCacheStats::missedBlocks)
323323
.def_rw("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
324324
.def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize)
325-
.def_ro("allocated_bytes", &tbk::KvCacheStats::allocatedBytes);
325+
.def_rw("allocated_bytes", &tbk::KvCacheStats::allocatedBytes);
326326

327327
nb::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
328328
.def(nb::init<>())

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor,
1717
get_size_in_bytes, mpi_comm, mpi_disabled,
1818
torch_comm)
19+
from tensorrt_llm.bindings.internal.batch_manager import KvCacheStats
1920
from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import (
2021
IndexMapper, copy_batch_block_offsets_to_device)
2122
from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
@@ -1961,13 +1962,14 @@ def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool:
19611962
request)
19621963

19631964
def get_kv_cache_stats(self):
1965+
# TODO: Remove this once we have a proper way to shutdown the kv cache manager
1966+
if hasattr(self, "kv_cache_stats"):
1967+
return self.kv_cache_stats
19641968

1965-
class KVCacheStatus:
1969+
kv_cache_stats = KvCacheStats()
1970+
kv_cache_stats.allocated_bytes = self.impl.get_quota(GPU_LEVEL)
19661971

1967-
def __init__(self, allocated_bytes: int):
1968-
self.allocated_bytes = allocated_bytes
1969-
1970-
return KVCacheStatus(allocated_bytes=self.impl.get_quota(GPU_LEVEL))
1972+
return kv_cache_stats
19711973

19721974
def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor:
19731975
block_ids_per_seq = self.get_batch_cache_indices(request_ids)
@@ -2208,7 +2210,11 @@ def shutdown(self):
22082210
for kv_cache in self.kv_cache_map.values():
22092211
kv_cache.close()
22102212
self.kv_cache_map.clear()
2211-
self.impl.clear_reusable_blocks()
2213+
self.kv_cache_stats = self.get_kv_cache_stats()
2214+
if hasattr(self, "impl"):
2215+
# TODO: Use self.impl.shutdown() instead of del self.impl
2216+
self.impl.clear_reusable_blocks()
2217+
del self.impl
22122218

22132219
def get_max_resource_count(self) -> int:
22142220
# TODO: implement this
@@ -2279,7 +2285,8 @@ def update_resources(self,
22792285
req.get_tokens(DEFAULT_BEAM_INDEX)
22802286
[kv_cache.num_committed_tokens:req.
22812287
context_current_position])
2282-
kv_cache.stop_committing()
2288+
if req.context_remaining_length == 0:
2289+
kv_cache.stop_committing()
22832290
else:
22842291
success = kv_cache.resize(None, req.context_current_position)
22852292
if not success:

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,15 @@ def get_test_config(test_desc, example_dir, test_root):
222222
def get_extra_llm_config(config, suffix, cwd):
223223
extra_llm_config = {
224224
'orchestrator_type': 'ray',
225+
'kv_cache_config': {
226+
'use_kv_cache_manager_v2': False
227+
}
225228
}
226229
for key, value in config.items():
227230
if key not in ['num_instances', 'urls']:
228231
extra_llm_config[key] = value
232+
if key == 'kv_cache_config':
233+
extra_llm_config[key]['use_kv_cache_manager_v2'] = False
229234

230235
temp_fd, extra_config_file = tempfile.mkstemp(suffix='_%s.yaml' % suffix,
231236
dir=cwd)

tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
172172
disable_overlap_scheduler=not generation_overlap,
173173
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
174174

175-
kv_cache_configs = [KvCacheConfig(max_tokens=2048 * 8) for _ in range(2)]
175+
kv_cache_configs = [
176+
KvCacheConfig(max_tokens=2048 * 8, use_kv_cache_manager_v2=False)
177+
for _ in range(2)
178+
]
176179
cache_transceiver_configs = [
177180
CacheTransceiverConfig(backend="DEFAULT") for _ in range(2)
178181
]
@@ -318,8 +321,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
318321
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
319322

320323
kv_cache_configs = [
321-
KvCacheConfig(max_tokens=128, enable_block_reuse=False, dtype="auto")
322-
for _ in range(2)
324+
KvCacheConfig(max_tokens=128,
325+
enable_block_reuse=False,
326+
dtype="auto",
327+
use_kv_cache_manager_v2=False) for _ in range(2)
323328
]
324329
cache_transceiver_configs = [
325330
CacheTransceiverConfig(backend="DEFAULT") for _ in range(2)
@@ -429,7 +434,8 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
429434
kv_cache_configs = [
430435
KvCacheConfig(max_tokens=128,
431436
enable_block_reuse=False,
432-
free_gpu_memory_fraction=0.4) for _ in range(2)
437+
free_gpu_memory_fraction=0.4,
438+
use_kv_cache_manager_v2=False) for _ in range(2)
433439
]
434440
cache_transceiver_configs = [
435441
CacheTransceiverConfig(backend="DEFAULT") for _ in range(2)

0 commit comments

Comments
 (0)