diff --git a/openviking/session/memory/memory_type_registry.py b/openviking/session/memory/memory_type_registry.py index cdebdcfd9..51427663b 100644 --- a/openviking/session/memory/memory_type_registry.py +++ b/openviking/session/memory/memory_type_registry.py @@ -39,12 +39,12 @@ def __init__(self, load_schemas: bool = True): self._load_schemas() def _load_schemas(self) -> None: - """Load schemas from the resolved memory templates directory and custom directory.""" + """Load schemas from built-in templates, then custom/configured overrides.""" import os from openviking_cli.utils.config import get_openviking_config - memory_templates_dir = str(resolve_memory_templates_dir()) + memory_templates_dir = str(PromptManager._get_bundled_templates_dir() / "memory") config = get_openviking_config() custom_dir = config.memory.custom_templates_dir @@ -57,7 +57,6 @@ def _load_schemas(self) -> None: ) logger.info(f"Loaded {loaded} memory schemas from templates: {memory_templates_dir}") - # Load from custom directory (if configured) - use replace to allow overriding built-in templates if custom_dir: custom_dir_expanded = os.path.expanduser(custom_dir) if os.path.exists(custom_dir_expanded): @@ -65,6 +64,17 @@ def _load_schemas(self) -> None: logger.info( f"Loaded {custom_loaded} memory schemas from custom: {custom_dir_expanded}" ) + else: + memory_templates_dir = str(resolve_memory_templates_dir()) + if memory_templates_dir != str( + PromptManager._get_bundled_templates_dir() / "memory" + ) and os.path.exists(memory_templates_dir): + loaded = self.load_from_directory(memory_templates_dir, replace=True) + logger.info( + "Loaded %s memory schemas from configured prompt templates: %s", + loaded, + memory_templates_dir, + ) def register(self, memory_type: MemoryTypeSchema) -> None: """Register a memory type. Raises error if already exists.""" diff --git a/openviking/session/memory/session_extract_context_provider.py b/openviking/session/memory/session_extract_context_provider.py index e0e19d730..c296636a2 100644 --- a/openviking/session/memory/session_extract_context_provider.py +++ b/openviking/session/memory/session_extract_context_provider.py @@ -12,6 +12,7 @@ from openviking.core.namespace import to_agent_space, to_user_space from openviking.message.part import ToolPart +from openviking.prompts.manager import PromptManager from openviking.server.identity import RequestContext, ToolContext from openviking.session.memory.core import ExtractContextProvider from openviking.session.memory.dataclass import MemoryFileContent @@ -473,7 +474,7 @@ def get_memory_schemas(self, ctx: RequestContext) -> List[Any]: def get_schema_directories(self) -> List[str]: """返回需要加载的 schema 目录""" if self._schema_directories is None: - memory_templates_dir = str(resolve_memory_templates_dir()) + memory_templates_dir = str(PromptManager._get_bundled_templates_dir() / "memory") config = get_openviking_config() custom_dir = config.memory.custom_templates_dir self._schema_directories = [memory_templates_dir] @@ -481,6 +482,12 @@ def get_schema_directories(self) -> List[str]: custom_dir_expanded = os.path.expanduser(custom_dir) if os.path.exists(custom_dir_expanded): self._schema_directories.append(custom_dir_expanded) + else: + memory_templates_dir = str(resolve_memory_templates_dir()) + if memory_templates_dir != str( + PromptManager._get_bundled_templates_dir() / "memory" + ) and os.path.exists(memory_templates_dir): + self._schema_directories.append(memory_templates_dir) return self._schema_directories def _get_registry(self) -> MemoryTypeRegistry: diff --git a/tests/test_prompt_manager.py b/tests/test_prompt_manager.py index 1162c8ca7..b1ba8e1ab 100644 --- a/tests/test_prompt_manager.py +++ b/tests/test_prompt_manager.py @@ -167,6 +167,57 @@ def test_memory_type_registry_loads_schemas_from_prompt_manager_resolved_templat assert registry.get("custom_memory") is not None +def test_memory_type_registry_prefers_custom_memory_dir_over_prompt_manager_templates_root( + tmp_path, monkeypatch +): + resolved_templates_dir = tmp_path / "resolved-prompts" + resolved_memory_dir = resolved_templates_dir / "memory" + custom_memory_dir = tmp_path / "custom-memory" + resolved_memory_dir.mkdir(parents=True) + custom_memory_dir.mkdir(parents=True) + (resolved_memory_dir / "prompt_root.yaml").write_text( + json.dumps( + { + "memory_type": "prompt_root_memory", + "description": "schema from prompt manager root", + "directory": "viking://user/{{ user_space }}/memories/prompt-root", + "filename_template": "prompt-root.md", + "fields": [], + } + ), + encoding="utf-8", + ) + (custom_memory_dir / "custom.yaml").write_text( + json.dumps( + { + "memory_type": "custom_memory", + "description": "schema from custom memory dir", + "directory": "viking://user/{{ user_space }}/memories/custom", + "filename_template": "custom.md", + "fields": [], + } + ), + encoding="utf-8", + ) + + monkeypatch.setattr( + PromptManager, + "_resolve_templates_dir", + classmethod(lambda cls, templates_dir=None: resolved_templates_dir), + ) + monkeypatch.setattr( + "openviking_cli.utils.config.get_openviking_config", + lambda: SimpleNamespace( + memory=SimpleNamespace(custom_templates_dir=str(custom_memory_dir)) + ), + ) + + registry = MemoryTypeRegistry(load_schemas=True) + + assert registry.get("custom_memory") is not None + assert registry.get("prompt_root_memory") is None + + def test_context_provider_schema_directories_use_prompt_manager_resolved_templates_root( tmp_path, monkeypatch ): @@ -188,3 +239,37 @@ def test_context_provider_schema_directories_use_prompt_manager_resolved_templat provider = SessionExtractContextProvider(messages=[]) assert provider.get_schema_directories() == [str(expected_memory_dir)] + + +def test_context_provider_schema_directories_prefer_custom_memory_dir_over_prompt_manager_root( + tmp_path, monkeypatch +): + resolved_templates_dir = tmp_path / "resolved-prompts" + custom_memory_dir = tmp_path / "custom-memory" + + monkeypatch.setattr( + PromptManager, + "_resolve_templates_dir", + classmethod(lambda cls, templates_dir=None: resolved_templates_dir), + ) + monkeypatch.setattr( + "openviking.session.memory.session_extract_context_provider.get_openviking_config", + lambda: SimpleNamespace( + memory=SimpleNamespace( + custom_templates_dir=str(custom_memory_dir), + eager_prefetch=False, + ) + ), + ) + monkeypatch.setattr( + "os.path.exists", + lambda path: path == str(custom_memory_dir) + or path == str(PromptManager._get_bundled_templates_dir() / "memory"), + ) + + provider = SessionExtractContextProvider(messages=[]) + + assert provider.get_schema_directories() == [ + str(PromptManager._get_bundled_templates_dir() / "memory"), + str(custom_memory_dir), + ]