Skip to content

Commit f7427db

Browse files
author
jetstream authors
committed
Merge pull request #243 from AI-Hypercomputer:yuyan-prefix-cache
PiperOrigin-RevId: 750638778
2 parents f71349a + ad7e494 commit f7427db

File tree

8 files changed

+993
-176
lines changed

8 files changed

+993
-176
lines changed

jetstream/core/config_lib.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ class MetricsServerConfig:
5959
model_name: Optional[str] = None
6060

6161

62+
@dataclasses.dataclass
63+
class PrefixCachingConfig:
64+
"""Config to prefix caching.
65+
66+
Attributes:
67+
max_hbm_byte: the max size saving in hbm in bytes.
68+
max_dram_byte: the max size saving in dram on host in bytes.
69+
"""
70+
71+
max_hbm_byte: int
72+
max_dram_byte: int
73+
74+
6275
# ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼#
6376

6477

jetstream/core/orchestrator.py

Lines changed: 172 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,15 @@
9494
import jax
9595
from jetstream.core.lora import adapter_tensorstore as adapterstore
9696

97+
from jetstream.core import prefix_cache
9798
from jetstream.core.proto import jetstream_pb2
9899
from jetstream.core.proto import jetstream_pb2_grpc
99100
from jetstream.core.utils import async_multifuture
100101
from jetstream.core.utils.return_sample import ReturnSample
101-
from jetstream.engine import engine_api, tokenizer_api, token_utils
102+
from jetstream.engine import chunked_prefill
103+
from jetstream.engine import engine_api
104+
from jetstream.engine import token_utils
105+
from jetstream.engine import tokenizer_api
102106
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
103107
import numpy as np
104108

@@ -261,6 +265,10 @@ class Driver:
261265
_prefill_adapterstore: list[adapterstore.AdapterTensorStore] | None = None
262266
_generate_adapterstore: list[adapterstore.AdapterTensorStore] | None = None
263267

