@@ -85,6 +85,7 @@ def __init__(
8585 # Context Params
8686 seed : int = llama_cpp .LLAMA_DEFAULT_SEED ,
8787 n_ctx : int = 512 ,
88+ n_keep : int = 256 ,
8889 n_batch : int = 2048 ,
8990 n_ubatch : int = 512 ,
9091 n_seq_max : int = 1 ,
@@ -177,6 +178,7 @@ def __init__(
177178 kv_overrides: Key-value overrides for the model.
178179 seed: RNG seed, -1 for random
179180 n_ctx: Text context, 0 = from model
181+ n_keep: Number of tokens to keep from initial prompt
180182 n_batch: Prompt processing maximum batch size
181183 n_ubatch: Physical batch size
182184 n_seq_max: max number of sequences (i.e. distinct states for recurrent models)
@@ -328,6 +330,7 @@ def __init__(
328330 self .model_params .kv_overrides = self ._kv_overrides_array
329331
330332 self .n_batch = min (n_ctx , n_batch ) # ???
333+ self .n_keep = n_keep if n_keep > 0 else 256
331334 self .n_seq_max = n_seq_max
332335 self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
333336 self .n_threads_batch = n_threads_batch or multiprocessing .cpu_count ()
@@ -778,35 +781,69 @@ def eval(self, tokens: Sequence[int]):
778781 if len (tokens ) == 0 :
779782 return
780783 n_eval = len (tokens )
781- current_pos = self .n_tokens
784+ if n_eval == 0 :
785+ return
786+
787+ # Context Shift
788+ if self .n_tokens + n_eval > self ._n_ctx :
789+ if self .is_hybrid :
790+ raise RuntimeError (
791+ f"Context length exceeded for Hybrid/SWA model! "
792+ f"(n_tokens: { self .n_tokens } , new: { n_eval } , max: { self ._n_ctx } )"
793+ )
794+ else :
795+ _n_keep = min (self .n_keep , self .n_tokens )
796+ # number of tokens after n_keep that may be discarded when shifting context
797+ # defaults to half
798+ _n_discard = (self .n_tokens - _n_keep ) // 2
782799
783- if self ._ctx :
784- # Standard cleanup by current_pos
785- is_success = self ._ctx .memory_seq_rm (0 , current_pos , - 1 )
786- # Fallback: Broad cleanup
787- if not is_success :
788800 if self .verbose :
789- print (f"WARN: memory_seq_rm(0, { current_pos } , -1) failed. Executing fallback: memory_seq_rm(0, 0, -1)" )
790- is_success = self ._ctx .memory_seq_rm (0 , 0 , - 1 )
801+ print (f"Llama.eval: Context limit reached. Shifting context: "
802+ f"discarding { _n_discard } tokens..." , file = sys .stderr )
803+
804+ self ._ctx .memory_seq_rm (0 , _n_keep , _n_keep + _n_discard )
805+ self ._ctx .memory_seq_add (0 , _n_keep + _n_discard , self .n_tokens , - _n_discard )
806+
807+ remaining_len = self .n_tokens - (_n_keep + _n_discard )
808+ if remaining_len > 0 :
809+ self .input_ids [_n_keep : _n_keep + remaining_len ] = self .input_ids [_n_keep + _n_discard : self .n_tokens ]
810+
811+ self .n_tokens -= _n_discard
791812
792813 for i in range (0 , n_eval , self .n_batch ):
793- batch = tokens [i : min (n_eval , i + self .n_batch )]
814+ batch_tokens = tokens [i : min (n_eval , i + self .n_batch )]
815+ n_batch_tokens = len (batch_tokens )
794816 n_past = self .n_tokens
795- n_batch_tokens = len (batch )
796- self ._batch .set_batch (
797- batch = batch , n_past = n_past , logits_all = self ._logits_all
817+
818+ self ._batch .reset ()
819+
820+ pos_array = [self .n_tokens + j for j in range (n_batch_tokens )]
821+
822+ if self ._logits_all :
823+ logits_array = [True ] * n_batch_tokens
824+ else :
825+ logits_array = [False ] * n_batch_tokens
826+ if i + n_batch_tokens == n_eval :
827+ logits_array [- 1 ] = True
828+
829+ self ._batch .add_sequence (
830+ token_array = batch_tokens ,
831+ pos_array = pos_array ,
832+ seq_ids = [0 ],
833+ logits_array = logits_array
798834 )
835+ current_batch_size = n_batch_tokens
799836 try :
800837 self ._ctx .decode (self ._batch )
801838 except Exception as e :
802839 raise RuntimeError (
803- f"Decode Failed at Pos { current_pos } . "
840+ f"Decode Failed at "
804841 f"Batch size: { n_batch_tokens } . "
805- f"Result of memory_seq_rm: { is_success } . "
806842 f"Error: { str (e )} ."
807843 ) from e
844+
808845 # Save tokens
809- self .input_ids [n_past : n_past + n_batch_tokens ] = batch
846+ self .input_ids [n_past : n_past + n_batch_tokens ] = batch_tokens
810847
811848 # Save logits
812849 logits_ptr = self ._ctx .get_logits ()
@@ -820,8 +857,8 @@ def eval(self, tokens: Sequence[int]):
820857 self .scores [0 , :] = logits_view
821858
822859 # Update n_tokens
823- current_pos += n_batch_tokens
824- self . n_tokens = current_pos
860+ self . n_tokens += current_batch_size
861+ i += current_batch_size
825862
826863 # Helper method: Convert dict logit_bias to List[llama_logit_bias]
827864 def _convert_logit_bias (self , logit_bias : Optional [Dict [int , float ]]) -> List [llama_cpp .llama_logit_bias ]:
@@ -1026,6 +1063,63 @@ def generate(
10261063 Yields:
10271064 The generated tokens.
10281065 """
1066+ original_tokens = list (tokens )
1067+ # Check for kv cache prefix match
1068+ if reset and self .n_tokens > 0 :
1069+ longest_prefix = self .longest_token_prefix (self ._input_ids , tokens [:- 1 ])
1070+ if longest_prefix > 0 :
1071+ reset = False
1072+
1073+ if longest_prefix == len (tokens ):
1074+ if self .verbose :
1075+ print (f"Llama.generate: Full match. Forcing prefix-- to evaluate 1 token." , file = sys .stderr )
1076+ longest_prefix -= 1
1077+
1078+ # Physically erase trailing "ghost" tokens from the C++ KV cache
1079+ # to prevent attention misalignment in multi-round chats.
1080+ if longest_prefix < self .n_tokens :
1081+ if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1082+ if self .verbose :
1083+ print (f"Llama.generate: Hybrid model rollback triggered." , file = sys .stderr )
1084+
1085+ best_ckpt = self ._hybrid_cache_mgr .find_best_checkpoint (original_tokens , 0 )
1086+ if best_ckpt is not None and self ._hybrid_cache_mgr .restore_checkpoint (best_ckpt , seq_id = 0 ):
1087+ actual_prefix = best_ckpt .pos
1088+ else :
1089+ actual_prefix = 0
1090+ self ._hybrid_cache_mgr .clear ()
1091+ self ._ctx .memory_clear (True )
1092+
1093+ self .n_tokens = actual_prefix
1094+ tokens = original_tokens [actual_prefix :]
1095+ if self .verbose :
1096+ print (
1097+ f"Llama.generate: { actual_prefix } prefix-match hit, "
1098+ f"remaining { len (tokens )} prompt tokens to eval" ,
1099+ file = sys .stderr ,
1100+ )
1101+ else :
1102+ if self .verbose :
1103+ print (f"Llama.generate: Truncating KV cache size from { self .n_tokens } to { longest_prefix } " , file = sys .stderr )
1104+ self ._ctx .memory_seq_rm (0 , longest_prefix , - 1 )
1105+
1106+ # Adjust the tokens array and cursor to reuse the matched cache
1107+ self .n_tokens = longest_prefix
1108+ tokens = tokens [longest_prefix :]
1109+
1110+ if self .verbose :
1111+ print (
1112+ f"Llama.generate: { longest_prefix } prefix-match hit, "
1113+ f"remaining { len (tokens )} prompt tokens to eval" ,
1114+ file = sys .stderr ,
1115+ )
1116+ else :
1117+ # No prefix matched. Completely clear the KV cache to prevent context poisoning.
1118+ self .n_tokens = 0
1119+ self ._ctx .memory_clear (True )
1120+ if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1121+ self ._hybrid_cache_mgr .clear ()
1122+
10291123 # Reset mirostat sampling
10301124 params = LlamaSamplingParams (
10311125 # Core Sampling
@@ -1101,88 +1195,84 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
11011195
11021196 self ._sampling_ctx = LlamaSamplingContext (params , self ._model )
11031197
1104- # Check for kv cache prefix match
1105- if reset and self .n_tokens > 0 :
1106- longest_prefix = self .longest_token_prefix (self ._input_ids , tokens [:- 1 ])
1107- if longest_prefix > 0 :
1108- reset = False
1109-
1110- # Physically erase trailing "ghost" tokens from the C++ KV cache
1111- # to prevent attention misalignment in multi-round chats.
1112- if longest_prefix < self .n_tokens :
1113- if self .verbose :
1114- print (f"Llama.generate: Truncating KV cache size from { self .n_tokens } to { longest_prefix } " , file = sys .stderr )
1115- self ._ctx .memory_seq_rm (0 , longest_prefix , - 1 )
1116-
1117- # Adjust the tokens array and cursor to reuse the matched cache
1118- tokens = tokens [longest_prefix :]
1119- self .n_tokens = longest_prefix
1120-
1121- if self .verbose :
1122- print (
1123- f"Llama.generate: { longest_prefix } prefix-match hit, "
1124- f"remaining { len (tokens )} prompt tokens to eval" ,
1125- file = sys .stderr ,
1126- )
1127- else :
1128- # No prefix matched. Completely clear the KV cache to prevent context poisoning.
1129- self .n_tokens = 0
1130- self ._ctx .memory_clear (True )
1131- if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1132- self ._hybrid_cache_mgr .clear ()
1133-
1134- # Reset the model state
1135- if reset :
1136- self .reset ()
1137-
11381198 sample_idx = self .n_tokens + len (tokens ) - 1
11391199 tokens = list (tokens )
11401200
11411201 # Eval and sample
1142- while True :
1143- self .eval (tokens )
1144- while sample_idx < self .n_tokens :
1145- token = self ._sampling_ctx .sample (self ._ctx , idx = - 1 )
1146- self ._sampling_ctx .accept (token , False if grammar is None else True )
1147-
1148- sample_idx += 1
1149- if stopping_criteria is not None :
1150- if self ._logits_all :
1151- logits_idx = sample_idx - self .n_tokens
1152- check_stopping = True
1153- else :
1154- if sample_idx == self .n_tokens :
1155- logits_idx = 0
1202+ try :
1203+ while True :
1204+ if len (tokens ) > 0 :
1205+ self .eval (tokens )
1206+ while sample_idx < self .n_tokens :
1207+ token = self ._sampling_ctx .sample (self ._ctx , idx = - 1 )
1208+ self ._sampling_ctx .accept (token , False if grammar is None else True )
1209+
1210+ sample_idx += 1
1211+
1212+ if stopping_criteria is not None :
1213+ if self ._logits_all :
1214+ logits_idx = sample_idx - self .n_tokens
11561215 check_stopping = True
11571216 else :
1158- check_stopping = False
1159-
1160- if check_stopping and stopping_criteria (
1161- self ._input_ids [: sample_idx ],
1162- self ._scores [logits_idx , :]
1163- ):
1164- return
1165- tokens_or_none = yield token
1166- tokens .clear ()
1167- tokens .append (token )
1168-
1169- if tokens_or_none is not None :
1170- tokens .extend (tokens_or_none )
1171-
1172- if sample_idx < self .n_tokens and token != self ._input_ids [sample_idx ]:
1173- self .n_tokens = sample_idx
1174- self ._ctx .memory_seq_rm (0 , self .n_tokens , - 1 )
1175- break
1217+ if sample_idx == self .n_tokens :
1218+ logits_idx = 0
1219+ check_stopping = True
1220+ else :
1221+ check_stopping = False
1222+
1223+ if check_stopping and stopping_criteria (
1224+ self ._input_ids [: sample_idx ],
1225+ self ._scores [logits_idx , :]
1226+ ):
1227+ return
1228+
1229+ tokens_or_none = yield token
1230+ tokens .clear ()
1231+ tokens .append (token )
1232+
1233+ if tokens_or_none is not None :
1234+ tokens .extend (tokens_or_none )
1235+
1236+ if sample_idx < self .n_tokens and token != self ._input_ids [sample_idx ]:
1237+ self .n_tokens = sample_idx
1238+ if self .is_hybrid :
1239+ if self .verbose :
1240+ print ("Llama.generate: Draft token rejected for Hybrid model. Rolling back via Checkpoint." , file = sys .stderr )
1241+ if self ._hybrid_cache_mgr :
1242+ best_ckpt = self ._hybrid_cache_mgr .find_best_checkpoint (self ._input_ids [:self .n_tokens ].tolist (), 0 )
1243+ if best_ckpt and self ._hybrid_cache_mgr .restore_checkpoint (best_ckpt , seq_id = 0 ):
1244+ self .n_tokens = best_ckpt .pos
1245+ else :
1246+ self ._hybrid_cache_mgr .clear ()
1247+ self ._ctx .memory_clear (True )
1248+ self .n_tokens = 0
1249+ else :
1250+ self ._ctx .memory_seq_rm (0 , self .n_tokens , - 1 )
11761251
1177- if self .draft_model is not None :
1178- self .input_ids [self .n_tokens : self .n_tokens + len (tokens )] = tokens
1179- draft_tokens = self .draft_model (
1180- self .input_ids [: self .n_tokens + len (tokens )]
1181- )
1182- tokens .extend (
1183- draft_tokens .astype (int )[
1184- : self ._n_ctx - self .n_tokens - len (tokens )
1185- ]
1252+ break
1253+
1254+ if self .draft_model is not None :
1255+ if self .is_hybrid :
1256+ if self .verbose :
1257+ print ("Llama.generate: Speculative decoding is skipped for Hybrid models." , file = sys .stderr )
1258+ else :
1259+ self .input_ids [self .n_tokens : self .n_tokens + len (tokens )] = tokens
1260+ draft_tokens = self .draft_model (
1261+ self .input_ids [: self .n_tokens + len (tokens )]
1262+ )
1263+ tokens .extend (
1264+ draft_tokens .astype (int )[
1265+ : self ._n_ctx - self .n_tokens - len (tokens )
1266+ ]
1267+ )
1268+ finally :
1269+ if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1270+ current_history = self ._input_ids [:self .n_tokens ].tolist ()
1271+
1272+ self ._hybrid_cache_mgr .save_checkpoint (
1273+ current_pos = self .n_tokens ,
1274+ tokens = current_history ,
1275+ seq_id = 0
11861276 )
11871277
11881278 def create_embedding (
0 commit comments