Skip to content

Commit cca4745

Browse files
authored
feat:optimzie search_by_embedding && logs (#1284)
* fix: optimize search_by_embedding * fix: optimize search_by_embedding * feat:optimize get_grouped_counts * feat:optimize get_grouped_counts * feat:optimize get_by_metadata * fix: remove self._refresh_memory_size
1 parent 48f7320 commit cca4745

3 files changed

Lines changed: 54 additions & 81 deletions

File tree

src/memos/graph_dbs/polardb.py

Lines changed: 53 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,13 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
432432
def remove_oldest_memory(
433433
self, memory_type: str, keep_latest: int, user_name: str | None = None
434434
) -> None:
435-
"""
436-
Remove all WorkingMemory nodes except the latest `keep_latest` entries.
437-
438-
Args:
439-
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
440-
keep_latest (int): Number of latest WorkingMemory entries to keep.
441-
user_name (str, optional): User name for filtering in non-multi-db mode
442-
"""
435+
start_time = time.perf_counter()
436+
logger.info(
437+
"remove_oldest_memory by memory_type:%s,keep_latest: %s,user_name:%s",
438+
memory_type,
439+
keep_latest,
440+
user_name,
441+
)
443442
user_name = user_name if user_name else self._get_config_value("user_name")
444443

445444
# Use actual OFFSET logic, consistent with nebular.py
@@ -456,6 +455,9 @@ def remove_oldest_memory(
456455
self.format_param_value(user_name),
457456
keep_latest,
458457
]
458+
logger.info(
459+
f"remove_oldest_memory by select_query:{select_query},select_params:{select_params}"
460+
)
459461
try:
460462
with self._get_connection() as conn, conn.cursor() as cursor:
461463
# Execute query to get IDs to delete
@@ -482,6 +484,8 @@ def remove_oldest_memory(
482484
f"keeping {keep_latest} latest for user {user_name}, "
483485
f"removed ids: {ids_to_delete}"
484486
)
487+
elapsed = (time.perf_counter() - start_time) * 1000.0
488+
logger.info("remove_oldest_memory internal took %.1f ms", elapsed)
485489
except Exception as e:
486490
logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True)
487491
raise
@@ -1840,9 +1844,8 @@ def search_by_embedding(
18401844
**kwargs,
18411845
) -> list[dict]:
18421846
logger.info(
1843-
"search_by_embedding user_name:%s,filter: %s, knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s",
1847+
"search_by_embedding by user_name:%s,knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s",
18441848
user_name,
1845-
filter,
18461849
knowledgebase_ids,
18471850
scope,
18481851
status,
@@ -1895,20 +1898,21 @@ def search_by_embedding(
18951898
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
18961899

18971900
query = f"""
1901+
set hnsw.ef_search = 100;set hnsw.iterative_scan = relaxed_order;
18981902
WITH t AS (
18991903
SELECT id,
19001904
properties,
19011905
timeline,
19021906
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
1903-
(1 - (embedding <=> %s::vector(1024))) AS scope
1907+
(embedding <=> %s::vector(1024)) AS scope_distance
19041908
FROM "{self.db_name}_graph"."Memory"
19051909
{where_clause}
1906-
ORDER BY scope DESC
1910+
ORDER BY scope_distance ASC
19071911
LIMIT {top_k}
19081912
)
1909-
SELECT *
1913+
SELECT *,(1 - scope_distance) AS scope
19101914
FROM t
1911-
WHERE scope > 0.1;
1915+
WHERE scope_distance < 0.9;
19121916
"""
19131917
vector_str = convert_to_vector(vector)
19141918
query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)")
@@ -1953,7 +1957,7 @@ def search_by_embedding(
19531957
output.append(item)
19541958
elapsed_time = (time.perf_counter() - start_time) * 1000.0
19551959
logger.info(
1956-
"search_by_embedding query embedding completed time took %.1f ms", elapsed_time
1960+
"search_by_embedding query by embedding completed time took %.1f ms", elapsed_time
19571961
)
19581962
return output[:top_k]
19591963

@@ -1966,57 +1970,34 @@ def get_by_metadata(
19661970
knowledgebase_ids: list | None = None,
19671971
user_name_flag: bool = True,
19681972
) -> list[str]:
1969-
"""
1970-
Retrieve node IDs that match given metadata filters.
1971-
Supports exact match.
1972-
1973-
Args:
1974-
filters: List of filter dicts like:
1975-
[
1976-
{"field": "key", "op": "in", "value": ["A", "B"]},
1977-
{"field": "confidence", "op": ">=", "value": 80},
1978-
{"field": "tags", "op": "contains", "value": "AI"},
1979-
...
1980-
]
1981-
user_name (str, optional): User name for filtering in non-multi-db mode
1982-
1983-
Returns:
1984-
list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
1985-
"""
1973+
start_time = time.perf_counter()
19861974
logger.info(
19871975
f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters}"
19881976
)
19891977

19901978
user_name = user_name if user_name else self._get_config_value("user_name")
19911979

1992-
# Build WHERE conditions for cypher query
19931980
where_conditions = []
19941981

19951982
for f in filters:
19961983
field = f["field"]
19971984
op = f.get("op", "=")
19981985
value = f["value"]
19991986

2000-
# Format value
20011987
if isinstance(value, str):
2002-
# Escape single quotes using backslash when inside $$ dollar-quoted strings
2003-
# In $$ delimiters, Cypher string literals can use \' to escape single quotes
20041988
escaped_str = value.replace("'", "\\'")
20051989
escaped_value = f"'{escaped_str}'"
20061990
elif isinstance(value, list):
2007-
# Handle list values - use double quotes for Cypher arrays
20081991
list_items = []
20091992
for v in value:
20101993
if isinstance(v, str):
2011-
# Escape double quotes in string values for Cypher
20121994
escaped_str = v.replace('"', '\\"')
20131995
list_items.append(f'"{escaped_str}"')
20141996
else:
20151997
list_items.append(str(v))
20161998
escaped_value = f"[{', '.join(list_items)}]"
20171999
else:
20182000
escaped_value = f"'{value}'" if isinstance(value, str) else str(value)
2019-
# Build WHERE conditions
20202001
if op == "=":
20212002
where_conditions.append(f"n.{field} = {escaped_value}")
20222003
elif op == "in":
@@ -2045,22 +2026,19 @@ def get_by_metadata(
20452026
knowledgebase_ids=knowledgebase_ids,
20462027
default_user_name=self._get_config_value("user_name"),
20472028
)
2048-
logger.info(f"[get_by_metadata] user_name_conditions: {user_name_conditions}")
2029+
logger.info(f"get_by_metadata user_name_conditions: {user_name_conditions}")
20492030

2050-
# Add user_name WHERE clause
20512031
if user_name_conditions:
20522032
if len(user_name_conditions) == 1:
20532033
where_conditions.append(user_name_conditions[0])
20542034
else:
20552035
where_conditions.append(f"({' OR '.join(user_name_conditions)})")
20562036

2057-
# Build filter conditions using common method
20582037
filter_where_clause = self._build_filter_conditions_cypher(filter)
2059-
logger.info(f"[get_by_metadata] filter_where_clause: {filter_where_clause}")
2038+
logger.info(f"get_by_metadata filter_where_clause: {filter_where_clause}")
20602039

20612040
where_str = " AND ".join(where_conditions) + filter_where_clause
20622041

2063-
# Use cypher query
20642042
cypher_query = f"""
20652043
SELECT * FROM cypher('{self.db_name}_graph', $$
20662044
MATCH (n:Memory)
@@ -2070,15 +2048,16 @@ def get_by_metadata(
20702048
"""
20712049

20722050
ids = []
2073-
logger.info(f"[get_by_metadata] cypher_query: {cypher_query}")
2051+
logger.info(f"get_by_metadata cypher_query: {cypher_query}")
20742052
try:
20752053
with self._get_connection() as conn, conn.cursor() as cursor:
20762054
cursor.execute(cypher_query)
20772055
results = cursor.fetchall()
20782056
ids = [str(item[0]).strip('"') for item in results]
20792057
except Exception as e:
20802058
logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}")
2081-
2059+
elapsed = (time.perf_counter() - start_time) * 1000.0
2060+
logger.info("get_by_metadata internal took %.1f ms", elapsed)
20822061
return ids
20832062

