From 4a01dea59d3acee501b58f0911fe957ecd4ba0bc Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Mon, 16 Mar 2026 18:45:37 +0800 Subject: [PATCH 1/3] feat: add backend for openai --- src/memos/api/config.py | 47 +++++++++++++++------- src/memos/configs/llm.py | 24 ++++++++--- src/memos/llms/openai.py | 83 ++++++++++++++++++++++++++++----------- tests/configs/test_llm.py | 5 +++ 4 files changed, 114 insertions(+), 45 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 06aa50c65..f24e28559 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -321,23 +321,40 @@ def get_activation_config() -> dict[str, Any]: @staticmethod def get_memreader_config() -> dict[str, Any]: - """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model).""" - return { - "backend": "openai", - "config": { - "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), - "temperature": 0.6, - "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")), - "top_p": 0.95, - "top_k": 20, - "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), - # Default to OpenAI base URL when env var is not provided to satisfy pydantic - # validation requirements during tests/import. - "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"), - "remove_think_prefix": True, - }, + """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model). + + When MEMREADER_GENERAL_MODEL is configured (i.e. a separate stable LLM exists), + the backup client is automatically enabled so that primary failures (self-deployed + model) fall back to the general LLM. + """ + config = { + "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), + "temperature": 0.6, + "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")), + "top_p": 0.95, + "top_k": 20, + "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), + # Default to OpenAI base URL when env var is not provided to satisfy pydantic + # validation requirements during tests/import. + "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"), + "remove_think_prefix": True, } + general_model = os.getenv("MEMREADER_GENERAL_MODEL") + enable_backup = os.getenv("MEMREADER_ENABLE_BACKUP", "false").lower() == "true" + if general_model and enable_backup: + config["backup_client"] = True + config["backup_model_name_or_path"] = general_model + config["backup_api_key"] = os.getenv( + "MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY") + ) + config["backup_api_base"] = os.getenv( + "MEMREADER_GENERAL_API_BASE", + os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + ) + + return {"backend": "openai", "config": config} + @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index 5487d117c..11c39b33c 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -28,6 +28,22 @@ class OpenAILLMConfig(BaseLLMConfig): default="https://api.openai.com/v1", description="Base URL for OpenAI API" ) extra_body: Any = Field(default=None, description="extra body") + backup_client: bool = Field( + default=False, + description="Whether to enable backup client for fallback on primary failure", + ) + backup_api_key: str | None = Field( + default=None, description="API key for backup OpenAI-compatible endpoint" + ) + backup_api_base: str | None = Field( + default=None, description="Base URL for backup OpenAI-compatible endpoint" + ) + backup_model_name_or_path: str | None = Field( + default=None, description="Model name for backup endpoint" + ) + backup_headers: dict[str, Any] | None = Field( + default=None, description="Default headers for backup client requests" + ) class OpenAIResponsesLLMConfig(BaseLLMConfig): @@ -42,22 +58,18 @@ class OpenAIResponsesLLMConfig(BaseLLMConfig): ) -class QwenLLMConfig(BaseLLMConfig): - api_key: str = Field(..., description="API key for DashScope (Qwen)") +class QwenLLMConfig(OpenAILLMConfig): api_base: str = Field( default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", description="Base URL for Qwen OpenAI-compatible API", ) - extra_body: Any = Field(default=None, description="extra body") -class DeepSeekLLMConfig(BaseLLMConfig): - api_key: str = Field(..., description="API key for DeepSeek") +class DeepSeekLLMConfig(OpenAILLMConfig): api_base: str = Field( default="https://api.deepseek.com", description="Base URL for DeepSeek OpenAI-compatible API", ) - extra_body: Any = Field(default=None, description="Extra options for API") class AzureLLMConfig(BaseLLMConfig): diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 93dac42fb..f6bb4efc1 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -27,7 +27,39 @@ def __init__(self, config: OpenAILLMConfig): self.client = openai.Client( api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers ) - logger.info("OpenAI LLM instance initialized") + self.use_backup_client = config.backup_client + if self.use_backup_client: + self.backup_client = openai.Client( + api_key=config.backup_api_key, + base_url=config.backup_api_base, + default_headers=config.backup_headers, + ) + logger.info( + f"OpenAI LLM instance initialized with backup " + f"(model={config.backup_model_name_or_path})" + ) + else: + self.backup_client = None + logger.info("OpenAI LLM instance initialized") + + def _parse_response(self, response) -> str: + """Extract text content from a chat completion response.""" + if not response.choices: + logger.warning("OpenAI response has no choices") + return "" + + tool_calls = getattr(response.choices[0].message, "tool_calls", None) + if isinstance(tool_calls, list) and len(tool_calls) > 0: + return self.tool_call_parser(tool_calls) + response_content = response.choices[0].message.content + reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) + if isinstance(reasoning_content, str) and reasoning_content: + reasoning_content = f"{reasoning_content}" + if self.config.remove_think_prefix: + return remove_thinking_tags(response_content) + if reasoning_content: + return reasoning_content + (response_content or "") + return response_content or "" @timed_with_status( log_prefix="OpenAI LLM", @@ -50,29 +82,32 @@ def generate(self, messages: MessageList, **kwargs) -> str: start_time = time.perf_counter() logger.info(f"OpenAI LLM Request body: {request_body}") - response = self.client.chat.completions.create(**request_body) - - cost_time = time.perf_counter() - start_time - logger.info( - f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}" - ) - - if not response.choices: - logger.warning("OpenAI response has no choices") - return "" - - tool_calls = getattr(response.choices[0].message, "tool_calls", None) - if isinstance(tool_calls, list) and len(tool_calls) > 0: - return self.tool_call_parser(tool_calls) - response_content = response.choices[0].message.content - reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) - if isinstance(reasoning_content, str) and reasoning_content: - reasoning_content = f"{reasoning_content}" - if self.config.remove_think_prefix: - return remove_thinking_tags(response_content) - if reasoning_content: - return reasoning_content + (response_content or "") - return response_content or "" + try: + response = self.client.chat.completions.create(**request_body) + cost_time = time.perf_counter() - start_time + logger.info( + f"Request body: {request_body}, Response from OpenAI: " + f"{response.model_dump_json()}, Cost time: {cost_time}" + ) + return self._parse_response(response) + except Exception as e: + if not self.use_backup_client: + raise + logger.warning( + f"Primary LLM request failed with {type(e).__name__}: {e}, " + f"falling back to backup client" + ) + backup_body = { + **request_body, + "model": self.config.backup_model_name_or_path or request_body["model"], + } + backup_response = self.backup_client.chat.completions.create(**backup_body) + cost_time = time.perf_counter() - start_time + logger.info( + f"Backup LLM request succeeded, Response: " + f"{backup_response.model_dump_json()}, Cost time: {cost_time}" + ) + return self._parse_response(backup_response) @timed_with_status( log_prefix="OpenAI LLM Stream", diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py index 6562c9a95..f3d4549b5 100644 --- a/tests/configs/test_llm.py +++ b/tests/configs/test_llm.py @@ -56,6 +56,11 @@ def test_openai_llm_config(): "remove_think_prefix", "extra_body", "default_headers", + "backup_client", + "backup_api_key", + "backup_api_base", + "backup_model_name_or_path", + "backup_headers", ], ) From 9336027e8ddbbb503786af9f9c5204f30b32c621 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Tue, 17 Mar 2026 11:31:22 +0800 Subject: [PATCH 2/3] feat: change memreader for scheaduer --- src/memos/api/handlers/component_init.py | 4 ++-- src/memos/mem_os/core.py | 2 +- .../general_modules/init_components_for_scheduler.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index aa2525878..fc8ce311f 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -233,7 +233,7 @@ def init_server() -> dict[str, Any]: searcher: Searcher = tree_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", moscube=False, - process_llm=mem_reader.llm, + process_llm=mem_reader.general_llm, ) logger.debug("Searcher created") @@ -260,7 +260,7 @@ def init_server() -> dict[str, Any]: mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) mem_scheduler.initialize_modules( chat_llm=llm, - process_llm=mem_reader.llm, + process_llm=mem_reader.general_llm, db_engine=BaseDBManager.create_default_sqlite_engine(), mem_reader=mem_reader, redis_client=redis_client, diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 22cd0e9cb..54f8f01e0 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -132,7 +132,7 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler: # Configure scheduler general_modules self._mem_scheduler.initialize_modules( chat_llm=self.chat_llm, - process_llm=self.mem_reader.llm, + process_llm=self.mem_reader.general_llm, db_engine=self.user_manager.engine, ) self._mem_scheduler.start() diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 8777b9f2e..ec431c253 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -305,7 +305,7 @@ def init_components() -> dict[str, Any]: searcher: Searcher = tree_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", moscube=False, - process_llm=mem_reader.llm, + process_llm=mem_reader.general_llm, ) # Initialize feedback server feedback_server = SimpleMemFeedback( From 9c4448804ef1c2c438d53b393eb20adf5946cfc8 Mon Sep 17 00:00:00 2001 From: "chunyu li (fridayL)" Date: Tue, 17 Mar 2026 14:22:14 +0800 Subject: [PATCH 3/3] feat: change singlecube llm --- src/memos/multi_mem_cube/single_cube.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 6df410c19..6a91f436f 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -617,7 +617,7 @@ def add_before_search( # 3. Call LLM try: - raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}]) + raw = self.mem_reader.general_llm.generate([{"role": "user", "content": prompt}]) success, parsed_result = parse_keep_filter_response(raw) if not success: