diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/task_stop_rerun.py index db8dd8807..809e625ae 100644 --- a/examples/mem_scheduler/task_stop_rerun.py +++ b/examples/mem_scheduler/task_stop_rerun.py @@ -25,9 +25,9 @@ def my_test_handler(messages: list[ScheduleMessageItem]): task_id = str(msg.item_id) file_path = tmp_dir / f"{task_id}.txt" try: - print(f"writing {file_path}...") - file_path.write_text(f"Task {task_id} processed.\n") sleep(5) + file_path.write_text(f"Task {task_id} processed.\n") + print(f"writing {file_path} done") except Exception as e: print(f"Failed to write {file_path}: {e}") @@ -89,4 +89,5 @@ def submit_tasks(): # 7. Stop the scheduler print("Stopping the scheduler...") +sleep(5) mem_scheduler.stop() diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 7c2e5b558..9a83ab16e 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -453,7 +453,7 @@ def get_memory( @staticmethod def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: """Parse index-keyed JSON from hallucination filter response. - Expected shape: { "0": {"if_delete": bool, "rewritten memory content": str}, ... } + Expected shape: { "0": {"delete": bool, "rewritten": str, "reason": str}, ... } Returns (success, parsed_dict) with int keys. """ try: @@ -476,54 +476,82 @@ def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dic continue if not isinstance(v, dict): continue - delete_flag = v.get("delete_flag") - rewritten = v.get("rewritten memory content", "") - if isinstance(delete_flag, bool) and isinstance(rewritten, str): - result[idx] = {"delete_flag": delete_flag, "rewritten memory content": rewritten} + delete_flag = v.get("delete") + rewritten = v.get("rewritten", "") + reason = v.get("reason", "") + if ( + isinstance(delete_flag, bool) + and isinstance(rewritten, str) + and isinstance(reason, str) + ): + result[idx] = {"delete": delete_flag, "rewritten": rewritten, "reason": reason} return (len(result) > 0), result def filter_hallucination_in_memories( - self, user_messages: list[str], memory_list: list[list[TextualMemoryItem]] - ): - filtered_memory_list = [] - for group in memory_list: - try: - flat_memories = [one.memory for one in group] - template = PROMPT_MAPPING["hallucination_filter"] - prompt_args = { - "user_messages_inline": "\n".join(user_messages), - "memories_inline": json.dumps(flat_memories, ensure_ascii=False, indent=2), - } - prompt = template.format(**prompt_args) + self, user_messages: list[str], memory_list: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + flat_memories = [one.memory for one in memory_list] + template = PROMPT_MAPPING["hallucination_filter"] + prompt_args = { + "user_messages_inline": "\n".join([f"- {memory}" for memory in user_messages]), + "memories_inline": json.dumps( + {str(i): memory for i, memory in enumerate(flat_memories)}, + ensure_ascii=False, + indent=2, + ), + } + prompt = template.format(**prompt_args) - # Optionally run filter and parse the output - try: - raw = self.llm.generate(prompt) - success, parsed = self._parse_hallucination_filter_response(raw) - logger.info(f"Hallucination filter parsed successfully: {success}") - new_mem_list = [] - if success: - logger.info(f"Hallucination filter result: {parsed}") - for mem_idx, (delete_flag, rewritten_mem_content) in parsed.items(): - if not delete_flag: - group[mem_idx].memory = rewritten_mem_content - new_mem_list.append(group[mem_idx]) - filtered_memory_list.append(new_mem_list) - logger.info( - f"Successfully transform origianl memories from {group} to {new_mem_list}." - ) - else: + # Optionally run filter and parse the output + try: + raw = self.llm.generate([{"role": "user", "content": prompt}]) + success, parsed = self._parse_hallucination_filter_response(raw) + logger.info( + f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success}" + ) + if success: + logger.info(f"Hallucination filter result: {parsed}") + total = len(memory_list) + keep_flags = [True] * total + for mem_idx, content in parsed.items(): + # Validate index bounds + if not isinstance(mem_idx, int) or mem_idx < 0 or mem_idx >= total: logger.warning( - "Hallucination filter parsing failed or returned empty result." + f"[filter_hallucination_in_memories] Ignoring out-of-range index: {mem_idx}" ) - except Exception as e: - logger.error(f"Hallucination filter execution error: {e}", stack_info=True) - filtered_memory_list.append(group) - except Exception: - logger.error("Fail to filter memories", stack_info=True) - filtered_memory_list.append(group) - return filtered_memory_list + continue + + delete_flag = content.get("delete", False) + rewritten = content.get("rewritten", None) + reason = content.get("reason", "") + + logger.info( + f"[filter_hallucination_in_memories] index={mem_idx}, delete={delete_flag}, rewritten='{(rewritten or '')[:100]}', reason='{reason[:120]}'" + ) + + if delete_flag is True and rewritten is not None: + # Mark for deletion + keep_flags[mem_idx] = False + else: + # Apply rewrite if provided (safe-by-default: keep item when not mentioned or delete=False) + try: + if isinstance(rewritten, str): + memory_list[mem_idx].memory = rewritten + except Exception as e: + logger.warning( + f"[filter_hallucination_in_memories] Failed to apply rewrite for index {mem_idx}: {e}" + ) + + # Build result, preserving original order; keep items not mentioned by LLM by default + new_mem_list = [memory_list[i] for i in range(total) if keep_flags[i]] + return new_mem_list + else: + logger.warning("Hallucination filter parsing failed or returned empty result.") + except Exception as e: + logger.error(f"Hallucination filter execution error: {e}", stack_info=True) + + return memory_list def _read_memory( self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" @@ -572,11 +600,16 @@ def _read_memory( if os.getenv("SIMPLE_STRUCT_ADD_FILTER", "false") == "true": # Build inputs - user_messages = [msg.content for msg in messages if msg.role == "user"] - memory_list = self.filter_hallucination_in_memories( - user_messages=user_messages, memory_list=memory_list - ) - + new_memory_list = [] + for unit_messages, unit_memory_list in zip(messages, memory_list, strict=False): + unit_user_messages = [ + msg["content"] for msg in unit_messages if msg["role"] == "user" + ] + unit_memory_list = self.filter_hallucination_in_memories( + user_messages=unit_user_messages, memory_list=unit_memory_list + ) + new_memory_list.append(unit_memory_list) + memory_list = new_memory_list return memory_list def fine_transfer_simple_mem( diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 8f3eccecf..59bd1c0a2 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -66,7 +66,7 @@ def __init__(self, config: GeneralSchedulerConfig): def long_memory_update_process( self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] ): - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube # update query monitors for msg in messages: @@ -109,8 +109,8 @@ def long_memory_update_process( query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] query_db_manager.obj.put(item=item) - # Sync with database after adding new item - query_db_manager.sync_with_orm() + # Sync with database after adding new item + query_db_manager.sync_with_orm() logger.debug( f"Queries in monitor for user_id={user_id}, mem_cube_id={mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}" ) @@ -162,7 +162,7 @@ def long_memory_update_process( label=QUERY_TASK_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, ) def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: @@ -249,7 +249,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: to_memory_type=NOT_APPLICABLE_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=[ { "content": f"[User] {msg.content}", @@ -305,7 +305,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: to_memory_type=NOT_APPLICABLE_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=[ { "content": f"[Assistant] {msg.content}", @@ -338,7 +338,7 @@ def log_add_messages(self, msg: ScheduleMessageItem): try: # This mem_item represents the NEW content that was just added/processed mem_item: TextualMemoryItem | None = None - mem_item = self.current_mem_cube.text_mem.get( + mem_item = self.mem_cube.text_mem.get( memory_id=memory_id, user_name=msg.mem_cube_id ) if mem_item is None: @@ -352,8 +352,8 @@ def log_add_messages(self, msg: ScheduleMessageItem): original_item_id = None # Only check graph_store if a key exists and the text_mem has a graph_store - if key and hasattr(self.current_mem_cube.text_mem, "graph_store"): - candidates = self.current_mem_cube.text_mem.graph_store.get_by_metadata( + if key and hasattr(self.mem_cube.text_mem, "graph_store"): + candidates = self.mem_cube.text_mem.graph_store.get_by_metadata( [ {"field": "key", "op": "=", "value": key}, { @@ -368,7 +368,7 @@ def log_add_messages(self, msg: ScheduleMessageItem): original_item_id = candidates[0] # Crucial step: Fetch the original content for updates # This `get` is for the *existing* memory that will be updated - original_mem_item = self.current_mem_cube.text_mem.get( + original_mem_item = self.mem_cube.text_mem.get( memory_id=original_item_id, user_name=msg.mem_cube_id ) original_content = original_mem_item.memory @@ -481,7 +481,7 @@ def send_add_log_messages_to_local_env( to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=add_content_legacy, metadata=add_meta_legacy, memory_len=len(add_content_legacy), @@ -496,7 +496,7 @@ def send_add_log_messages_to_local_env( to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=update_content_legacy, metadata=update_meta_legacy, memory_len=len(update_content_legacy), @@ -562,7 +562,7 @@ def send_add_log_messages_to_cloud_env( to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=kb_log_content, metadata=None, memory_len=len(kb_log_content), @@ -577,7 +577,7 @@ def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> if not messages: return message = messages[0] - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube user_id = message.user_id mem_cube_id = message.mem_cube_id @@ -744,9 +744,9 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube if mem_cube is None: - logger.warning( + logger.error( f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing", stack_info=True, ) @@ -923,7 +923,7 @@ def _process_memories_with_reader( to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=kb_log_content, metadata=None, memory_len=len(kb_log_content), @@ -968,7 +968,7 @@ def _process_memories_with_reader( to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=add_content_legacy, metadata=add_meta_legacy, memory_len=len(add_content_legacy), @@ -1036,7 +1036,7 @@ def _process_memories_with_reader( to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, memcube_log_content=kb_log_content, metadata=None, memory_len=len(kb_log_content), @@ -1054,7 +1054,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube if mem_cube is None: logger.warning( f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing" @@ -1284,7 +1284,7 @@ def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> Non def process_message(message: ScheduleMessageItem): try: - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube if mem_cube is None: logger.warning( f"mem_cube is None for user_id={message.user_id}, mem_cube_id={message.mem_cube_id}, skipping processing" diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 693816fd8..c3f5891ae 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -230,7 +230,7 @@ def update_search_memories_to_redis( memories: list[TextualMemoryItem] = self.search_memories( search_req=APISearchRequest(**content_dict["search_req"]), user_context=UserContext(**content_dict["user_context"]), - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, mode=SearchMode.FAST, ) formatted_memories = [format_textual_memory_item(data) for data in memories] diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index fb3a5931a..5439cf225 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -60,6 +60,16 @@ class TaskPriorityLevel(Enum): # Interval in seconds for batching and cleaning up deletions (xdel) DEFAULT_DELETE_CLEANUP_INTERVAL_SEC = 30.0 +# Inactivity threshold for stream deletion +# Delete streams whose last message ID timestamp is older than this threshold. +# Unit: seconds. Default: 1 day. +DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS = 86_400.0 + +# Recency threshold for active streams +# Consider a stream "active" if its last message is within this window. +# Unit: seconds. Default: 30 minutes. +DEFAULT_STREAM_RECENT_ACTIVE_SECONDS = 1_800.0 + # task queue DEFAULT_STREAM_KEY_PREFIX = os.getenv( diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index b9cab4ff8..36fe3c553 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -18,8 +18,10 @@ from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( + DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS, DEFAULT_STREAM_KEY_PREFIX, DEFAULT_STREAM_KEYS_REFRESH_INTERVAL_SEC, + DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, ) from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -64,15 +66,17 @@ def __init__( max_len: Maximum length of the stream (for memory management) maxsize: Maximum size of the queue (for Queue compatibility, ignored) auto_delete_acked: Whether to automatically delete acknowledged messages from stream - status_tracker: TaskStatusTracker instance for tracking task status """ super().__init__() # Stream configuration self.stream_key_prefix = stream_key_prefix + # Precompile regex for prefix filtering to reduce repeated compilation overhead + self.stream_prefix_regex_pattern = re.compile(f"^{re.escape(self.stream_key_prefix)}:") self.consumer_group = consumer_group self.consumer_name = f"{consumer_name}_{uuid4().hex[:8]}" self.max_len = max_len self.auto_delete_acked = auto_delete_acked # Whether to delete acknowledged messages + self.status_tracker = status_tracker # Consumer state self._is_listening = False @@ -89,6 +93,10 @@ def __init__( self._refill_lock = threading.Lock() self._refill_thread: ContextThread | None = None + # Track empty streams first-seen time to avoid zombie keys + self._empty_stream_seen_times: dict[str, float] = {} + self._empty_stream_seen_lock = threading.Lock() + logger.info( f"[REDIS_QUEUE] Initialized with stream_prefix='{self.stream_key_prefix}', " f"consumer_group='{self.consumer_group}', consumer_name='{self.consumer_name}'" @@ -104,7 +112,6 @@ def __init__( self.message_pack_cache = deque() self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator - self.status_tracker = status_tracker # Cached stream keys and refresh control self._stream_keys_cache: list[str] = [] @@ -120,6 +127,11 @@ def __init__( os.getenv("MEMSCHEDULER_REDIS_INITIAL_SCAN_TIME_LIMIT_SEC", "1.0") or 1.0 ) + # Pipeline chunk size for XREVRANGE pipelined calls + self._pipeline_chunk_size = int( + os.getenv("MEMSCHEDULER_REDIS_PIPELINE_CHUNK_SIZE", "200") or 200 + ) + # Start background stream keys refresher if connected if self._is_connected: try: @@ -150,36 +162,53 @@ def _refresh_stream_keys( stream_key_prefix = self.stream_key_prefix try: - redis_pattern = f"{stream_key_prefix}:*" - collected: list[str] = [] - cursor: int | str = 0 - start_ts = time.time() if time_limit_sec else None - count_hint = 200 - while True: - if ( - start_ts is not None - and time_limit_sec is not None - and time.time() - start_ts > time_limit_sec - ): - break - cursor, keys = self._redis_conn.scan( - cursor=cursor, match=redis_pattern, count=count_hint + candidate_keys = self._scan_candidate_stream_keys( + stream_key_prefix=stream_key_prefix, + max_keys=max_keys, + time_limit_sec=time_limit_sec, + ) + chunked_results = self._pipeline_last_entries(candidate_keys) + # Only process successful chunks to maintain 1:1 key-result mapping + processed_keys: list[str] = [] + last_entries_results: list[list[tuple[str, dict]]] = [] + + total_key_count = 0 + for chunk_keys, chunk_res, success in chunked_results: + if success: + processed_keys.extend(chunk_keys) + last_entries_results.extend(chunk_res) + total_key_count += len(chunk_keys) + + # Abort refresh if any chunk failed, indicated by processed count mismatch + if len(candidate_keys) != total_key_count: + logger.error( + f"[REDIS_QUEUE] Last entries processed mismatch: " + f"candidates={len(candidate_keys)}, processed={len(processed_keys)}; aborting refresh" ) - collected.extend(keys) - if max_keys is not None and len(collected) >= max_keys: - break - if cursor == 0 or cursor == "0": - break - - escaped_prefix = re.escape(stream_key_prefix) - regex_pattern = f"^{escaped_prefix}:" - stream_keys = [key for key in collected if re.match(regex_pattern, key)] - - if stream_key_prefix == self.stream_key_prefix: - with self._stream_keys_lock: - self._stream_keys_cache = stream_keys - self._stream_keys_last_refresh = time.time() - return stream_keys + return [] + + now_sec = time.time() + keys_to_delete = self._collect_inactive_keys( + candidate_keys=processed_keys, + last_entries_results=last_entries_results, + inactivity_seconds=DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS, + now_sec=now_sec, + ) + active_stream_keys = self._filter_active_keys( + candidate_keys=processed_keys, + last_entries_results=last_entries_results, + recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, + now_sec=now_sec, + ) + deleted_count = self._delete_streams(keys_to_delete) + self._update_stream_cache_with_log( + stream_key_prefix=stream_key_prefix, + candidate_keys=processed_keys, + active_stream_keys=active_stream_keys, + deleted_count=deleted_count, + active_threshold_sec=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, + ) + return active_stream_keys except Exception as e: logger.warning(f"Failed to refresh stream keys: {e}") return [] @@ -384,11 +413,6 @@ def ack_message( try: self._redis_conn.xack(stream_key, self.consumer_group, redis_message_id) - - if message: - logger.info( - f"Message {message.item_id} | {message.label} | {message.content} has been acknowledged." - ) except Exception as e: logger.warning( f"xack failed for stream '{stream_key}', msg_id='{redis_message_id}': {e}" @@ -411,136 +435,159 @@ def get( if not self._redis_conn: raise ConnectionError("Not connected to Redis. Redis connection not available.") + redis_timeout = self._compute_redis_timeout(block=block, timeout=timeout) + + # Step 1: read new messages first + new_messages = self._read_new_messages( + stream_key=stream_key, batch_size=batch_size, redis_timeout=redis_timeout + ) + + # Step 2: determine how many pending messages we need + need_pending_count = self._compute_pending_need( + new_messages=new_messages, batch_size=batch_size + ) + + # Step 3: claim eligible pending messages + pending_messages: list[tuple[str, list[tuple[str, dict]]]] = [] + if need_pending_count: + task_label = stream_key.rsplit(":", 1)[1] + pending_messages = self._claim_pending_messages( + stream_key=stream_key, + need_pending_count=need_pending_count, + task_label=task_label, + ) + + # Step 4: assemble and convert to ScheduleMessageItem + messages = [] + if new_messages: + messages.extend(new_messages) + if pending_messages: + messages.extend(pending_messages) + + result_messages = self._convert_messages(messages) + + if not result_messages: + if not block: + return [] + else: + from queue import Empty + + raise Empty("No messages available in Redis queue") + + return result_messages + + def _compute_redis_timeout(self, block: bool, timeout: float | None) -> int | None: + """Compute Redis block timeout in milliseconds for xreadgroup.""" + if block and timeout is not None: + return int(timeout * 1000) + return None + + def _read_new_messages( + self, stream_key: str, batch_size: int | None, redis_timeout: int | None + ) -> list[tuple[str, list[tuple[str, dict]]]]: + """Read new messages for the consumer group, handling missing group/stream.""" try: - # Calculate timeout for Redis - redis_timeout = None - if block and timeout is not None: - redis_timeout = int(timeout * 1000) - elif not block: - redis_timeout = None # Non-blocking - - # Read messages from the consumer group - # 1) Read remaining/new messages first (not yet delivered to any consumer) - new_messages: list[tuple[str, list[tuple[str, dict]]]] = [] - try: - new_messages = self._redis_conn.xreadgroup( + return self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {stream_key: ">"}, + count=batch_size, + block=redis_timeout, + ) + except Exception as read_err: + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (new)." + ) + self._ensure_consumer_group(stream_key=stream_key) + return self._redis_conn.xreadgroup( self.consumer_group, self.consumer_name, {stream_key: ">"}, count=batch_size, block=redis_timeout, ) - except Exception as read_err: - # Handle missing group/stream by creating and retrying once - err_msg = str(read_err).lower() - if "nogroup" in err_msg or "no such key" in err_msg: - logger.warning( - f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (new)." - ) - self._ensure_consumer_group(stream_key=stream_key) - new_messages = self._redis_conn.xreadgroup( - self.consumer_group, - self.consumer_name, - {stream_key: ">"}, - count=batch_size, - block=redis_timeout, - ) - else: - raise - - # 2) If needed, read pending messages for THIS consumer only - pending_messages: list[tuple[str, list[tuple[str, dict]]]] = [] - need_pending_count = None - if batch_size is None: - # No batch_size: prefer returning a single new message; if none, fetch one pending - if not new_messages: - need_pending_count = 1 - else: - # With batch_size: fill from pending if new insufficient - new_count = sum(len(sm) for _s, sm in new_messages) if new_messages else 0 - need_pending = max(0, batch_size - new_count) - need_pending_count = need_pending if need_pending > 0 else 0 - - task_label = stream_key.rsplit(":", 1)[1] - if need_pending_count: - # Claim only pending messages whose idle time exceeds configured threshold - try: - # Ensure group exists before claiming - self._ensure_consumer_group(stream_key=stream_key) - # XAUTOCLAIM returns (next_start_id, [(id, fields), ...]) - next_id, claimed = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - # Derive task_label from stream_key suffix: {prefix}:{user_id}:{mem_cube_id}:{task_label} - min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label), - start_id="0-0", - count=need_pending_count, - justid=False, - ) - pending_messages = [(stream_key, claimed)] if claimed else [] - except Exception as read_err: - # Handle missing group/stream by creating and retrying once - err_msg = str(read_err).lower() - if "nogroup" in err_msg or "no such key" in err_msg: - logger.warning( - f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)." - ) - self._ensure_consumer_group(stream_key=stream_key) - next_id, claimed = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min( - task_label=task_label - ), - start_id="0-0", - count=need_pending_count, - justid=False, - ) - pending_messages = [(stream_key, claimed)] if claimed else [] - else: - pending_messages = [] - - # Combine: new first, then pending - messages = [] - if new_messages: - messages.extend(new_messages) - if pending_messages: - messages.extend(pending_messages) - - result_messages = [] - for _stream, stream_messages in messages: - for message_id, fields in stream_messages: - try: - # Convert Redis message back to SchedulerMessageItem - message = ScheduleMessageItem.from_dict(fields) - # Preserve stream key and redis message id for monitoring/ack - message.stream_key = _stream - message.redis_message_id = message_id - - result_messages.append(message) - - except Exception as e: - logger.error(f"Failed to parse message {message_id}: {e}", stack_info=True) + logger.error(f"{read_err}", stack_info=True) + raise - # Always return a list for consistency - if not result_messages: - if not block: - return [] # Return empty list for non-blocking calls + def _compute_pending_need( + self, new_messages: list[tuple[str, list[tuple[str, dict]]]] | None, batch_size: int | None + ) -> int: + """Compute how many pending messages are needed to fill the batch.""" + if batch_size is None: + return 1 if not new_messages else 0 + new_count = sum(len(sm) for _s, sm in new_messages) if new_messages else 0 + need_pending = max(0, batch_size - new_count) + return need_pending if need_pending > 0 else 0 + + def _claim_pending_messages( + self, stream_key: str, need_pending_count: int, task_label: str + ) -> list[tuple[str, list[tuple[str, dict]]]]: + """Claim pending messages exceeding idle threshold, with group existence handling.""" + try: + claimed_result = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label), + start_id="0-0", + count=need_pending_count, + justid=False, + ) + if len(claimed_result) == 2: + next_id, claimed = claimed_result + deleted_ids = [] + elif len(claimed_result) == 3: + next_id, claimed, deleted_ids = claimed_result + else: + raise ValueError(f"Unexpected xautoclaim response length: {len(claimed_result)}") + + return [(stream_key, claimed)] if claimed else [] + except Exception as read_err: + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)." + ) + self._ensure_consumer_group(stream_key=stream_key) + claimed_result = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label), + start_id="0-0", + count=need_pending_count, + justid=False, + ) + if len(claimed_result) == 2: + next_id, claimed = claimed_result + deleted_ids = [] + elif len(claimed_result) == 3: + next_id, claimed, deleted_ids = claimed_result else: - # If no messages were found, raise Empty exception - from queue import Empty - - raise Empty("No messages available in Redis queue") + raise ValueError( + f"Unexpected xautoclaim response length: {len(claimed_result)}" + ) from read_err - return result_messages + return [(stream_key, claimed)] if claimed else [] + return [] - except Exception as e: - if "Empty" in str(type(e).__name__): - raise - logger.error(f"Failed to get message from Redis queue: {e}") - raise + def _convert_messages( + self, messages: list[tuple[str, list[tuple[str, dict]]]] + ) -> list[ScheduleMessageItem]: + """Convert raw Redis messages into ScheduleMessageItem with metadata.""" + result: list[ScheduleMessageItem] = [] + for _stream, stream_messages in messages or []: + for message_id, fields in stream_messages: + try: + message = ScheduleMessageItem.from_dict(fields) + message.stream_key = _stream + message.redis_message_id = message_id + result.append(message) + except Exception as e: + logger.error(f"Failed to parse message {message_id}: {e}", stack_info=True) + return result def qsize(self) -> dict: """ @@ -742,3 +789,187 @@ def __del__(self): @property def unfinished_tasks(self) -> int: return self.qsize() + + def _scan_candidate_stream_keys( + self, + stream_key_prefix: str, + max_keys: int | None = None, + time_limit_sec: float | None = None, + count_hint: int = 200, + ) -> list[str]: + """Return stream keys matching the given prefix via SCAN with optional limits. + + Uses a cursor-based SCAN to collect keys matching the prefix, honoring + optional `max_keys` and `time_limit_sec` constraints. Filters results + with a precompiled regex when scanning the configured prefix. + """ + redis_pattern = f"{stream_key_prefix}:*" + collected = [] + cursor = 0 + start_ts = time.time() if time_limit_sec else None + while True: + if ( + start_ts is not None + and time_limit_sec is not None + and (time.time() - start_ts) > time_limit_sec + ): + break + cursor, keys = self._redis_conn.scan( + cursor=cursor, match=redis_pattern, count=count_hint + ) + collected.extend(keys) + if max_keys is not None and len(collected) >= max_keys: + break + if cursor == 0 or cursor == "0": + break + + if stream_key_prefix == self.stream_key_prefix: + pattern = self.stream_prefix_regex_pattern + else: + escaped_prefix = re.escape(stream_key_prefix) + pattern = re.compile(f"^{escaped_prefix}:") + return [key for key in collected if pattern.match(key)] + + def _pipeline_last_entries( + self, candidate_keys: list[str] + ) -> list[tuple[list[str], list[list[tuple[str, dict]]], bool]]: + """Fetch last entries for keys using pipelined XREVRANGE COUNT 1, per-chunk success. + + Returns a list of tuples: (chunk_keys, chunk_results, success_bool). + Only successful chunks should be processed by the caller to preserve + a 1:1 mapping between keys and results. + """ + if not candidate_keys: + return [] + + results_chunks: list[tuple[list[str], list[list[tuple[str, dict]]], bool]] = [] + chunk_size = max(1, int(self._pipeline_chunk_size)) + + for start in range(0, len(candidate_keys), chunk_size): + chunk_keys = candidate_keys[start : start + chunk_size] + try: + pipe = self._redis_conn.pipeline(transaction=False) + for key in chunk_keys: + pipe.xrevrange(key, count=1) + chunk_res = pipe.execute() + results_chunks.append((chunk_keys, chunk_res, True)) + except Exception as e: + logger.warning( + f"[REDIS_QUEUE] Pipeline execute failed for last entries chunk: " + f"offset={start}, size={len(chunk_keys)}, error={e}" + ) + results_chunks.append((chunk_keys, [], False)) + + return results_chunks + + def _parse_last_ms_from_entries(self, entries: list[tuple[str, dict]]) -> int | None: + """Parse millisecond timestamp from the last entry ID.""" + if not entries: + return None + try: + last_id = entries[0][0] + return int(str(last_id).split("-")[0]) + except Exception: + return None + + def _collect_inactive_keys( + self, + candidate_keys: list[str], + last_entries_results: list[list[tuple[str, dict]]], + inactivity_seconds: float, + now_sec: float | None = None, + ) -> list[str]: + """Collect keys whose last entry time is older than inactivity threshold.""" + keys_to_delete: list[str] = [] + now = time.time() if now_sec is None else now_sec + for key, entries in zip(candidate_keys, last_entries_results or [], strict=False): + last_ms = self._parse_last_ms_from_entries(entries) + if last_ms is None: + # Empty stream (no entries). Track first-seen time and delete if past threshold + with self._empty_stream_seen_lock: + first_seen = self._empty_stream_seen_times.get(key) + if first_seen is None: + # Record when we first observed this empty stream + self._empty_stream_seen_times[key] = now + else: + if (now - first_seen) > inactivity_seconds: + keys_to_delete.append(key) + continue + # Stream has entries; clear any empty-tracking state + with self._empty_stream_seen_lock: + if key in self._empty_stream_seen_times: + self._empty_stream_seen_times.pop(key, None) + if (now - (last_ms / 1000.0)) > inactivity_seconds: + keys_to_delete.append(key) + return keys_to_delete + + def _filter_active_keys( + self, + candidate_keys: list[str], + last_entries_results: list[list[tuple[str, dict]]], + recent_seconds: float, + now_sec: float | None = None, + ) -> list[str]: + """Return keys whose last entry time is within the recent window.""" + active: list[str] = [] + now = time.time() if now_sec is None else now_sec + for key, entries in zip(candidate_keys, last_entries_results or [], strict=False): + last_ms = self._parse_last_ms_from_entries(entries) + if last_ms is None: + continue + # Stream has entries; clear any empty-tracking state + with self._empty_stream_seen_lock: + if key in self._empty_stream_seen_times: + self._empty_stream_seen_times.pop(key, None) + # Active if last message is no older than recent_seconds + if (now - (last_ms / 1000.0)) <= recent_seconds: + active.append(key) + return active + + def _delete_streams(self, keys_to_delete: list[str]) -> int: + """Delete the given stream keys in batch, return deleted count.""" + if not keys_to_delete: + return 0 + deleted_count = 0 + try: + del_pipe = self._redis_conn.pipeline(transaction=False) + for key in keys_to_delete: + del_pipe.delete(key) + del_pipe.execute() + deleted_count = len(keys_to_delete) + # Clean up empty-tracking state for deleted keys + with self._empty_stream_seen_lock: + for key in keys_to_delete: + self._empty_stream_seen_times.pop(key, None) + except Exception: + for key in keys_to_delete: + try: + self._redis_conn.delete(key) + deleted_count += 1 + with self._empty_stream_seen_lock: + self._empty_stream_seen_times.pop(key, None) + except Exception: + pass + return deleted_count + + def _update_stream_cache_with_log( + self, + stream_key_prefix: str, + candidate_keys: list[str], + active_stream_keys: list[str], + deleted_count: int, + active_threshold_sec: float, + ) -> None: + """Update cache and emit an info log summarizing refresh statistics.""" + if stream_key_prefix != self.stream_key_prefix: + return + with self._stream_keys_lock: + self._stream_keys_cache = active_stream_keys + self._stream_keys_last_refresh = time.time() + cache_count = len(self._stream_keys_cache) + logger.info( + f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', " + f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, " + f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, " + f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}" + ) diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index ffe6db2d0..8f9810cf1 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -420,36 +420,41 @@ SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ -You are a precise memory consistency auditor. +You are a strict memory validator. -# GOAL -Given user messages and an extracted memory list, identify and fix inconsistencies for each memory. +# TASK +Validate each memory entry against the user's current messages (ground truth). +Memories that hallucinate unsupported facts or contradict the user must be corrected or marked for deletion. # RULES -- Use ONLY information present in the user messages; do not invent. -- Preserve explicit facts: names, timestamps, quantities, locations. -- For each memory, keep the language identical to that memory's original language. -- Output only JSON. No extra commentary. +- Use ONLY facts explicitly stated in the user messages. +- Do NOT invent, assume, or retain unsupported specifics. +- Preserve the original language of each memory when rewriting. +- Output ONLY a JSON object with no extra text. # INPUTS -User messages: +User messages (ground truth): {user_messages_inline} -Current memory list (JSON): +Memory list (to validate, in indexed JSON format): {memories_inline} # OUTPUT FORMAT -Return a JSON object where keys are the 0-based indices of the input memories (string keys allowed), and each value is an object: -{ - "0": {"delete_flag": false, "rewritten memory content": "..."}, - "1": {"delete_flag": true, "rewritten memory content": ""}, - "2": {"delete_flag": false, "rewritten memory content": "..."} -} +Return a JSON object where: +- Keys are the same stringified indices as in the input memory list (e.g., "0", "1"). +- Each value is: {{"delete": boolean, "rewritten": string, "reason": string}} +- If "delete" is true, "rewritten" must be an empty string. +- "reason" must briefly explain the decision (delete or rewrite) based on user messages. +- The number of output entries MUST exactly match the number of input memories. + +# DECISION GUIDE +- Contradicted? → rewrite to match user message, "delete"=false, "rewritten"=corrected memory content. +- Hallucinated (specific fact not in user messages)? → "delete"=true, "rewritten"=dehallucinated rewritten memory. +- Consistent or non-factual (opinion, emotion)? → keep as-is, "delete"=false. + +Additionally, include a concise "reason" for each item explaining your decision. -Notes: -- If a memory is entirely hallucinated or contradicted by user messages, set `if_delete` to true and leave `rewritten memory content` empty. -- If a memory conflicts but can be corrected, set `if_delete` to false and provide the corrected content in `"rewritten memory content"` using the memory's original language. -- If a memory is valid, set `if_delete` to false and return the original content. +Final Output: """