Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 105 additions & 76 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,124 +1707,116 @@ def delete_node_by_prams(
) -> int:
"""
Delete nodes by memory_ids, file_ids, or filter.
Supports three scenarios:
1. Delete by memory_ids (standalone)
2. Delete by writable_cube_ids + file_ids (combined)
3. Delete by filter (standalone, no writable_cube_ids needed)

Args:
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
If not provided, no user_name filter will be applied.
Only used with file_ids scenario. If not provided, no user_name filter will be applied.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
file_ids (list[str], optional): List of file node IDs to delete. Must be used with writable_cube_ids.
filter (dict, optional): Filter dictionary for metadata filtering.
Filter conditions are directly used in DELETE WHERE clause without pre-querying.
Does not require writable_cube_ids.

Returns:
int: Number of nodes deleted.
"""
batch_start_time = time.time()
logger.info(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)
print(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Build WHERE conditions separately for memory_ids and file_ids
where_clauses = []
params = {}

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
# Only add user_name filter if writable_cube_ids is provided
# Only add user_name filter if writable_cube_ids is provided (for file_ids scenario)
user_name_conditions = []
params = {}
if writable_cube_ids and len(writable_cube_ids) > 0:
for idx, cube_id in enumerate(writable_cube_ids):
param_name = f"cube_id_{idx}"
user_name_conditions.append(f"n.user_name = ${param_name}")
params[param_name] = cube_id

# Handle memory_ids: query n.id
if memory_ids and len(memory_ids) > 0:
# Build filter conditions using common method (no query, direct use in WHERE clause)
filter_conditions = []
filter_params = {}
if filter:
filter_conditions, filter_params = self._build_filter_conditions_cypher(
filter, param_counter_start=0, node_alias="n"
)
logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}")
params.update(filter_params)

# If no conditions to delete, return 0
if not memory_ids and not file_ids and not filter_conditions:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
return 0

# Build WHERE conditions list
where_clauses = []

# Scenario 1: memory_ids (standalone)
if memory_ids:
logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids")
where_clauses.append("n.id IN $memory_ids")
params["memory_ids"] = memory_ids

# Handle file_ids: query n.file_ids field
# All file_ids must be present in the array field (AND relationship)
if file_ids and len(file_ids) > 0:
file_id_and_conditions = []
# Scenario 2: file_ids + writable_cube_ids (combined)
if file_ids:
logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids")
file_id_conditions = []
for idx, file_id in enumerate(file_ids):
param_name = f"file_id_{idx}"
params[param_name] = file_id
# Check if this file_id is in the file_ids array field
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
if file_id_and_conditions:
# Use AND to require all file_ids to be present
where_clauses.append(f"({' OR '.join(file_id_and_conditions)})")

# Query nodes by filter if provided
filter_ids = []
if filter:
# Use get_by_metadata with empty filters list and filter
filter_ids = self.get_by_metadata(
filters=[],
user_name=None,
filter=filter,
knowledgebase_ids=writable_cube_ids if writable_cube_ids else None,
)
file_id_conditions.append(f"${param_name} IN n.file_ids")
if file_id_conditions:
where_clauses.append(f"({' OR '.join(file_id_conditions)})")

# If filter returned IDs, add condition for them
if filter_ids:
where_clauses.append("n.id IN $filter_ids")
params["filter_ids"] = filter_ids
# Scenario 3: filter (standalone, no writable_cube_ids needed)
if filter_conditions:
logger.info("[delete_node_by_prams] Processing filter conditions")
# Combine filter conditions with AND
filter_where = " AND ".join(filter_conditions)
where_clauses.append(f"({filter_where})")

# If no conditions (except user_name), return 0
# Build final WHERE clause
if not where_clauses:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
logger.warning("[delete_node_by_prams] No WHERE conditions to delete")
return 0

# Build WHERE clause
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])
# Combine all conditions with AND
data_conditions = " AND ".join([f"({clause})" for clause in where_clauses])

# Build final WHERE clause
# If user_name_conditions exist, combine with data_conditions using AND
# Otherwise, use only data_conditions
# Add user_name filter if provided (for file_ids scenario)
if user_name_conditions:
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"
final_where = f"({user_name_where}) AND ({data_conditions})"
else:
ids_where = f"({data_conditions})"

logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
print(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
final_where = data_conditions

# First count matching nodes to get accurate count
count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count"
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
print(f"[delete_node_by_prams] count_query: {count_query}")

# Then delete nodes
delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
# Delete directly without pre-counting
delete_query = f"MATCH (n:Memory) WHERE {final_where} DETACH DELETE n"
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] params: {params}")

deleted_count = 0
try:
with self.driver.session(database=self.db_name) as session:
# Count nodes before deletion
count_result = session.run(count_query, **params)
count_record = count_result.single()
expected_count = 0
if count_record:
expected_count = count_record["node_count"] or 0

# Delete nodes
session.run(delete_query, **params)
# Use the count from before deletion as the actual deleted count
deleted_count = expected_count

# Execute delete query
result = session.run(delete_query, **params)
# Consume the result to ensure deletion completes and get the summary
summary = result.consume()
# Get the count from the result summary
deleted_count = summary.counters.nodes_deleted if summary.counters else 0

elapsed_time = time.time() - batch_start_time
logger.info(
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {deleted_count} nodes"
)
except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
raise
Expand Down Expand Up @@ -1884,3 +1876,40 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str |
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
)
raise

def exist_user_name(self, user_name: str) -> dict[str, bool]:
"""Check if user name exists in the graph.

Args:
user_name: User name to check.

Returns:
dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence.
"""
logger.info(f"[exist_user_name] Querying user_name {user_name}")
if not user_name:
return {user_name: False}

try:
with self.driver.session(database=self.db_name) as session:
# Query to check if user_name exists
query = """
MATCH (n:Memory)
WHERE n.user_name = $user_name
RETURN COUNT(n) AS count
"""
logger.info(f"[exist_user_name] query: {query}")

result = session.run(query, user_name=user_name)
count = result.single()["count"]
result_dict = {user_name: count > 0}

logger.info(
f"[exist_user_name] user_name {user_name} exists: {result_dict[user_name]}"
)
return result_dict
except Exception as e:
logger.error(
f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
)
raise
49 changes: 49 additions & 0 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5316,3 +5316,52 @@ def escape_memory_id(mid: str) -> str:
raise
finally:
self._return_connection(conn)

def exist_user_name(self, user_name: str) -> dict[str, bool]:
"""Check if user name exists in the graph.

Args:
user_name: User name to check.

Returns:
dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence.
"""
logger.info(f"[exist_user_name] Querying user_name {user_name}")
if not user_name:
return {user_name: False}

# Escape special characters for JSON string format in agtype
def escape_user_name(un: str) -> str:
"""Escape special characters in user_name for JSON string format."""
# Escape backslashes first, then double quotes
un_str = un.replace("\\", "\\\\")
un_str = un_str.replace('"', '\\"')
return un_str

# Escape special characters
escaped_un = escape_user_name(user_name)

# Query to check if user_name exists
query = f"""
SELECT COUNT(*)
FROM "{self.db_name}_graph"."Memory"
WHERE ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{escaped_un}\"'::agtype
"""
logger.info(f"[exist_user_name] query: {query}")
result_dict = {}
conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
cursor.execute(query)
count = cursor.fetchone()[0]
result = count > 0
result_dict[user_name] = result
return result_dict
except Exception as e:
logger.error(
f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
)
raise
finally:
self._return_connection(conn)