diff --git a/src/config.py b/src/config.py index 4473283..00744a1 100644 --- a/src/config.py +++ b/src/config.py @@ -12,7 +12,7 @@ import logging import os -from typing import Literal +from typing import Literal, cast import dotenv from pathlib import Path @@ -268,7 +268,17 @@ def __init__(self): ) -_MODES = ("local", "remote") +MODEL_MODES = ("local", "remote") +PIPELINE_MODES = ("lightweight", "accurate") +OCR_MODES = ("docling", "docparser") +AGENTIC_ARCHITECTURES = ("custom", "langchain") + + +def _read_mode(name: str, allowed_values: tuple[str, ...], default: str | None = None) -> str: + value = os.getenv(name, default) + if value not in allowed_values: + raise ValueError(f"Invalid {name}: {value}") + return value class Config: @@ -289,9 +299,9 @@ def __init__(self): if not self.brainpat_token: raise ValueError("BrainPAT token is not set") - self.models_mode = os.getenv("MODELS_MODE") - if self.models_mode not in _MODES: - raise ValueError(f"Invalid MODELS_MODE: {self.models_mode}") + self.models_mode = cast( + Literal["local", "remote"], _read_mode("MODELS_MODE", MODEL_MODES) + ) if self.models_mode == "local": self.azure = None @@ -317,14 +327,16 @@ def __init__(self): self.docparser_endpoint = os.getenv("DOCPARSER_ENDPOINT") self.docparser_token = os.getenv("DOCPARSER_TOKEN") self.app_host = os.getenv("APP_HOST") - self.pipeline_mode: Literal["lightweight", "accurate"] = os.getenv( - "PIPELINE_MODE" + self.pipeline_mode = cast( + Literal["lightweight", "accurate"], + _read_mode("PIPELINE_MODE", PIPELINE_MODES, "accurate"), ) - self.ocr_mode: Literal["docling", "docparser"] = os.getenv( - "OCR_MODE", "docling" + self.ocr_mode = cast( + Literal["docling", "docparser"], _read_mode("OCR_MODE", OCR_MODES, "docling") ) - self.agentic_architecture: Literal["custom", "langchain"] = os.getenv( - "AGENTIC_ARCHITECTURE", "custom" + self.agentic_architecture = cast( + Literal["custom", "langchain"], + _read_mode("AGENTIC_ARCHITECTURE", AGENTIC_ARCHITECTURES, "custom"), ) diff --git a/src/core/instances.py b/src/core/instances.py index a751a9c..4279493 100644 --- a/src/core/instances.py +++ b/src/core/instances.py @@ -3,17 +3,14 @@ from src.adapters.embeddings import EmbeddingsAdapter, VectorStoreAdapter from src.adapters.graph import GraphAdapter from src.adapters.llm import LLMAdapter -from src.config import config +from src.config import MODEL_MODES, config from src.lib.milvus.client import _milvus_client from src.lib.mongo.client import _mongo_client from src.lib.neo4j.client import _neo4j_client from src.lib.redis.client import _redis_client -_MODES = ("local", "remote") - - def _require_mode(): - if config.models_mode not in _MODES: + if config.models_mode not in MODEL_MODES: raise ValueError(f"Invalid models mode: {config.models_mode}") return config.models_mode == "local" diff --git a/src/core/pipeline/__init__.py b/src/core/pipeline/__init__.py new file mode 100644 index 0000000..f32c68a --- /dev/null +++ b/src/core/pipeline/__init__.py @@ -0,0 +1,15 @@ +from src.core.pipeline.mode_strategy import ( + AccuratePipelineModeStrategy, + LightweightPipelineModeStrategy, + PipelineModeStrategy, + PipelineModeStrategyFactory, + resolve_pipeline_mode_strategy, +) + +__all__ = [ + "PipelineModeStrategy", + "LightweightPipelineModeStrategy", + "AccuratePipelineModeStrategy", + "PipelineModeStrategyFactory", + "resolve_pipeline_mode_strategy", +] diff --git a/src/core/pipeline/mode_strategy.py b/src/core/pipeline/mode_strategy.py new file mode 100644 index 0000000..e8d6408 --- /dev/null +++ b/src/core/pipeline/mode_strategy.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod + + +class PipelineModeStrategy(ABC): + mode: str + + @abstractmethod + def should_extract_observations(self) -> bool: + raise NotImplementedError + + @abstractmethod + def scout_mode(self) -> str | None: + raise NotImplementedError + + +class LightweightPipelineModeStrategy(PipelineModeStrategy): + mode = "lightweight" + + def should_extract_observations(self) -> bool: + return False + + def scout_mode(self) -> str | None: + return "coarse" + + +class AccuratePipelineModeStrategy(PipelineModeStrategy): + mode = "accurate" + + def should_extract_observations(self) -> bool: + return True + + def scout_mode(self) -> str | None: + return None + + +class PipelineModeStrategyFactory: + def __init__(self): + self._strategies = { + "lightweight": LightweightPipelineModeStrategy(), + "accurate": AccuratePipelineModeStrategy(), + } + + def create(self, mode: str) -> PipelineModeStrategy: + strategy = self._strategies.get(mode) + if strategy is None: + raise ValueError(f"Invalid PIPELINE_MODE: {mode}") + return strategy + + +_pipeline_mode_strategy_factory = PipelineModeStrategyFactory() + + +def resolve_pipeline_mode_strategy(mode: str) -> PipelineModeStrategy: + return _pipeline_mode_strategy_factory.create(mode) diff --git a/src/core/saving/auto_kg.py b/src/core/saving/auto_kg.py index f7eef30..578bf9b 100644 --- a/src/core/saving/auto_kg.py +++ b/src/core/saving/auto_kg.py @@ -17,6 +17,7 @@ from src.core.agents.scout_agent import ScoutAgent from src.core.agents.architect_agent import ArchitectAgent from src.core.layers.graph_consolidation.graph_consolidation import consolidate_graph +from src.core.pipeline import resolve_pipeline_mode_strategy from src.core.saving.ingestion_manager import IngestionManager from src.services.input.agents import ( cache_adapter, @@ -81,7 +82,9 @@ def _enrich_kg_impl(input: str, targeting, brain_id: str, ingestion_session_id: ingestion_manager=ingestion_manager, ) - if config.pipeline_mode == "lightweight": + pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode) + scout_mode = pipeline_strategy.scout_mode() + if scout_mode is not None: print("[DEBUG (enrich_kg_from_input)]: Lightweight pipeline mode selected") entities = scout_agent.run( @@ -89,7 +92,7 @@ def _enrich_kg_impl(input: str, targeting, brain_id: str, ingestion_session_id: targeting=targeting, brain_id=brain_id, ingestion_session_id=ingestion_session_id, - mode="coarse", + mode=scout_mode, ) print("[DEBUG (initial_scout_entities)]: ", entities.entities) architect_agent.run_tooler( @@ -99,12 +102,12 @@ def _enrich_kg_impl(input: str, targeting, brain_id: str, ingestion_session_id: brain_id=brain_id, timeout=20000, ingestion_session_id=ingestion_session_id, - mode="coarse", + mode=scout_mode, ) return - if config.pipeline_mode == "accurate": + if pipeline_strategy.should_extract_observations(): token_details = [] print(f"[DEBUG (ingestion_session_id)]: {ingestion_session_id}") diff --git a/src/services/api/controllers/retrieve.py b/src/services/api/controllers/retrieve.py index 8fd2e0d..e6831a7 100644 --- a/src/services/api/controllers/retrieve.py +++ b/src/services/api/controllers/retrieve.py @@ -28,11 +28,10 @@ RetrieveNeighborsRequestResponse, RetrievedNeighborNode, ) -from src.services.kg_agent.main import graph_adapter, kg_agent +from src.services.kg_agent.main import graph_adapter from src.services.data.main import data_adapter from src.services.kg_agent.main import embeddings_adapter, vector_store_adapter from src.utils.similarity.vectors import cosine_similarity -from src.utils.nlp.ner import _entity_extractor vector_search = VectorSearchFacade(vector_store_adapter) @@ -272,6 +271,8 @@ async def retrieve_neighbors_ai_mode( """ def _get_neighbors(): + from src.services.kg_agent.main import kg_agent + node = graph_adapter.get_by_identification_params( identification_params, brain_id=brain_id, @@ -394,6 +395,8 @@ async def get_context(request: GetContextRequestBody) -> GetContextResponse: GetContextResponse: Response containing the context information. """ + from src.utils.nlp.ner import _entity_extractor + embeddings_map = {} futures = [] diff --git a/src/workers/tasks/ingestion.py b/src/workers/tasks/ingestion.py index aa4e395..fa4685c 100644 --- a/src/workers/tasks/ingestion.py +++ b/src/workers/tasks/ingestion.py @@ -38,6 +38,7 @@ from src.constants.agents import ArchitectAgentRelationship from src.constants.prompts.misc import NODE_DESCRIPTION_PROMPT from src.core.plugins.prompts import prompt_registry +from src.core.pipeline import resolve_pipeline_mode_strategy from src.core.saving.auto_kg import enrich_kg_from_input from src.core.saving.ingestion_manager import IngestionManager from src.services.api.constants.requests import IngestionStructuredRequestBody @@ -144,10 +145,11 @@ def ingest_data(self, args: dict): "data", ) - if config.pipeline_mode == "lightweight": + pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode) + if not pipeline_strategy.should_extract_observations(): print("[DEBUG (ingest_data)]: Lightweight pipeline mode selected") - if config.pipeline_mode == "accurate": + if pipeline_strategy.should_extract_observations(): # ================================================ # --------------- Observations ------------------- # ================================================ diff --git a/tests/test_architecture_refactors.py b/tests/test_architecture_refactors.py index e96ff10..4992c95 100644 --- a/tests/test_architecture_refactors.py +++ b/tests/test_architecture_refactors.py @@ -62,11 +62,20 @@ def get_by_uuids(self, *_args, **_kwargs): def get_neighbors(self, *_args, **_kwargs): return {} + def get_nodes_by_uuid(self, *_args, **_kwargs): + return [] + class _StubDataAdapter: def get_observations_list(self, *_args, **_kwargs): return [] + def search(self, *_args, **_kwargs): + return type("SearchResult", (), {"text_chunks": [], "observations": []})() + + def get_text_chunks_by_ids(self, *_args, **_kwargs): + return [], [] + _stub_input_agents = types.ModuleType("src.services.input.agents") _stub_input_agents.embeddings_adapter = _StubEmbeddingsAdapter() diff --git a/tests/test_config_pipeline_mode_refactors.py b/tests/test_config_pipeline_mode_refactors.py new file mode 100644 index 0000000..945ceed --- /dev/null +++ b/tests/test_config_pipeline_mode_refactors.py @@ -0,0 +1,117 @@ +import os +from pathlib import Path +import unittest +from unittest.mock import patch + + +ENV_DEFAULTS = { + "BRAINPAT_TOKEN": "test-token", + "MODELS_MODE": "local", + "EMBEDDINGS_LOCAL_MODEL": "local-model", + "EMBEDDINGS_SMALL_MODEL": "small-model", + "EMBEDDING_NODES_DIMENSION": "3", + "EMBEDDING_TRIPLETS_DIMENSION": "3", + "EMBEDDING_OBSERVATIONS_DIMENSION": "3", + "EMBEDDING_DATA_DIMENSION": "3", + "EMBEDDING_RELATIONSHIPS_DIMENSION": "3", + "REDIS_HOST": "localhost", + "REDIS_PORT": "6379", + "NEO4J_HOST": "localhost", + "NEO4J_PORT": "7687", + "NEO4J_USERNAME": "neo4j", + "NEO4J_PASSWORD": "password", + "MILVUS_HOST": "localhost", + "MILVUS_PORT": "19530", + "MONGO_CONNECTION_STRING": "mongodb://localhost:27017", + "CELERY_WORKER_CONCURRENCY": "1", + "OLLAMA_HOST": "localhost", + "OLLAMA_PORT": "11434", + "OLLAMA_LLM_SMALL_MODEL": "small", + "OLLAMA_LLM_LARGE_MODEL": "large", + "PIPELINE_MODE": "accurate", + "OCR_MODE": "docling", + "AGENTIC_ARCHITECTURE": "custom", +} +for key, value in ENV_DEFAULTS.items(): + os.environ.setdefault(key, value) + +from src.config import Config +from src.core.pipeline import ( + AccuratePipelineModeStrategy, + LightweightPipelineModeStrategy, + PipelineModeStrategyFactory, + resolve_pipeline_mode_strategy, +) + + +ROOT = Path(__file__).resolve().parent.parent + + +def read_source(relative_path: str) -> str: + return (ROOT / relative_path).read_text(encoding="utf-8") + + +class ConfigModeValidationTests(unittest.TestCase): + def test_pipeline_mode_defaults_to_accurate(self): + with patch.dict(os.environ, ENV_DEFAULTS, clear=False): + os.environ.pop("PIPELINE_MODE", None) + config = Config() + self.assertEqual(config.pipeline_mode, "accurate") + + def test_invalid_pipeline_mode_raises_value_error(self): + with patch.dict(os.environ, ENV_DEFAULTS, clear=False): + os.environ["PIPELINE_MODE"] = "invalid-mode" + with self.assertRaises(ValueError): + Config() + + def test_invalid_ocr_mode_raises_value_error(self): + with patch.dict(os.environ, ENV_DEFAULTS, clear=False): + os.environ["OCR_MODE"] = "invalid-ocr" + with self.assertRaises(ValueError): + Config() + + def test_invalid_agentic_architecture_raises_value_error(self): + with patch.dict(os.environ, ENV_DEFAULTS, clear=False): + os.environ["AGENTIC_ARCHITECTURE"] = "invalid-arch" + with self.assertRaises(ValueError): + Config() + + +class PipelineModeStrategyFactoryTests(unittest.TestCase): + def test_factory_builds_lightweight_strategy(self): + strategy = PipelineModeStrategyFactory().create("lightweight") + self.assertIsInstance(strategy, LightweightPipelineModeStrategy) + self.assertFalse(strategy.should_extract_observations()) + self.assertEqual(strategy.scout_mode(), "coarse") + + def test_factory_builds_accurate_strategy(self): + strategy = PipelineModeStrategyFactory().create("accurate") + self.assertIsInstance(strategy, AccuratePipelineModeStrategy) + self.assertTrue(strategy.should_extract_observations()) + self.assertIsNone(strategy.scout_mode()) + + def test_resolver_rejects_unknown_mode(self): + with self.assertRaises(ValueError): + resolve_pipeline_mode_strategy("broken") + + +class PipelineStrategyIntegrationTests(unittest.TestCase): + def test_ingestion_uses_pipeline_strategy_dispatch(self): + source = read_source("src/workers/tasks/ingestion.py") + self.assertIn( + "pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode)", + source, + ) + self.assertIn("pipeline_strategy.should_extract_observations()", source) + + def test_auto_kg_uses_pipeline_strategy_dispatch(self): + source = read_source("src/core/saving/auto_kg.py") + self.assertIn( + "pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode)", + source, + ) + self.assertIn("scout_mode = pipeline_strategy.scout_mode()", source) + + +if __name__ == "__main__": + unittest.main()