diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index c2dc4a629..64aedc8f4 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1707,124 +1707,116 @@ def delete_node_by_prams( ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. + Supports three scenarios: + 1. Delete by memory_ids (standalone) + 2. Delete by writable_cube_ids + file_ids (combined) + 3. Delete by filter (standalone, no writable_cube_ids needed) Args: writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. - If not provided, no user_name filter will be applied. + Only used with file_ids scenario. If not provided, no user_name filter will be applied. memory_ids (list[str], optional): List of memory node IDs to delete. - file_ids (list[str], optional): List of file node IDs to delete. - filter (dict, optional): Filter dictionary to query matching nodes for deletion. + file_ids (list[str], optional): List of file node IDs to delete. Must be used with writable_cube_ids. + filter (dict, optional): Filter dictionary for metadata filtering. + Filter conditions are directly used in DELETE WHERE clause without pre-querying. + Does not require writable_cube_ids. Returns: int: Number of nodes deleted. """ + batch_start_time = time.time() logger.info( f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" ) - print( - f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" - ) - - # Build WHERE conditions separately for memory_ids and file_ids - where_clauses = [] - params = {} # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) - # Only add user_name filter if writable_cube_ids is provided + # Only add user_name filter if writable_cube_ids is provided (for file_ids scenario) user_name_conditions = [] + params = {} if writable_cube_ids and len(writable_cube_ids) > 0: for idx, cube_id in enumerate(writable_cube_ids): param_name = f"cube_id_{idx}" user_name_conditions.append(f"n.user_name = ${param_name}") params[param_name] = cube_id - # Handle memory_ids: query n.id - if memory_ids and len(memory_ids) > 0: + # Build filter conditions using common method (no query, direct use in WHERE clause) + filter_conditions = [] + filter_params = {} + if filter: + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter, param_counter_start=0, node_alias="n" + ) + logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}") + params.update(filter_params) + + # If no conditions to delete, return 0 + if not memory_ids and not file_ids and not filter_conditions: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) + return 0 + + # Build WHERE conditions list + where_clauses = [] + + # Scenario 1: memory_ids (standalone) + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") where_clauses.append("n.id IN $memory_ids") params["memory_ids"] = memory_ids - # Handle file_ids: query n.file_ids field - # All file_ids must be present in the array field (AND relationship) - if file_ids and len(file_ids) > 0: - file_id_and_conditions = [] + # Scenario 2: file_ids + writable_cube_ids (combined) + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") + file_id_conditions = [] for idx, file_id in enumerate(file_ids): param_name = f"file_id_{idx}" params[param_name] = file_id # Check if this file_id is in the file_ids array field - file_id_and_conditions.append(f"${param_name} IN n.file_ids") - if file_id_and_conditions: - # Use AND to require all file_ids to be present - where_clauses.append(f"({' OR '.join(file_id_and_conditions)})") - - # Query nodes by filter if provided - filter_ids = [] - if filter: - # Use get_by_metadata with empty filters list and filter - filter_ids = self.get_by_metadata( - filters=[], - user_name=None, - filter=filter, - knowledgebase_ids=writable_cube_ids if writable_cube_ids else None, - ) + file_id_conditions.append(f"${param_name} IN n.file_ids") + if file_id_conditions: + where_clauses.append(f"({' OR '.join(file_id_conditions)})") - # If filter returned IDs, add condition for them - if filter_ids: - where_clauses.append("n.id IN $filter_ids") - params["filter_ids"] = filter_ids + # Scenario 3: filter (standalone, no writable_cube_ids needed) + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + # Combine filter conditions with AND + filter_where = " AND ".join(filter_conditions) + where_clauses.append(f"({filter_where})") - # If no conditions (except user_name), return 0 + # Build final WHERE clause if not where_clauses: - logger.warning( - "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" - ) + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") return 0 - # Build WHERE clause - # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) - data_conditions = " OR ".join([f"({clause})" for clause in where_clauses]) + # Combine all conditions with AND + data_conditions = " AND ".join([f"({clause})" for clause in where_clauses]) - # Build final WHERE clause - # If user_name_conditions exist, combine with data_conditions using AND - # Otherwise, use only data_conditions + # Add user_name filter if provided (for file_ids scenario) if user_name_conditions: user_name_where = " OR ".join(user_name_conditions) - ids_where = f"({user_name_where}) AND ({data_conditions})" + final_where = f"({user_name_where}) AND ({data_conditions})" else: - ids_where = f"({data_conditions})" - - logger.info( - f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" - ) - print( - f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" - ) + final_where = data_conditions - # First count matching nodes to get accurate count - count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count" - logger.info(f"[delete_node_by_prams] count_query: {count_query}") - print(f"[delete_node_by_prams] count_query: {count_query}") - - # Then delete nodes - delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n" + # Delete directly without pre-counting + delete_query = f"MATCH (n:Memory) WHERE {final_where} DETACH DELETE n" logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") - print(f"[delete_node_by_prams] delete_query: {delete_query}") - print(f"[delete_node_by_prams] params: {params}") deleted_count = 0 try: with self.driver.session(database=self.db_name) as session: - # Count nodes before deletion - count_result = session.run(count_query, **params) - count_record = count_result.single() - expected_count = 0 - if count_record: - expected_count = count_record["node_count"] or 0 - - # Delete nodes - session.run(delete_query, **params) - # Use the count from before deletion as the actual deleted count - deleted_count = expected_count - + # Execute delete query + result = session.run(delete_query, **params) + # Consume the result to ensure deletion completes and get the summary + summary = result.consume() + # Get the count from the result summary + deleted_count = summary.counters.nodes_deleted if summary.counters else 0 + + elapsed_time = time.time() - batch_start_time + logger.info( + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {deleted_count} nodes" + ) except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) raise @@ -1884,3 +1876,40 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise + + def exist_user_name(self, user_name: str) -> dict[str, bool]: + """Check if user name exists in the graph. + + Args: + user_name: User name to check. + + Returns: + dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence. + """ + logger.info(f"[exist_user_name] Querying user_name {user_name}") + if not user_name: + return {user_name: False} + + try: + with self.driver.session(database=self.db_name) as session: + # Query to check if user_name exists + query = """ + MATCH (n:Memory) + WHERE n.user_name = $user_name + RETURN COUNT(n) AS count + """ + logger.info(f"[exist_user_name] query: {query}") + + result = session.run(query, user_name=user_name) + count = result.single()["count"] + result_dict = {user_name: count > 0} + + logger.info( + f"[exist_user_name] user_name {user_name} exists: {result_dict[user_name]}" + ) + return result_dict + except Exception as e: + logger.error( + f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b0a8bc4be..d1c2716c8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5316,3 +5316,52 @@ def escape_memory_id(mid: str) -> str: raise finally: self._return_connection(conn) + + def exist_user_name(self, user_name: str) -> dict[str, bool]: + """Check if user name exists in the graph. + + Args: + user_name: User name to check. + + Returns: + dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence. + """ + logger.info(f"[exist_user_name] Querying user_name {user_name}") + if not user_name: + return {user_name: False} + + # Escape special characters for JSON string format in agtype + def escape_user_name(un: str) -> str: + """Escape special characters in user_name for JSON string format.""" + # Escape backslashes first, then double quotes + un_str = un.replace("\\", "\\\\") + un_str = un_str.replace('"', '\\"') + return un_str + + # Escape special characters + escaped_un = escape_user_name(user_name) + + # Query to check if user_name exists + query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{escaped_un}\"'::agtype + """ + logger.info(f"[exist_user_name] query: {query}") + result_dict = {} + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + count = cursor.fetchone()[0] + result = count > 0 + result_dict[user_name] = result + return result_dict + except Exception as e: + logger.error( + f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn)