diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index de300c193..58c06ff18 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -102,7 +102,10 @@ jobs: if: ${{ !startsWith(matrix.os, 'macos-13') }} run: | poetry install --no-interaction --extras all - - name: PyTest unit tests + - name: PyTest unit tests with coverage if: ${{ !startsWith(matrix.os, 'macos-13') }} run: | - poetry run pytest tests -vv --durations=10 + poetry run pytest tests -vv --durations=10 \ + --cov=src/memos \ + --cov-report=term-missing \ + --cov-fail-under=28 diff --git a/.gitignore b/.gitignore index ece7e45ba..7d1be5a25 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,8 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ +report/ +cov-report/ .tox/ .nox/ .coverage diff --git a/Makefile b/Makefile index 57ede5838..788504a73 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test +.PHONY: test test-report test-cov install: poetry install --extras all --with dev --with test @@ -9,10 +9,25 @@ clean: rm -rf .pytest_cache rm -rf .ruff_cache rm -rf tmp + rm -rf report cov-report + rm -f .coverage .coverage.* test: poetry run pytest tests +test-report: + poetry run pytest tests -vv --durations=10 \ + --html=report/index.html \ + --cov=src/memos \ + --cov-report=term-missing \ + --cov-report=html:cov-report/src + +test-cov: + poetry run pytest tests \ + --cov=src/memos \ + --cov-report=term-missing \ + --cov-report=html:cov-report/src + format: poetry run ruff check --fix poetry run ruff format diff --git a/pyproject.toml b/pyproject.toml index 9f17c0000..ff7c9699a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,8 @@ optional = true [tool.poetry.group.test.dependencies] pytest = "^8.3.5" pytest-asyncio = "^0.23.5" +pytest-cov = "^6.1" +pytest-html = "^4.2" ruff = "^0.11.8" [tool.poetry.group.eval] @@ -208,6 +210,23 @@ filterwarnings = [ ] +[tool.coverage.run] +source = ["src/memos"] +branch = true + +[tool.coverage.report] +show_missing = true +skip_empty = true +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "if __name__ == .__main__.", +] + +[tool.coverage.html] +directory = "cov-report" + + [tool.ruff] ############################################################################## # Ruff is a fast Python linter and formatter. diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 06aa50c65..87f1efd8e 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -321,23 +321,40 @@ def get_activation_config() -> dict[str, Any]: @staticmethod def get_memreader_config() -> dict[str, Any]: - """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model).""" - return { - "backend": "openai", - "config": { - "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), - "temperature": 0.6, - "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")), - "top_p": 0.95, - "top_k": 20, - "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), - # Default to OpenAI base URL when env var is not provided to satisfy pydantic - # validation requirements during tests/import. - "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"), - "remove_think_prefix": True, - }, + """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model). + + When MEMREADER_GENERAL_MODEL is configured (i.e. a separate stable LLM exists), + the backup client is automatically enabled so that primary failures (self-deployed + model) fall back to the general LLM. + """ + config = { + "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), + "temperature": 0.6, + "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")), + "top_p": 0.95, + "top_k": 20, + "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), + # Default to OpenAI base URL when env var is not provided to satisfy pydantic + # validation requirements during tests/import. + "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"), + "remove_think_prefix": True, } + general_model = os.getenv("MEMREADER_GENERAL_MODEL") + enable_backup = os.getenv("MEMREADER_ENABLE_BACKUP", "false").lower() == "true" + if general_model and enable_backup: + config["backup_client"] = True + config["backup_model_name_or_path"] = general_model + config["backup_api_key"] = os.getenv( + "MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY") + ) + config["backup_api_base"] = os.getenv( + "MEMREADER_GENERAL_API_BASE", + os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + ) + + return {"backend": "openai", "config": config} + @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. @@ -837,7 +854,7 @@ def get_scheduler_config() -> dict[str, Any]: ), "context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")), "thread_pool_max_workers": int( - os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10000") + os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "200") ), "consume_interval_seconds": float( os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01") @@ -850,6 +867,8 @@ def get_scheduler_config() -> dict[str, Any]: "MOS_SCHEDULER_ENABLE_ACTIVATION_MEMORY", "false" ).lower() == "true", + "use_redis_queue": os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() + == "true", }, } diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index aa2525878..7894ff7dc 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -233,7 +233,7 @@ def init_server() -> dict[str, Any]: searcher: Searcher = tree_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", moscube=False, - process_llm=mem_reader.llm, + process_llm=mem_reader.general_llm, ) logger.debug("Searcher created") @@ -255,12 +255,13 @@ def init_server() -> dict[str, Any]: # Initialize Scheduler scheduler_config_dict = APIConfig.get_scheduler_config() scheduler_config = SchedulerConfigFactory( - backend="optimized_scheduler", config=scheduler_config_dict + backend=scheduler_config_dict["backend"], + config=scheduler_config_dict["config"], ) mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) mem_scheduler.initialize_modules( chat_llm=llm, - process_llm=mem_reader.llm, + process_llm=mem_reader.general_llm, db_engine=BaseDBManager.create_default_sqlite_engine(), mem_reader=mem_reader, redis_client=redis_client, diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index 5487d117c..11c39b33c 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -28,6 +28,22 @@ class OpenAILLMConfig(BaseLLMConfig): default="https://api.openai.com/v1", description="Base URL for OpenAI API" ) extra_body: Any = Field(default=None, description="extra body") + backup_client: bool = Field( + default=False, + description="Whether to enable backup client for fallback on primary failure", + ) + backup_api_key: str | None = Field( + default=None, description="API key for backup OpenAI-compatible endpoint" + ) + backup_api_base: str | None = Field( + default=None, description="Base URL for backup OpenAI-compatible endpoint" + ) + backup_model_name_or_path: str | None = Field( + default=None, description="Model name for backup endpoint" + ) + backup_headers: dict[str, Any] | None = Field( + default=None, description="Default headers for backup client requests" + ) class OpenAIResponsesLLMConfig(BaseLLMConfig): @@ -42,22 +58,18 @@ class OpenAIResponsesLLMConfig(BaseLLMConfig): ) -class QwenLLMConfig(BaseLLMConfig): - api_key: str = Field(..., description="API key for DashScope (Qwen)") +class QwenLLMConfig(OpenAILLMConfig): api_base: str = Field( default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", description="Base URL for Qwen OpenAI-compatible API", ) - extra_body: Any = Field(default=None, description="extra body") -class DeepSeekLLMConfig(BaseLLMConfig): - api_key: str = Field(..., description="API key for DeepSeek") +class DeepSeekLLMConfig(OpenAILLMConfig): api_base: str = Field( default="https://api.deepseek.com", description="Base URL for DeepSeek OpenAI-compatible API", ) - extra_body: Any = Field(default=None, description="Extra options for API") class AzureLLMConfig(BaseLLMConfig): diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 9807f42c3..f76ddecc4 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -155,7 +155,10 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "SchedulerConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + raw = self.config + if isinstance(raw, dict) and "config" in raw and "use_redis_queue" not in raw: + raw = raw["config"] + self.config = config_class(**raw) return self diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 470d8cd8e..283e15115 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -721,6 +721,8 @@ def get_by_metadata( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + user_name_flag: bool = True, + status: str | None = None, ) -> list[str]: """ Retrieve node IDs that match given metadata filters. @@ -745,15 +747,20 @@ def get_by_metadata( - Can be used for faceted recall or prefiltering before embedding rerank. """ logger.info( - f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}" ) print( - f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}" + f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}" ) user_name = user_name if user_name else self.config.user_name where_clauses = [] params = {} + # Add status filter if provided + if status: + where_clauses.append("n.status = $status") + params["status"] = status + for i, f in enumerate(filters): field = f["field"] op = f.get("op", "=") diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index d740ad1d2..6db31990d 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4667,6 +4667,36 @@ def build_filter_condition(condition_dict: dict) -> str: condition_parts.append( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text LIKE '%{op_value}%'" ) + elif op == "nolike": + if key.startswith("info."): + info_field = key[5:] + if isinstance(op_value, str): + escaped_value = ( + escape_sql_string(op_value) + .replace("%", "\\%") + .replace("_", "\\_") + ) + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text NOT LIKE '%{escaped_value}%'" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text NOT LIKE '%{op_value}%'" + ) + else: + if isinstance(op_value, str): + escaped_value = ( + escape_sql_string(op_value) + .replace("%", "\\%") + .replace("_", "\\_") + ) + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text NOT LIKE '%{escaped_value}%'" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text NOT LIKE '%{op_value}%'" + ) # Check if key starts with "info." prefix (for simple equality) elif key.startswith("info."): # Extract the field name after "info." @@ -4756,6 +4786,7 @@ def parse_filter( "project_id", "manager_user_id", "delete_time", + "related_id", } def process_condition(condition): diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 93dac42fb..f6bb4efc1 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -27,7 +27,39 @@ def __init__(self, config: OpenAILLMConfig): self.client = openai.Client( api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers ) - logger.info("OpenAI LLM instance initialized") + self.use_backup_client = config.backup_client + if self.use_backup_client: + self.backup_client = openai.Client( + api_key=config.backup_api_key, + base_url=config.backup_api_base, + default_headers=config.backup_headers, + ) + logger.info( + f"OpenAI LLM instance initialized with backup " + f"(model={config.backup_model_name_or_path})" + ) + else: + self.backup_client = None + logger.info("OpenAI LLM instance initialized") + + def _parse_response(self, response) -> str: + """Extract text content from a chat completion response.""" + if not response.choices: + logger.warning("OpenAI response has no choices") + return "" + + tool_calls = getattr(response.choices[0].message, "tool_calls", None) + if isinstance(tool_calls, list) and len(tool_calls) > 0: + return self.tool_call_parser(tool_calls) + response_content = response.choices[0].message.content + reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) + if isinstance(reasoning_content, str) and reasoning_content: + reasoning_content = f"{reasoning_content}" + if self.config.remove_think_prefix: + return remove_thinking_tags(response_content) + if reasoning_content: + return reasoning_content + (response_content or "") + return response_content or "" @timed_with_status( log_prefix="OpenAI LLM", @@ -50,29 +82,32 @@ def generate(self, messages: MessageList, **kwargs) -> str: start_time = time.perf_counter() logger.info(f"OpenAI LLM Request body: {request_body}") - response = self.client.chat.completions.create(**request_body) - - cost_time = time.perf_counter() - start_time - logger.info( - f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}" - ) - - if not response.choices: - logger.warning("OpenAI response has no choices") - return "" - - tool_calls = getattr(response.choices[0].message, "tool_calls", None) - if isinstance(tool_calls, list) and len(tool_calls) > 0: - return self.tool_call_parser(tool_calls) - response_content = response.choices[0].message.content - reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) - if isinstance(reasoning_content, str) and reasoning_content: - reasoning_content = f"{reasoning_content}" - if self.config.remove_think_prefix: - return remove_thinking_tags(response_content) - if reasoning_content: - return reasoning_content + (response_content or "") - return response_content or "" + try: + response = self.client.chat.completions.create(**request_body) + cost_time = time.perf_counter() - start_time + logger.info( + f"Request body: {request_body}, Response from OpenAI: " + f"{response.model_dump_json()}, Cost time: {cost_time}" + ) + return self._parse_response(response) + except Exception as e: + if not self.use_backup_client: + raise + logger.warning( + f"Primary LLM request failed with {type(e).__name__}: {e}, " + f"falling back to backup client" + ) + backup_body = { + **request_body, + "model": self.config.backup_model_name_or_path or request_body["model"], + } + backup_response = self.backup_client.chat.completions.create(**backup_body) + cost_time = time.perf_counter() - start_time + logger.info( + f"Backup LLM request succeeded, Response: " + f"{backup_response.model_dump_json()}, Cost time: {cost_time}" + ) + return self._parse_response(backup_response) @timed_with_status( log_prefix="OpenAI LLM Stream", diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index b8019004d..135058a7d 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -243,7 +243,6 @@ def _single_add_operation( datetime.now().isoformat() ) to_add_memory.metadata.background = new_memory_item.metadata.background - to_add_memory.metadata.sources = [] added_ids = self._retry_db_operation( lambda: self.memory_manager.add([to_add_memory], user_name=user_name, use_batch=False) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 22cd0e9cb..54f8f01e0 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -132,7 +132,7 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler: # Configure scheduler general_modules self._mem_scheduler.initialize_modules( chat_llm=self.chat_llm, - process_llm=self.mem_reader.llm, + process_llm=self.mem_reader.general_llm, db_engine=self.user_manager.engine, ) self._mem_scheduler.start() diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 4c0d4dcd0..2745a1bee 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -982,6 +982,9 @@ def _process_multi_modal_data( # Use MultiModalParser to parse the scene data # If it's a list, parse each item; otherwise parse as single message if isinstance(scene_data_info, list): + # Pre-expand multimodal messages + expanded_messages = self._expand_multimodal_messages(scene_data_info) + # Parse each message in the list all_memory_items = [] # Use thread pool to parse each message in parallel, but keep the original order @@ -996,7 +999,7 @@ def _process_multi_modal_data( need_emb=False, **kwargs, ) - for msg in scene_data_info + for msg in expanded_messages ] # collect results in original order for future in futures: @@ -1014,20 +1017,23 @@ def _process_multi_modal_data( if mode == "fast": return fast_memory_items else: + non_file_url_fast_items = [ + item for item in fast_memory_items if not self._is_file_url_only_item(item) + ] + # Part A: call llm in parallel using thread pool fine_memory_items = [] with ContextThreadPoolExecutor(max_workers=4) as executor: future_string = executor.submit( - self._process_string_fine, fast_memory_items, info, custom_tags, **kwargs + self._process_string_fine, non_file_url_fast_items, info, custom_tags, **kwargs ) future_tool = executor.submit( - self._process_tool_trajectory_fine, fast_memory_items, info, **kwargs + self._process_tool_trajectory_fine, non_file_url_fast_items, info, **kwargs ) - # Use general_llm for skill memory extraction (not fine-tuned for this task) future_skill = executor.submit( process_skill_memory_fine, - fast_memory_items=fast_memory_items, + fast_memory_items=non_file_url_fast_items, info=info, searcher=self.searcher, graph_db=self.graph_db, @@ -1039,7 +1045,7 @@ def _process_multi_modal_data( ) future_pref = executor.submit( process_preference_fine, - fast_memory_items, + non_file_url_fast_items, info, self.llm, self.embedder, @@ -1094,19 +1100,21 @@ def _process_transfer_multi_modal_data( **(raw_nodes[0].metadata.info or {}), } + # Filter out file-URL-only items for Part A fine processing (same as _process_multi_modal_data) + non_file_url_nodes = [node for node in raw_nodes if not self._is_file_url_only_item(node)] + fine_memory_items = [] # Part A: call llm in parallel using thread pool with ContextThreadPoolExecutor(max_workers=4) as executor: future_string = executor.submit( - self._process_string_fine, raw_nodes, info, custom_tags, **kwargs + self._process_string_fine, non_file_url_nodes, info, custom_tags, **kwargs ) future_tool = executor.submit( - self._process_tool_trajectory_fine, raw_nodes, info, **kwargs + self._process_tool_trajectory_fine, non_file_url_nodes, info, **kwargs ) - # Use general_llm for skill memory extraction (not fine-tuned for this task) future_skill = executor.submit( process_skill_memory_fine, - raw_nodes, + non_file_url_nodes, info, searcher=self.searcher, llm=self.general_llm, @@ -1118,7 +1126,7 @@ def _process_transfer_multi_modal_data( ) # Add preference memory extraction future_pref = executor.submit( - process_preference_fine, raw_nodes, info, self.general_llm, self.embedder, **kwargs + process_preference_fine, non_file_url_nodes, info, self.llm, self.embedder, **kwargs ) # Collect results @@ -1148,6 +1156,90 @@ def _process_transfer_multi_modal_data( fine_memory_items.extend(items) return fine_memory_items + @staticmethod + def _expand_multimodal_messages(messages: list) -> list: + """ + Expand messages whose ``content`` is a list into individual + sub-messages so that each modality is routed to its specialised + parser during fast-mode parsing. + + For a message like:: + + { + "content": [ + {"type": "text", "text": "Analyze this file"}, + {"type": "file", "file": {"file_data": "https://...", ...}}, + {"type": "image_url", "image_url": {"url": "https://..."}}, + ], + "role": "user", + "chat_time": "03:14 PM on 13 March, 2026", + } + + The result will be:: + + [ + {"content": "Analyze this file", "role": "user", "chat_time": "..."}, + {"type": "file", "file": {"file_data": "https://...", ...}}, + {"type": "image_url", "image_url": {"url": "https://..."}}, + ] + + Messages whose ``content`` is already a plain string (or that are + not dicts) are passed through unchanged. + """ + expanded: list = [] + for msg in messages: + if not isinstance(msg, dict): + expanded.append(msg) + continue + + content = msg.get("content") + if not isinstance(content, list): + expanded.append(msg) + continue + + # ---- content is a list: split by modality ---- + text_parts: list[str] = [] + for part in content: + if not isinstance(part, dict): + text_parts.append(str(part)) + continue + + part_type = part.get("type", "") + if part_type == "text": + text_parts.append(part.get("text", "")) + elif part_type in ("file", "image", "image_url"): + # Extract as a standalone message for its specialised parser + expanded.append(part) + else: + text_parts.append(f"[{part_type}]") + + # Reconstruct a text-only version of the original message + # (preserving role, chat_time, message_id, etc.) + text_content = "\n".join(t for t in text_parts if t.strip()) + if text_content.strip(): + text_msg = {k: v for k, v in msg.items() if k != "content"} + text_msg["content"] = text_content + expanded.append(text_msg) + + return expanded + + @staticmethod + def _is_file_url_only_item(item: TextualMemoryItem) -> bool: + """ + Check if a fast memory item contains only file-URL sources. + Args: + item: TextualMemoryItem to check + + Returns: + True if all sources are file-type with URL info (metadata only) + """ + sources = item.metadata.sources or [] + if not sources: + return False + return all( + getattr(s, "type", None) == "file" and getattr(s, "file_info", None) for s in sources + ) + def get_scene_data_info(self, scene_data: list, type: str) -> list[list[Any]]: """ Convert normalized MessagesType scenes into scene data info. diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 8777b9f2e..ec431c253 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -305,7 +305,7 @@ def init_components() -> dict[str, Any]: searcher: Searcher = tree_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", moscube=False, - process_llm=mem_reader.llm, + process_llm=mem_reader.general_llm, ) # Initialize feedback server feedback_server = SimpleMemFeedback( diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 6df410c19..6a91f436f 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -617,7 +617,7 @@ def add_before_search( # 3. Call LLM try: - raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}]) + raw = self.mem_reader.general_llm.generate([{"role": "user", "content": prompt}]) success, parsed_result = parse_keep_filter_response(raw) if not success: diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py index 6562c9a95..f3d4549b5 100644 --- a/tests/configs/test_llm.py +++ b/tests/configs/test_llm.py @@ -56,6 +56,11 @@ def test_openai_llm_config(): "remove_think_prefix", "extra_body", "default_headers", + "backup_client", + "backup_api_key", + "backup_api_base", + "backup_model_name_or_path", + "backup_headers", ], )