9494import jax
9595from jetstream .core .lora import adapter_tensorstore as adapterstore
9696
97+ from jetstream .core import prefix_cache
9798from jetstream .core .proto import jetstream_pb2
9899from jetstream .core .proto import jetstream_pb2_grpc
99100from jetstream .core .utils import async_multifuture
100101from jetstream .core .utils .return_sample import ReturnSample
101- from jetstream .engine import engine_api , tokenizer_api , token_utils
102+ from jetstream .engine import chunked_prefill
103+ from jetstream .engine import engine_api
104+ from jetstream .engine import token_utils
105+ from jetstream .engine import tokenizer_api
102106from jetstream .core .metrics .prometheus import JetstreamMetricsCollector
103107import numpy as np
104108
@@ -261,6 +265,10 @@ class Driver:
261265 _prefill_adapterstore : list [adapterstore .AdapterTensorStore ] | None = None
262266 _generate_adapterstore : list [adapterstore .AdapterTensorStore ] | None = None
263267
268+ # Optional prefix cache for storing and retrieving KV caches of common
269+ # prompt prefixes to accelerate prefill. Only work with chunked prefill.
270+ _prefix_cache : prefix_cache .PrefixCache | None = None
271+
264272 def __init__ (
265273 self ,
266274 prefill_engines : Optional [list [engine_api .Engine ]] = None ,
@@ -278,6 +286,7 @@ def __init__(
278286 metrics_collector : JetstreamMetricsCollector | None = None ,
279287 is_ray_backend : bool = False ,
280288 multi_sampling : bool = False ,
289+ prefix_cache_inst : prefix_cache .PrefixCache | None = None ,
281290 ):
282291 if prefill_engines is None :
283292 raise ValueError ("No prefill engine provided." )
@@ -306,6 +315,7 @@ def __init__(
306315 self ._interleaved_mode = interleaved_mode
307316 self ._metrics_collector = metrics_collector
308317 self ._multi_sampling = multi_sampling
318+ self ._prefix_cache = prefix_cache_inst
309319
310320 # Stages 1-4 represent the life cycle of a request.
311321 # Stage 1
@@ -619,55 +629,158 @@ def _do_chunked_prefill(
619629 prefill_params : Any ,
620630 tokenizer : tokenizer_api .Tokenizer ,
621631 tokens : jax .Array | np .ndarray ,
632+ existing_prefix : Optional [engine_api .ExistingPrefix ] = None ,
622633 ) -> Tuple [engine_api .Prefix , engine_api .ResultTokens ]:
623- """Do chunked prefill.
634+ """Performs the prefill operation in chunks.
635+
636+ This method takes a sequence of tokens and processes them in chunks using
637+ the provided prefill engine. It can optionally start from an existing
638+ prefix (KVCache state).
639+
640+ Note: This method requires the `use_chunked_prefill` attribute to be True
641+ on the `prefill_engine`.
624642
625- Should not use without enabling use_chunked_prefill config.
643+ Args:
644+ prefill_engine: The engine instance responsible for prefilling.
645+ prefill_params: The parameters (e.g., model weights) for the prefill
646+ engine.
647+ tokenizer: The tokenizer used, primarily for padding information if
648+ needed during chunk generation.
649+ tokens: A JAX or NumPy array of input token IDs to be prefilled.
650+ existing_prefix: An optional `ExistingPrefix` object containing a
651+ previously computed KVCache and the corresponding common tokens. If
652+ provided, prefill starts from this state. Defaults to None.
653+
654+ Returns:
655+ A tuple containing:
656+ - The resulting prefix (engine-specific KVCache state) after
657+ processing the tokens.
658+ - The first token generated immediately after the prefill.
659+
660+ Raises:
661+ ValueError: If `use_chunked_prefill` is not enabled on the engine.
662+ (Implicitly raised by `chunked_prefill.do_chunked_prefill` if checks
663+ are present there.
626664 """
627665
628- assert prefill_engine .use_chunked_prefill
666+ chunked_tokens_list = chunked_prefill .gen_chunked_padded_tokens (
667+ tokens ,
668+ prefill_engine .prefill_chunk_size ,
669+ tokenizer ,
670+ existing_prefix .common_prefix_tokens if existing_prefix else None ,
671+ jax_padding = self ._jax_padding ,
672+ )
673+ prefill_result , first_token = chunked_prefill .do_chunked_prefill (
674+ prefill_engine ,
675+ prefill_params ,
676+ chunked_tokens_list ,
677+ existing_prefix ,
678+ )
679+ return prefill_result , first_token
680+
681+ def _do_chunked_prefill_with_prefix_cache (
682+ self ,
683+ prefill_engine : engine_api .Engine ,
684+ prefill_params : Any ,
685+ tokenizer : tokenizer_api .Tokenizer ,
686+ tokens : jax .Array | np .ndarray ,
687+ ) -> Tuple [engine_api .Prefix , engine_api .ResultTokens ]:
688+ """Performs chunked prefill leveraging a prefix cache.
689+
690+ This method attempts to accelerate the prefill process by first loading
691+ the longest possible matching prefix (KV cache state) from the
692+ `self._prefix_cache`. It then performs chunked prefill only on the
693+ remaining portion of the input tokens. Finally, it saves the potentially
694+ updated or new prefix back into the cache.
695+
696+ Note:
697+ - This method requires `use_chunked_prefill` to be True on the
698+ `prefill_engine`.
699+ - This method requires `self._prefix_cache` to be initialized.
629700
630- prefill_result = None
631- first_token = None
701+ Args:
702+ prefill_engine: The engine instance responsible for prefilling.
703+ prefill_params: The parameters (e.g., model weights) for the prefill
704+ engine.
705+ tokenizer: The tokenizer used for padding and potentially during cache
706+ operations if needed.
707+ tokens: A JAX or NumPy array of input token IDs to be prefilled.
708+
709+ Returns:
710+ A tuple containing:
711+ - The resulting prefix (engine-specific KVCache state) after
712+ processing all tokens (either loaded from cache or computed).
713+ - The first token generated immediately after the full prefill.
714+
715+ Raises:
716+ ValueError: If `use_chunked_prefill` is not enabled on the engine or
717+ if `self._prefix_cache` is None.
718+ """
719+ if not prefill_engine .use_chunked_prefill :
720+ raise ValueError ("Chunked prefill must be enabled to use this function." )
721+ if self ._prefix_cache is None :
722+ raise ValueError ("Prefix cache is not initialized." )
723+
724+ # Ensure tokens are in tuple format for cache keys
725+ tuple_tokens = tuple (tokens .tolist ())
726+ chunk_size = prefill_engine .prefill_chunk_size
727+
728+ # 1. Load the longest possible prefix from the cache
729+ load_result = prefix_cache .load_existing_prefix (
730+ self ._prefix_cache , tuple_tokens , chunk_size
731+ )
632732
633733 existing_prefix = None
634- for start_pos in range (
635- 0 ,
636- len (tokens ),
637- prefill_engine .prefill_chunk_size ,
638- ):
639- input_token = tokens [
640- start_pos : min (
641- len (tokens ), start_pos + prefill_engine .prefill_chunk_size
642- )
643- ]
644- padded_input_token , input_true_length = token_utils .pad_tokens (
645- input_token ,
646- tokenizer .bos_id ,
647- tokenizer .pad_id ,
648- is_bos = False ,
649- max_prefill_length = prefill_engine .max_prefill_length ,
650- jax_padding = self ._jax_padding ,
734+ remain_tokens = tokens # Assume full prefill initially
735+ original_common_prefix_len = 0
736+
737+ if load_result :
738+ existing_prefix , original_common_prefix_len = load_result
739+ # Calculate the tokens that still need to be prefilled
740+ # common_prefix_tokens is already truncated to chunk_size multiple
741+ # and ensures at least one token remains.
742+ truncated_len = existing_prefix .common_prefix_tokens .shape [0 ]
743+ remain_tokens = tokens [truncated_len :]
744+ logger .debug (
745+ "Prefix cache hit. Original common length: %d, Truncated length: %d,"
746+ " Remaining tokens to prefill: %d" ,
747+ original_common_prefix_len ,
748+ truncated_len ,
749+ len (remain_tokens ),
651750 )
652- prefill_result , first_token = prefill_engine .prefill (
653- params = prefill_params ,
654- existing_prefix = existing_prefix ,
655- padded_tokens = padded_input_token ,
656- true_length = input_true_length ,
657- )
658- existing_prefix = engine_api .ExistingPrefix (
659- cache = prefill_result ["cache" ],
660- common_prefix_tokens = tokens [
661- 0 : min (
662- len (tokens ), start_pos + prefill_engine .prefill_chunk_size
663- )
664- ],
751+ else :
752+ logger .debug (
753+ "Prefix cache miss or prefix too short. Prefilling all tokens."
665754 )
666755
667- # Should assign in the loop
756+ # 2. Perform chunked prefill on the remaining tokens
757+ prefill_result , first_token = self ._do_chunked_prefill (
758+ prefill_engine = prefill_engine ,
759+ prefill_params = prefill_params ,
760+ tokenizer = tokenizer ,
761+ tokens = remain_tokens ,
762+ existing_prefix = existing_prefix ,
763+ )
764+
668765 assert prefill_result is not None
669766 assert first_token is not None
670767
768+ # 3. Save the potentially new prefix to the cache
769+ # save_existing_prefix handles truncation and checks if the key exists.
770+ # We pass the original full tokens and the final cache state.
771+ # copy_prefix=True because the insert function after will donate the result.
772+ # Assuming max prefill length is the padded length to the cache.
773+ saved = prefix_cache .save_existing_prefix (
774+ prefix_cache = self ._prefix_cache ,
775+ tokens = tuple_tokens ,
776+ prefix = prefill_result ["cache" ],
777+ chunk_size = chunk_size ,
778+ padded_length = prefill_engine .max_prefill_length ,
779+ copy_prefix = True ,
780+ )
781+ if saved :
782+ logger .debug ("Saved new prefix to cache." )
783+
671784 return prefill_result , first_token
672785
673786 def _prefill_thread (self , idx : int ):
@@ -749,14 +862,28 @@ def _prefill_thread(self, idx: int):
749862 )
750863 request .complete = np .zeros ((request .num_samples ,), np .bool_ )
751864 else :
752- # if chunked_prefill is used,
753- if prefill_engine .use_chunked_prefill :
754- prefill_result , first_token = self ._do_chunked_prefill (
755- prefill_engine ,
756- final_prefill_params ,
757- tokenizer ,
758- padded_tokens [:true_length ],
759- )
865+ # if chunked_prefill is used, and the prompt is long enough
866+ if (
867+ prefill_engine .use_chunked_prefill
868+ and true_length >= prefill_engine .prefill_chunk_size
869+ ):
870+ if self ._prefix_cache is not None :
871+ prefill_result , first_token = (
872+ self ._do_chunked_prefill_with_prefix_cache (
873+ prefill_engine ,
874+ final_prefill_params ,
875+ tokenizer ,
876+ padded_tokens [:true_length ],
877+ )
878+ )
879+ else :
880+ prefill_result , first_token = self ._do_chunked_prefill (
881+ prefill_engine ,
882+ final_prefill_params ,
883+ tokenizer ,
884+ padded_tokens [:true_length ],
885+ )
886+
760887 else :
761888 # Compute new kv cache for the prefill_content.
762889 prefill_result , first_token = prefill_engine .prefill (
0 commit comments