Skip to content

Commit 978ddf7

Browse files
committed
feat(core): overhaul generate and eval for hybrid model support(Qwen3-next、Qwen3.5 etc.)
- Integrated `HybridCheckpointCache` into the generation loop to support state rollback for recurrent/hybrid architectures. - Implemented Context Shift (sliding window) in `eval` to gracefully prevent OOM when exceeding `n_ctx`. - Adapted `eval` to use the newly vectorized `LlamaBatch.add_sequence` API with dynamic `logits_array` configuration. - Fixed the full prefix match bug by forcing a 1-token re-evaluation to refresh logits. - Disabled speculative decoding for hybrid models to prevent irreversible state pollution. - Wrapped the generation loop in a `try...finally` block to guarantee safe checkpoint saving. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent 9262fc0 commit 978ddf7

File tree

2 files changed

+207
-108
lines changed

2 files changed

+207
-108
lines changed

llama_cpp/llama.py

Lines changed: 182 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
# Context Params
8686
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
8787
n_ctx: int = 512,
88+
n_keep: int = 256,
8889
n_batch: int = 2048,
8990
n_ubatch: int = 512,
9091
n_seq_max: int = 1,
@@ -177,6 +178,7 @@ def __init__(
177178
kv_overrides: Key-value overrides for the model.
178179
seed: RNG seed, -1 for random
179180
n_ctx: Text context, 0 = from model
181+
n_keep: Number of tokens to keep from initial prompt
180182
n_batch: Prompt processing maximum batch size
181183
n_ubatch: Physical batch size
182184
n_seq_max: max number of sequences (i.e. distinct states for recurrent models)
@@ -328,6 +330,7 @@ def __init__(
328330
self.model_params.kv_overrides = self._kv_overrides_array
329331

330332
self.n_batch = min(n_ctx, n_batch) # ???
333+
self.n_keep = n_keep if n_keep > 0 else 256
331334
self.n_seq_max = n_seq_max
332335
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
333336
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
@@ -778,35 +781,69 @@ def eval(self, tokens: Sequence[int]):
778781
if len(tokens) == 0:
779782
return
780783
n_eval = len(tokens)
781-
current_pos = self.n_tokens
784+
if n_eval == 0:
785+
return
786+
787+
# Context Shift
788+
if self.n_tokens + n_eval > self._n_ctx:
789+
if self.is_hybrid:
790+
raise RuntimeError(
791+
f"Context length exceeded for Hybrid/SWA model! "
792+
f"(n_tokens: {self.n_tokens}, new: {n_eval}, max: {self._n_ctx})"
793+
)
794+
else:
795+
_n_keep = min(self.n_keep, self.n_tokens)
796+
# number of tokens after n_keep that may be discarded when shifting context
797+
# defaults to half
798+
_n_discard = (self.n_tokens - _n_keep) // 2
782799

783-
if self._ctx:
784-
# Standard cleanup by current_pos
785-
is_success = self._ctx.memory_seq_rm(0, current_pos, -1)
786-
# Fallback: Broad cleanup
787-
if not is_success:
788800
if self.verbose:
789-
print(f"WARN: memory_seq_rm(0, {current_pos}, -1) failed. Executing fallback: memory_seq_rm(0, 0, -1)")
790-
is_success = self._ctx.memory_seq_rm(0, 0, -1)
801+
print(f"Llama.eval: Context limit reached. Shifting context: "
802+
f"discarding {_n_discard} tokens...", file=sys.stderr)
803+
804+
self._ctx.memory_seq_rm(0, _n_keep, _n_keep + _n_discard)
805+
self._ctx.memory_seq_add(0, _n_keep + _n_discard, self.n_tokens, -_n_discard)
806+
807+
remaining_len = self.n_tokens - (_n_keep + _n_discard)
808+
if remaining_len > 0:
809+
self.input_ids[_n_keep : _n_keep + remaining_len] = self.input_ids[_n_keep + _n_discard : self.n_tokens]
810+
811+
self.n_tokens -= _n_discard
791812

792813
for i in range(0, n_eval, self.n_batch):
793-
batch = tokens[i : min(n_eval, i + self.n_batch)]
814+
batch_tokens = tokens[i : min(n_eval, i + self.n_batch)]
815+
n_batch_tokens = len(batch_tokens)
794816
n_past = self.n_tokens
795-
n_batch_tokens = len(batch)
796-
self._batch.set_batch(
797-
batch=batch, n_past=n_past, logits_all=self._logits_all
817+
818+
self._batch.reset()
819+
820+
pos_array = [self.n_tokens + j for j in range(n_batch_tokens)]
821+
822+
if self._logits_all:
823+
logits_array = [True] * n_batch_tokens
824+
else:
825+
logits_array = [False] * n_batch_tokens
826+
if i + n_batch_tokens == n_eval:
827+
logits_array[-1] = True
828+
829+
self._batch.add_sequence(
830+
token_array=batch_tokens,
831+
pos_array=pos_array,
832+
seq_ids=[0],
833+
logits_array=logits_array
798834
)
835+
current_batch_size = n_batch_tokens
799836
try:
800837
self._ctx.decode(self._batch)
801838
except Exception as e:
802839
raise RuntimeError(
803-
f"Decode Failed at Pos {current_pos}. "
840+
f"Decode Failed at "
804841
f"Batch size: {n_batch_tokens}. "
805-
f"Result of memory_seq_rm: {is_success}. "
806842
f"Error: {str(e)}."
807843
) from e
844+
808845
# Save tokens
809-
self.input_ids[n_past : n_past + n_batch_tokens] = batch
846+
self.input_ids[n_past : n_past + n_batch_tokens] = batch_tokens
810847

811848
# Save logits
812849
logits_ptr = self._ctx.get_logits()
@@ -820,8 +857,8 @@ def eval(self, tokens: Sequence[int]):
820857
self.scores[0, :] = logits_view
821858

822859
# Update n_tokens
823-
current_pos += n_batch_tokens
824-
self.n_tokens = current_pos
860+
self.n_tokens += current_batch_size
861+
i += current_batch_size
825862

826863
# Helper method: Convert dict logit_bias to List[llama_logit_bias]
827864
def _convert_logit_bias(self, logit_bias: Optional[Dict[int, float]]) -> List[llama_cpp.llama_logit_bias]:
@@ -1026,6 +1063,63 @@ def generate(
10261063
Yields:
10271064
The generated tokens.
10281065
"""
1066+
original_tokens = list(tokens)
1067+
# Check for kv cache prefix match
1068+
if reset and self.n_tokens > 0:
1069+
longest_prefix = self.longest_token_prefix(self._input_ids, tokens[:-1])
1070+
if longest_prefix > 0:
1071+
reset = False
1072+
1073+
if longest_prefix == len(tokens):
1074+
if self.verbose:
1075+
print(f"Llama.generate: Full match. Forcing prefix-- to evaluate 1 token.", file=sys.stderr)
1076+
longest_prefix -= 1
1077+
1078+
# Physically erase trailing "ghost" tokens from the C++ KV cache
1079+
# to prevent attention misalignment in multi-round chats.
1080+
if longest_prefix < self.n_tokens:
1081+
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1082+
if self.verbose:
1083+
print(f"Llama.generate: Hybrid model rollback triggered.", file=sys.stderr)
1084+
1085+
best_ckpt = self._hybrid_cache_mgr.find_best_checkpoint(original_tokens, 0)
1086+
if best_ckpt is not None and self._hybrid_cache_mgr.restore_checkpoint(best_ckpt, seq_id=0):
1087+
actual_prefix = best_ckpt.pos
1088+
else:
1089+
actual_prefix = 0
1090+
self._hybrid_cache_mgr.clear()
1091+
self._ctx.memory_clear(True)
1092+
1093+
self.n_tokens = actual_prefix
1094+
tokens = original_tokens[actual_prefix:]
1095+
if self.verbose:
1096+
print(
1097+
f"Llama.generate: {actual_prefix} prefix-match hit, "
1098+
f"remaining {len(tokens)} prompt tokens to eval",
1099+
file=sys.stderr,
1100+
)
1101+
else:
1102+
if self.verbose:
1103+
print(f"Llama.generate: Truncating KV cache size from {self.n_tokens} to {longest_prefix}", file=sys.stderr)
1104+
self._ctx.memory_seq_rm(0, longest_prefix, -1)
1105+
1106+
# Adjust the tokens array and cursor to reuse the matched cache
1107+
self.n_tokens = longest_prefix
1108+
tokens = tokens[longest_prefix:]
1109+
1110+
if self.verbose:
1111+
print(
1112+
f"Llama.generate: {longest_prefix} prefix-match hit, "
1113+
f"remaining {len(tokens)} prompt tokens to eval",
1114+
file=sys.stderr,
1115+
)
1116+
else:
1117+
# No prefix matched. Completely clear the KV cache to prevent context poisoning.
1118+
self.n_tokens = 0
1119+
self._ctx.memory_clear(True)
1120+
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1121+
self._hybrid_cache_mgr.clear()
1122+
10291123
# Reset mirostat sampling
10301124
params = LlamaSamplingParams(
10311125
# Core Sampling
@@ -1101,88 +1195,84 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
11011195

11021196
self._sampling_ctx = LlamaSamplingContext(params, self._model)
11031197

1104-
# Check for kv cache prefix match
1105-
if reset and self.n_tokens > 0:
1106-
longest_prefix = self.longest_token_prefix(self._input_ids, tokens[:-1])
1107-
if longest_prefix > 0:
1108-
reset = False
1109-
1110-
# Physically erase trailing "ghost" tokens from the C++ KV cache
1111-
# to prevent attention misalignment in multi-round chats.
1112-
if longest_prefix < self.n_tokens:
1113-
if self.verbose:
1114-
print(f"Llama.generate: Truncating KV cache size from {self.n_tokens} to {longest_prefix}", file=sys.stderr)
1115-
self._ctx.memory_seq_rm(0, longest_prefix, -1)
1116-
1117-
# Adjust the tokens array and cursor to reuse the matched cache
1118-
tokens = tokens[longest_prefix:]
1119-
self.n_tokens = longest_prefix
1120-
1121-
if self.verbose:
1122-
print(
1123-
f"Llama.generate: {longest_prefix} prefix-match hit, "
1124-
f"remaining {len(tokens)} prompt tokens to eval",
1125-
file=sys.stderr,
1126-
)
1127-
else:
1128-
# No prefix matched. Completely clear the KV cache to prevent context poisoning.
1129-
self.n_tokens = 0
1130-
self._ctx.memory_clear(True)
1131-
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1132-
self._hybrid_cache_mgr.clear()
1133-
1134-
# Reset the model state
1135-
if reset:
1136-
self.reset()
1137-
11381198
sample_idx = self.n_tokens + len(tokens) - 1
11391199
tokens = list(tokens)
11401200

11411201
# Eval and sample
1142-
while True:
1143-
self.eval(tokens)
1144-
while sample_idx < self.n_tokens:
1145-
token = self._sampling_ctx.sample(self._ctx, idx=-1)
1146-
self._sampling_ctx.accept(token, False if grammar is None else True)
1147-
1148-
sample_idx += 1
1149-
if stopping_criteria is not None:
1150-
if self._logits_all:
1151-
logits_idx = sample_idx - self.n_tokens
1152-
check_stopping = True
1153-
else:
1154-
if sample_idx == self.n_tokens:
1155-
logits_idx = 0
1202+
try:
1203+
while True:
1204+
if len(tokens) > 0:
1205+
self.eval(tokens)
1206+
while sample_idx < self.n_tokens:
1207+
token = self._sampling_ctx.sample(self._ctx, idx=-1)
1208+
self._sampling_ctx.accept(token, False if grammar is None else True)
1209+
1210+
sample_idx += 1
1211+
1212+
if stopping_criteria is not None:
1213+
if self._logits_all:
1214+
logits_idx = sample_idx - self.n_tokens
11561215
check_stopping = True
11571216
else:
1158-
check_stopping = False
1159-
1160-
if check_stopping and stopping_criteria(
1161-
self._input_ids[: sample_idx],
1162-
self._scores[logits_idx, :]
1163-
):
1164-
return
1165-
tokens_or_none = yield token
1166-
tokens.clear()
1167-
tokens.append(token)
1168-
1169-
if tokens_or_none is not None:
1170-
tokens.extend(tokens_or_none)
1171-
1172-
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
1173-
self.n_tokens = sample_idx
1174-
self._ctx.memory_seq_rm(0, self.n_tokens, -1)
1175-
break
1217+
if sample_idx == self.n_tokens:
1218+
logits_idx = 0
1219+
check_stopping = True
1220+
else:
1221+
check_stopping = False
1222+
1223+
if check_stopping and stopping_criteria(
1224+
self._input_ids[: sample_idx],
1225+
self._scores[logits_idx, :]
1226+
):
1227+
return
1228+
1229+
tokens_or_none = yield token
1230+
tokens.clear()
1231+
tokens.append(token)
1232+
1233+
if tokens_or_none is not None:
1234+
tokens.extend(tokens_or_none)
1235+
1236+
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
1237+
self.n_tokens = sample_idx
1238+
if self.is_hybrid:
1239+
if self.verbose:
1240+
print("Llama.generate: Draft token rejected for Hybrid model. Rolling back via Checkpoint.", file=sys.stderr)
1241+
if self._hybrid_cache_mgr:
1242+
best_ckpt = self._hybrid_cache_mgr.find_best_checkpoint(self._input_ids[:self.n_tokens].tolist(), 0)
1243+
if best_ckpt and self._hybrid_cache_mgr.restore_checkpoint(best_ckpt, seq_id=0):
1244+
self.n_tokens = best_ckpt.pos
1245+
else:
1246+
self._hybrid_cache_mgr.clear()
1247+
self._ctx.memory_clear(True)
1248+
self.n_tokens = 0
1249+
else:
1250+
self._ctx.memory_seq_rm(0, self.n_tokens, -1)
11761251

1177-
if self.draft_model is not None:
1178-
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
1179-
draft_tokens = self.draft_model(
1180-
self.input_ids[: self.n_tokens + len(tokens)]
1181-
)
1182-
tokens.extend(
1183-
draft_tokens.astype(int)[
1184-
: self._n_ctx - self.n_tokens - len(tokens)
1185-
]
1252+
break
1253+
1254+
if self.draft_model is not None:
1255+
if self.is_hybrid:
1256+
if self.verbose:
1257+
print("Llama.generate: Speculative decoding is skipped for Hybrid models.", file=sys.stderr)
1258+
else:
1259+
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
1260+
draft_tokens = self.draft_model(
1261+
self.input_ids[: self.n_tokens + len(tokens)]
1262+
)
1263+
tokens.extend(
1264+
draft_tokens.astype(int)[
1265+
: self._n_ctx - self.n_tokens - len(tokens)
1266+
]
1267+
)
1268+
finally:
1269+
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1270+
current_history = self._input_ids[:self.n_tokens].tolist()
1271+
1272+
self._hybrid_cache_mgr.save_checkpoint(
1273+
current_pos=self.n_tokens,
1274+
tokens=current_history,
1275+
seq_id=0
11861276
)
11871277

11881278
def create_embedding(

0 commit comments

Comments
 (0)