Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import logging
import os
from typing import Literal
from typing import Literal, cast
import dotenv
from pathlib import Path

Expand Down Expand Up @@ -268,7 +268,17 @@ def __init__(self):
)


_MODES = ("local", "remote")
MODEL_MODES = ("local", "remote")
PIPELINE_MODES = ("lightweight", "accurate")
OCR_MODES = ("docling", "docparser")
AGENTIC_ARCHITECTURES = ("custom", "langchain")


def _read_mode(name: str, allowed_values: tuple[str, ...], default: str | None = None) -> str:
value = os.getenv(name, default)
if value not in allowed_values:
raise ValueError(f"Invalid {name}: {value}")
return value


class Config:
Expand All @@ -289,9 +299,9 @@ def __init__(self):
if not self.brainpat_token:
raise ValueError("BrainPAT token is not set")

self.models_mode = os.getenv("MODELS_MODE")
if self.models_mode not in _MODES:
raise ValueError(f"Invalid MODELS_MODE: {self.models_mode}")
self.models_mode = cast(
Literal["local", "remote"], _read_mode("MODELS_MODE", MODEL_MODES)
)

if self.models_mode == "local":
self.azure = None
Expand All @@ -317,14 +327,16 @@ def __init__(self):
self.docparser_endpoint = os.getenv("DOCPARSER_ENDPOINT")
self.docparser_token = os.getenv("DOCPARSER_TOKEN")
self.app_host = os.getenv("APP_HOST")
self.pipeline_mode: Literal["lightweight", "accurate"] = os.getenv(
"PIPELINE_MODE"
self.pipeline_mode = cast(
Literal["lightweight", "accurate"],
_read_mode("PIPELINE_MODE", PIPELINE_MODES, "accurate"),
)
self.ocr_mode: Literal["docling", "docparser"] = os.getenv(
"OCR_MODE", "docling"
self.ocr_mode = cast(
Literal["docling", "docparser"], _read_mode("OCR_MODE", OCR_MODES, "docling")
)
self.agentic_architecture: Literal["custom", "langchain"] = os.getenv(
"AGENTIC_ARCHITECTURE", "custom"
self.agentic_architecture = cast(
Literal["custom", "langchain"],
_read_mode("AGENTIC_ARCHITECTURE", AGENTIC_ARCHITECTURES, "custom"),
)


Expand Down
7 changes: 2 additions & 5 deletions src/core/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
from src.adapters.embeddings import EmbeddingsAdapter, VectorStoreAdapter
from src.adapters.graph import GraphAdapter
from src.adapters.llm import LLMAdapter
from src.config import config
from src.config import MODEL_MODES, config
from src.lib.milvus.client import _milvus_client
from src.lib.mongo.client import _mongo_client
from src.lib.neo4j.client import _neo4j_client
from src.lib.redis.client import _redis_client

_MODES = ("local", "remote")


def _require_mode():
if config.models_mode not in _MODES:
if config.models_mode not in MODEL_MODES:
raise ValueError(f"Invalid models mode: {config.models_mode}")
return config.models_mode == "local"

Expand Down
15 changes: 15 additions & 0 deletions src/core/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from src.core.pipeline.mode_strategy import (
AccuratePipelineModeStrategy,
LightweightPipelineModeStrategy,
PipelineModeStrategy,
PipelineModeStrategyFactory,
resolve_pipeline_mode_strategy,
)

__all__ = [
"PipelineModeStrategy",
"LightweightPipelineModeStrategy",
"AccuratePipelineModeStrategy",
"PipelineModeStrategyFactory",
"resolve_pipeline_mode_strategy",
]
54 changes: 54 additions & 0 deletions src/core/pipeline/mode_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import ABC, abstractmethod


class PipelineModeStrategy(ABC):
mode: str

@abstractmethod
def should_extract_observations(self) -> bool:
raise NotImplementedError

@abstractmethod
def scout_mode(self) -> str | None:
raise NotImplementedError


class LightweightPipelineModeStrategy(PipelineModeStrategy):
mode = "lightweight"

def should_extract_observations(self) -> bool:
return False

def scout_mode(self) -> str | None:
return "coarse"


class AccuratePipelineModeStrategy(PipelineModeStrategy):
mode = "accurate"

def should_extract_observations(self) -> bool:
return True

def scout_mode(self) -> str | None:
return None


class PipelineModeStrategyFactory:
def __init__(self):
self._strategies = {
"lightweight": LightweightPipelineModeStrategy(),
"accurate": AccuratePipelineModeStrategy(),
}

def create(self, mode: str) -> PipelineModeStrategy:
strategy = self._strategies.get(mode)
if strategy is None:
raise ValueError(f"Invalid PIPELINE_MODE: {mode}")
return strategy


_pipeline_mode_strategy_factory = PipelineModeStrategyFactory()


def resolve_pipeline_mode_strategy(mode: str) -> PipelineModeStrategy:
return _pipeline_mode_strategy_factory.create(mode)
11 changes: 7 additions & 4 deletions src/core/saving/auto_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from src.core.agents.scout_agent import ScoutAgent
from src.core.agents.architect_agent import ArchitectAgent
from src.core.layers.graph_consolidation.graph_consolidation import consolidate_graph
from src.core.pipeline import resolve_pipeline_mode_strategy
from src.core.saving.ingestion_manager import IngestionManager
from src.services.input.agents import (
cache_adapter,
Expand Down Expand Up @@ -81,15 +82,17 @@ def _enrich_kg_impl(input: str, targeting, brain_id: str, ingestion_session_id:
ingestion_manager=ingestion_manager,
)

if config.pipeline_mode == "lightweight":
pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode)
scout_mode = pipeline_strategy.scout_mode()
if scout_mode is not None:
print("[DEBUG (enrich_kg_from_input)]: Lightweight pipeline mode selected")

entities = scout_agent.run(
input,
targeting=targeting,
brain_id=brain_id,
ingestion_session_id=ingestion_session_id,
mode="coarse",
mode=scout_mode,
)
print("[DEBUG (initial_scout_entities)]: ", entities.entities)
architect_agent.run_tooler(
Expand All @@ -99,12 +102,12 @@ def _enrich_kg_impl(input: str, targeting, brain_id: str, ingestion_session_id:
brain_id=brain_id,
timeout=20000,
ingestion_session_id=ingestion_session_id,
mode="coarse",
mode=scout_mode,
)

return

if config.pipeline_mode == "accurate":
if pipeline_strategy.should_extract_observations():
token_details = []
print(f"[DEBUG (ingestion_session_id)]: {ingestion_session_id}")

Expand Down
7 changes: 5 additions & 2 deletions src/services/api/controllers/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
RetrieveNeighborsRequestResponse,
RetrievedNeighborNode,
)
from src.services.kg_agent.main import graph_adapter, kg_agent
from src.services.kg_agent.main import graph_adapter
from src.services.data.main import data_adapter
from src.services.kg_agent.main import embeddings_adapter, vector_store_adapter
from src.utils.similarity.vectors import cosine_similarity
from src.utils.nlp.ner import _entity_extractor

vector_search = VectorSearchFacade(vector_store_adapter)

Expand Down Expand Up @@ -272,6 +271,8 @@ async def retrieve_neighbors_ai_mode(
"""

def _get_neighbors():
from src.services.kg_agent.main import kg_agent

node = graph_adapter.get_by_identification_params(
identification_params,
brain_id=brain_id,
Expand Down Expand Up @@ -394,6 +395,8 @@ async def get_context(request: GetContextRequestBody) -> GetContextResponse:
GetContextResponse: Response containing the context information.
"""

from src.utils.nlp.ner import _entity_extractor

embeddings_map = {}
futures = []

Expand Down
6 changes: 4 additions & 2 deletions src/workers/tasks/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from src.constants.agents import ArchitectAgentRelationship
from src.constants.prompts.misc import NODE_DESCRIPTION_PROMPT
from src.core.plugins.prompts import prompt_registry
from src.core.pipeline import resolve_pipeline_mode_strategy
from src.core.saving.auto_kg import enrich_kg_from_input
from src.core.saving.ingestion_manager import IngestionManager
from src.services.api.constants.requests import IngestionStructuredRequestBody
Expand Down Expand Up @@ -144,10 +145,11 @@ def ingest_data(self, args: dict):
"data",
)

if config.pipeline_mode == "lightweight":
pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode)
if not pipeline_strategy.should_extract_observations():
print("[DEBUG (ingest_data)]: Lightweight pipeline mode selected")

if config.pipeline_mode == "accurate":
if pipeline_strategy.should_extract_observations():
# ================================================
# --------------- Observations -------------------
# ================================================
Expand Down
9 changes: 9 additions & 0 deletions tests/test_architecture_refactors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,20 @@ def get_by_uuids(self, *_args, **_kwargs):
def get_neighbors(self, *_args, **_kwargs):
return {}

def get_nodes_by_uuid(self, *_args, **_kwargs):
return []


class _StubDataAdapter:
def get_observations_list(self, *_args, **_kwargs):
return []

def search(self, *_args, **_kwargs):
return type("SearchResult", (), {"text_chunks": [], "observations": []})()

def get_text_chunks_by_ids(self, *_args, **_kwargs):
return [], []


_stub_input_agents = types.ModuleType("src.services.input.agents")
_stub_input_agents.embeddings_adapter = _StubEmbeddingsAdapter()
Expand Down
117 changes: 117 additions & 0 deletions tests/test_config_pipeline_mode_refactors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
from pathlib import Path
import unittest
from unittest.mock import patch


ENV_DEFAULTS = {
"BRAINPAT_TOKEN": "test-token",
"MODELS_MODE": "local",
"EMBEDDINGS_LOCAL_MODEL": "local-model",
"EMBEDDINGS_SMALL_MODEL": "small-model",
"EMBEDDING_NODES_DIMENSION": "3",
"EMBEDDING_TRIPLETS_DIMENSION": "3",
"EMBEDDING_OBSERVATIONS_DIMENSION": "3",
"EMBEDDING_DATA_DIMENSION": "3",
"EMBEDDING_RELATIONSHIPS_DIMENSION": "3",
"REDIS_HOST": "localhost",
"REDIS_PORT": "6379",
"NEO4J_HOST": "localhost",
"NEO4J_PORT": "7687",
"NEO4J_USERNAME": "neo4j",
"NEO4J_PASSWORD": "password",
"MILVUS_HOST": "localhost",
"MILVUS_PORT": "19530",
"MONGO_CONNECTION_STRING": "mongodb://localhost:27017",
"CELERY_WORKER_CONCURRENCY": "1",
"OLLAMA_HOST": "localhost",
"OLLAMA_PORT": "11434",
"OLLAMA_LLM_SMALL_MODEL": "small",
"OLLAMA_LLM_LARGE_MODEL": "large",
"PIPELINE_MODE": "accurate",
"OCR_MODE": "docling",
"AGENTIC_ARCHITECTURE": "custom",
}
for key, value in ENV_DEFAULTS.items():
os.environ.setdefault(key, value)

from src.config import Config
from src.core.pipeline import (
AccuratePipelineModeStrategy,
LightweightPipelineModeStrategy,
PipelineModeStrategyFactory,
resolve_pipeline_mode_strategy,
)


ROOT = Path(__file__).resolve().parent.parent


def read_source(relative_path: str) -> str:
return (ROOT / relative_path).read_text(encoding="utf-8")


class ConfigModeValidationTests(unittest.TestCase):
def test_pipeline_mode_defaults_to_accurate(self):
with patch.dict(os.environ, ENV_DEFAULTS, clear=False):
os.environ.pop("PIPELINE_MODE", None)
config = Config()
self.assertEqual(config.pipeline_mode, "accurate")

def test_invalid_pipeline_mode_raises_value_error(self):
with patch.dict(os.environ, ENV_DEFAULTS, clear=False):
os.environ["PIPELINE_MODE"] = "invalid-mode"
with self.assertRaises(ValueError):
Config()

def test_invalid_ocr_mode_raises_value_error(self):
with patch.dict(os.environ, ENV_DEFAULTS, clear=False):
os.environ["OCR_MODE"] = "invalid-ocr"
with self.assertRaises(ValueError):
Config()

def test_invalid_agentic_architecture_raises_value_error(self):
with patch.dict(os.environ, ENV_DEFAULTS, clear=False):
os.environ["AGENTIC_ARCHITECTURE"] = "invalid-arch"
with self.assertRaises(ValueError):
Config()


class PipelineModeStrategyFactoryTests(unittest.TestCase):
def test_factory_builds_lightweight_strategy(self):
strategy = PipelineModeStrategyFactory().create("lightweight")
self.assertIsInstance(strategy, LightweightPipelineModeStrategy)
self.assertFalse(strategy.should_extract_observations())
self.assertEqual(strategy.scout_mode(), "coarse")

def test_factory_builds_accurate_strategy(self):
strategy = PipelineModeStrategyFactory().create("accurate")
self.assertIsInstance(strategy, AccuratePipelineModeStrategy)
self.assertTrue(strategy.should_extract_observations())
self.assertIsNone(strategy.scout_mode())

def test_resolver_rejects_unknown_mode(self):
with self.assertRaises(ValueError):
resolve_pipeline_mode_strategy("broken")


class PipelineStrategyIntegrationTests(unittest.TestCase):
def test_ingestion_uses_pipeline_strategy_dispatch(self):
source = read_source("src/workers/tasks/ingestion.py")
self.assertIn(
"pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode)",
source,
)
self.assertIn("pipeline_strategy.should_extract_observations()", source)

def test_auto_kg_uses_pipeline_strategy_dispatch(self):
source = read_source("src/core/saving/auto_kg.py")
self.assertIn(
"pipeline_strategy = resolve_pipeline_mode_strategy(config.pipeline_mode)",
source,
)
self.assertIn("scout_mode = pipeline_strategy.scout_mode()", source)


if __name__ == "__main__":
unittest.main()