20842063
@timed
@@ -2165,25 +2144,19 @@ def get_grouped_counts(
21652144
params: dict[str, Any] | None = None,
21662145
user_name: str | None = None,
21672146
) -> list[dict[str, Any]]:
2168-
"""
2169-
Count nodes grouped by any fields.
2170-
2171-
Args:
2172-
group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
2173-
where_clause (str, optional): Extra WHERE condition. E.g.,
2174-
"WHERE n.status = 'activated'"
2175-
params (dict, optional): Parameters for WHERE clause.
2176-
user_name (str, optional): User name for filtering in non-multi-db mode
2177-
2178-
Returns:
2179-
list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
2180-
"""
2147+
start_time = time.perf_counter()
2148+
logger.info(
2149+
"get_grouped_counts by group_fields:%s,where_clause: %s,params:%s,user_name:%s",
2150+
group_fields,
2151+
where_clause,
2152+
params,
2153+
user_name,
2154+
)
21812155
if not group_fields:
21822156
raise ValueError("group_fields cannot be empty")
21832157

21842158
user_name = user_name if user_name else self._get_config_value("user_name")
21852159

2186-
# Build user clause
21872160
user_clause = f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
21882161
if where_clause:
21892162
where_clause = where_clause.strip()
@@ -2194,44 +2167,43 @@ def get_grouped_counts(
21942167
else:
21952168
where_clause = f"WHERE {user_clause}"
21962169

2197-
# Inline parameters if provided
21982170
if params and isinstance(params, dict):
21992171
for key, value in params.items():
2200-
# Handle different value types appropriately
22012172
if isinstance(value, str):
22022173
value = f"'{value}'"
22032174
where_clause = where_clause.replace(f"${key}", str(value))
22042175

2205-
# Handle user_name parameter in where_clause
22062176
if "user_name = %s" in where_clause:
22072177
where_clause = where_clause.replace(
22082178
"user_name = %s",
22092179
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype",
22102180
)
22112181

2212-
# Build return fields and group by fields
2213-
return_fields = []
2214-
group_by_fields = []
2215-
2182+
cte_select_list = []
2183+
aliases = []
22162184
for field in group_fields:
22172185
alias = field.replace(".", "_")
2218-
return_fields.append(
2219-
f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text AS {alias}"
2220-
)
2221-
group_by_fields.append(
2222-
f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text"
2186+
aliases.append(alias)
2187+
cte_select_list.append(
2188+
f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype) AS {alias}"
22232189
)
2224-
2225-
# Full SQL query construction
2190+
outer_select = ", ".join(f"{a}::text" for a in aliases)
2191+
outer_group_by = ", ".join(aliases)
22262192
query = f"""
2227-
SELECT {", ".join(return_fields)}, COUNT(*) AS count
2228-
FROM "{self.db_name}_graph"."Memory"
2229-
{where_clause}
2230-
GROUP BY {", ".join(group_by_fields)}
2193+
WITH t AS (
2194+
SELECT {", ".join(cte_select_list)}
2195+
FROM "{self.db_name}_graph"."Memory"
2196+
{where_clause}
2197+
LIMIT 1000
2198+
)
2199+
SELECT {outer_select}, count(*) AS count
2200+
FROM t
2201+
GROUP BY {outer_group_by}
22312202
"""
2203+
logger.info(f"get_grouped_counts query:{query},params:{params}")
2204+
22322205
try:
22332206
with self._get_connection() as conn, conn.cursor() as cursor:
2234-
# Handle parameterized query
22352207
if params and isinstance(params, list):
22362208
cursor.execute(query, params)
22372209
else:
@@ -2250,6 +2222,8 @@ def get_grouped_counts(
22502222
count_value = row[-1] # Last column is count
22512223
output.append({**group_values, "count": int(count_value)})
22522224

2225+
elapsed = (time.perf_counter() - start_time) * 1000.0
2226+
logger.info("get_grouped_counts internal took %.1f ms", elapsed)
22532227
return output
22542228

22552229
except Exception as e:

src/memos/mem_scheduler/schemas/general_schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300
2121
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2
2222
DEFAULT_STUCK_THREAD_TOLERANCE = 10
23-
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1
23+
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 200
2424
DEFAULT_TOP_K = 5
2525
DEFAULT_CONTEXT_WINDOW_SIZE = 5
2626
DEFAULT_USE_REDIS_QUEUE = os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true"

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def add(
114114

115115
if mode == "sync":
116116
self._cleanup_working_memory(user_name)
117-
self._refresh_memory_size(user_name=user_name)
118117

119118
return added_ids
120119

0 commit comments

Comments
 (0)