268+
# Optional prefix cache for storing and retrieving KV caches of common
269+
# prompt prefixes to accelerate prefill. Only work with chunked prefill.
270+
_prefix_cache: prefix_cache.PrefixCache | None = None
271+
264272
def __init__(
265273
self,
266274
prefill_engines: Optional[list[engine_api.Engine]] = None,
@@ -278,6 +286,7 @@ def __init__(
278286
metrics_collector: JetstreamMetricsCollector | None = None,
279287
is_ray_backend: bool = False,
280288
multi_sampling: bool = False,
289+
prefix_cache_inst: prefix_cache.PrefixCache | None = None,
281290
):
282291
if prefill_engines is None:
283292
raise ValueError("No prefill engine provided.")
@@ -306,6 +315,7 @@ def __init__(
306315
self._interleaved_mode = interleaved_mode
307316
self._metrics_collector = metrics_collector
308317
self._multi_sampling = multi_sampling
318+
self._prefix_cache = prefix_cache_inst
309319

310320
# Stages 1-4 represent the life cycle of a request.
311321
# Stage 1
@@ -619,55 +629,158 @@ def _do_chunked_prefill(
619629
prefill_params: Any,
620630
tokenizer: tokenizer_api.Tokenizer,
621631
tokens: jax.Array | np.ndarray,
632+
existing_prefix: Optional[engine_api.ExistingPrefix] = None,
622633
) -> Tuple[engine_api.Prefix, engine_api.ResultTokens]:
623-
"""Do chunked prefill.
634+
"""Performs the prefill operation in chunks.
635+
636+
This method takes a sequence of tokens and processes them in chunks using
637+
the provided prefill engine. It can optionally start from an existing
638+
prefix (KVCache state).
639+
640+
Note: This method requires the `use_chunked_prefill` attribute to be True
641+
on the `prefill_engine`.
624642
625-
Should not use without enabling use_chunked_prefill config.
643+
Args:
644+
prefill_engine: The engine instance responsible for prefilling.
645+
prefill_params: The parameters (e.g., model weights) for the prefill
646+
engine.
647+
tokenizer: The tokenizer used, primarily for padding information if
648+
needed during chunk generation.
649+
tokens: A JAX or NumPy array of input token IDs to be prefilled.
650+
existing_prefix: An optional `ExistingPrefix` object containing a
651+
previously computed KVCache and the corresponding common tokens. If
652+
provided, prefill starts from this state. Defaults to None.
653+
654+
Returns:
655+
A tuple containing:
656+
- The resulting prefix (engine-specific KVCache state) after
657+
processing the tokens.
658+
- The first token generated immediately after the prefill.
659+
660+
Raises:
661+
ValueError: If `use_chunked_prefill` is not enabled on the engine.
662+
(Implicitly raised by `chunked_prefill.do_chunked_prefill` if checks
663+
are present there.
626664
"""
627665

628-
assert prefill_engine.use_chunked_prefill
666+
chunked_tokens_list = chunked_prefill.gen_chunked_padded_tokens(
667+
tokens,
668+
prefill_engine.prefill_chunk_size,
669+
tokenizer,
670+
existing_prefix.common_prefix_tokens if existing_prefix else None,
671+
jax_padding=self._jax_padding,
672+
)
673+
prefill_result, first_token = chunked_prefill.do_chunked_prefill(
674+
prefill_engine,
675+
prefill_params,
676+
chunked_tokens_list,
677+
existing_prefix,
678+
)
679+
return prefill_result, first_token
680+
681+
def _do_chunked_prefill_with_prefix_cache(
682+
self,
683+
prefill_engine: engine_api.Engine,
684+
prefill_params: Any,
685+
tokenizer: tokenizer_api.Tokenizer,
686+
tokens: jax.Array | np.ndarray,
687+
) -> Tuple[engine_api.Prefix, engine_api.ResultTokens]:
688+
"""Performs chunked prefill leveraging a prefix cache.
689+
690+
This method attempts to accelerate the prefill process by first loading
691+
the longest possible matching prefix (KV cache state) from the
692+
`self._prefix_cache`. It then performs chunked prefill only on the
693+
remaining portion of the input tokens. Finally, it saves the potentially
694+
updated or new prefix back into the cache.
695+
696+
Note:
697+
- This method requires `use_chunked_prefill` to be True on the
698+
`prefill_engine`.
699+
- This method requires `self._prefix_cache` to be initialized.
629700
630-
prefill_result = None
631-
first_token = None
701+
Args:
702+
prefill_engine: The engine instance responsible for prefilling.
703+
prefill_params: The parameters (e.g., model weights) for the prefill
704+
engine.
705+
tokenizer: The tokenizer used for padding and potentially during cache
706+
operations if needed.
707+
tokens: A JAX or NumPy array of input token IDs to be prefilled.
708+
709+
Returns:
710+
A tuple containing:
711+
- The resulting prefix (engine-specific KVCache state) after
712+
processing all tokens (either loaded from cache or computed).
713+
- The first token generated immediately after the full prefill.
714+
715+
Raises:
716+
ValueError: If `use_chunked_prefill` is not enabled on the engine or
717+
if `self._prefix_cache` is None.
718+
"""
719+
if not prefill_engine.use_chunked_prefill:
720+
raise ValueError("Chunked prefill must be enabled to use this function.")
721+
if self._prefix_cache is None:
722+
raise ValueError("Prefix cache is not initialized.")
723+
724+
# Ensure tokens are in tuple format for cache keys
725+
tuple_tokens = tuple(tokens.tolist())
726+
chunk_size = prefill_engine.prefill_chunk_size
727+
728+
# 1. Load the longest possible prefix from the cache
729+
load_result = prefix_cache.load_existing_prefix(
730+
self._prefix_cache, tuple_tokens, chunk_size
731+
)
632732

633733
existing_prefix = None
634-
for start_pos in range(
635-
0,
636-
len(tokens),
637-
prefill_engine.prefill_chunk_size,
638-
):
639-
input_token = tokens[
640-
start_pos : min(
641-
len(tokens), start_pos + prefill_engine.prefill_chunk_size
642-
)
643-
]
644-
padded_input_token, input_true_length = token_utils.pad_tokens(
645-
input_token,
646-
tokenizer.bos_id,
647-
tokenizer.pad_id,
648-
is_bos=False,
649-
max_prefill_length=prefill_engine.max_prefill_length,
650-
jax_padding=self._jax_padding,
734+
remain_tokens = tokens # Assume full prefill initially
735+
original_common_prefix_len = 0
736+
737+
if load_result:
738+
existing_prefix, original_common_prefix_len = load_result
739+
# Calculate the tokens that still need to be prefilled
740+
# common_prefix_tokens is already truncated to chunk_size multiple
741+
# and ensures at least one token remains.
742+
truncated_len = existing_prefix.common_prefix_tokens.shape[0]
743+
remain_tokens = tokens[truncated_len:]
744+
logger.debug(
745+
"Prefix cache hit. Original common length: %d, Truncated length: %d,"
746+
" Remaining tokens to prefill: %d",
747+
original_common_prefix_len,
748+
truncated_len,
749+
len(remain_tokens),
651750
)
652-
prefill_result, first_token = prefill_engine.prefill(
653-
params=prefill_params,
654-
existing_prefix=existing_prefix,
655-
padded_tokens=padded_input_token,
656-
true_length=input_true_length,
657-
)
658-
existing_prefix = engine_api.ExistingPrefix(
659-
cache=prefill_result["cache"],
660-
common_prefix_tokens=tokens[
661-
0 : min(
662-
len(tokens), start_pos + prefill_engine.prefill_chunk_size
663-
)
664-
],
751+
else:
752+
logger.debug(
753+
"Prefix cache miss or prefix too short. Prefilling all tokens."
665754
)
666755

667-
# Should assign in the loop
756+
# 2. Perform chunked prefill on the remaining tokens
757+
prefill_result, first_token = self._do_chunked_prefill(
758+
prefill_engine=prefill_engine,
759+
prefill_params=prefill_params,
760+
tokenizer=tokenizer,
761+
tokens=remain_tokens,
762+
existing_prefix=existing_prefix,
763+
)
764+
668765
assert prefill_result is not None
669766
assert first_token is not None
670767

768+
# 3. Save the potentially new prefix to the cache
769+
# save_existing_prefix handles truncation and checks if the key exists.
770+
# We pass the original full tokens and the final cache state.
771+
# copy_prefix=True because the insert function after will donate the result.
772+
# Assuming max prefill length is the padded length to the cache.
773+
saved = prefix_cache.save_existing_prefix(
774+
prefix_cache=self._prefix_cache,
775+
tokens=tuple_tokens,
776+
prefix=prefill_result["cache"],
777+
chunk_size=chunk_size,
778+
padded_length=prefill_engine.max_prefill_length,
779+
copy_prefix=True,
780+
)
781+
if saved:
782+
logger.debug("Saved new prefix to cache.")
783+
671784
return prefill_result, first_token
672785

673786
def _prefill_thread(self, idx: int):
@@ -749,14 +862,28 @@ def _prefill_thread(self, idx: int):
749862
)
750863
request.complete = np.zeros((request.num_samples,), np.bool_)
751864
else:
752-
# if chunked_prefill is used,
753-
if prefill_engine.use_chunked_prefill:
754-
prefill_result, first_token = self._do_chunked_prefill(
755-
prefill_engine,
756-
final_prefill_params,
757-
tokenizer,
758-
padded_tokens[:true_length],
759-
)
865+
# if chunked_prefill is used, and the prompt is long enough
866+
if (
867+
prefill_engine.use_chunked_prefill
868+
and true_length >= prefill_engine.prefill_chunk_size
869+
):
870+
if self._prefix_cache is not None:
871+
prefill_result, first_token = (
872+
self._do_chunked_prefill_with_prefix_cache(
873+
prefill_engine,
874+
final_prefill_params,
875+
tokenizer,
876+
padded_tokens[:true_length],
877+
)
878+
)
879+
else:
880+
prefill_result, first_token = self._do_chunked_prefill(
881+
prefill_engine,
882+
final_prefill_params,
883+
tokenizer,
884+
padded_tokens[:true_length],
885+
)
886+
760887
else:
761888
# Compute new kv cache for the prefill_content.
762889
prefill_result, first_token = prefill_engine.prefill(

0 commit comments

Comments
 (0)