diff --git a/src/adapters/graph.py b/src/adapters/graph.py index 4221f0f..b3c1e8e 100644 --- a/src/adapters/graph.py +++ b/src/adapters/graph.py @@ -500,6 +500,30 @@ def _average_embeddings(self, vectors: list[Vector]) -> list[float]: for i in range(dimension) ] + def _node_description(self, node: Optional[Node]) -> Optional[str]: + if node is None: + return None + return node.description + + def _vectors_with_descriptions( + self, + neighbors: List[Tuple[Predicate, Node]], + vectors_by_uuid: Dict[str, Vector], + ) -> list[dict]: + vectors_with_desc = [] + for _, node in neighbors: + vector = vectors_by_uuid.get(node.uuid) + if vector is None: + continue + vectors_with_desc.append( + { + "embeddings": vector.embeddings, + "metadata": vector.metadata, + "description": self._node_description(node), + } + ) + return vectors_with_desc + def get_2nd_degree_hops( self, from_uuids: List[str], @@ -588,20 +612,7 @@ def flatten_pred(p): filtered_fd_by_origin: Dict[str, List[Tuple[Predicate, Node]]] = {} for node_uuid, fd_list in all_fd_nodes.items(): - fd_nodes_by_uuid = {fd[1].uuid: fd[1] for fd in fd_list} - fd_vs_with_desc = [ - { - "embeddings": fd_vs_by_uuid[fd[1].uuid].embeddings, - "metadata": fd_vs_by_uuid[fd[1].uuid].metadata, - "description": ( - fd_nodes_by_uuid.get(fd[1].uuid, {}).description - if fd_nodes_by_uuid.get(fd[1].uuid) - else None - ), - } - for fd in fd_list - if fd[1].uuid in fd_vs_by_uuid - ] + fd_vs_with_desc = self._vectors_with_descriptions(fd_list, fd_vs_by_uuid) from_node = nodes_by_uuid[node_uuid] filtered_uuids = self._reduce_neighbor_vectors( vectors_with_desc=fd_vs_with_desc, @@ -646,21 +657,7 @@ def flatten_pred(p): for fd_pred, fd_node in filtered_fd_by_origin.get(from_uuid, []): sd_list = all_sd_nodes.get(fd_node.uuid, []) - sd_nodes_by_uuid = {sd[1].uuid: sd[1] for sd in sd_list} - - sd_vs_with_desc = [ - { - "embeddings": sd_vs_by_uuid[sd[1].uuid].embeddings, - "metadata": sd_vs_by_uuid[sd[1].uuid].metadata, - "description": ( - sd_nodes_by_uuid.get(sd[1].uuid, {}).description - if sd_nodes_by_uuid.get(sd[1].uuid) - else None - ), - } - for sd in sd_list - if sd[1].uuid in sd_vs_by_uuid - ] + sd_vs_with_desc = self._vectors_with_descriptions(sd_list, sd_vs_by_uuid) reduced_uuids = self._reduce_neighbor_vectors( vectors_with_desc=sd_vs_with_desc, diff --git a/src/adapters/graph_operation_result_serializer.py b/src/adapters/graph_operation_result_serializer.py index 73bf06c..459fa07 100644 --- a/src/adapters/graph_operation_result_serializer.py +++ b/src/adapters/graph_operation_result_serializer.py @@ -1,7 +1,10 @@ import json +import logging from abc import ABC, abstractmethod from typing import Any, Iterable +logger = logging.getLogger(__name__) + class GraphOperationResultSerializer(ABC): @abstractmethod @@ -33,26 +36,39 @@ class Neo4jResultSerializer(GraphOperationResultSerializer): def can_handle(self, result: Any) -> bool: return hasattr(result, "records") + def _serialize_record(self, record: Any) -> dict | str: + record_data = getattr(record, "data", None) + if callable(record_data): + return record_data() + try: + return dict(record) + except (TypeError, ValueError) as exc: + logger.warning("Failed to map neo4j record, using string fallback: %s", exc) + return str(record) + + def _extract_keys(self, result: Any) -> list | None: + keys_accessor = getattr(result, "keys", None) + if callable(keys_accessor): + try: + return list(keys_accessor()) + except (TypeError, ValueError) as exc: + logger.warning( + "Failed to extract neo4j result keys from callable: %s", exc + ) + return None + if isinstance(keys_accessor, Iterable) and not isinstance( + keys_accessor, (str, bytes) + ): + return list(keys_accessor) + return None + def serialize(self, result: Any) -> str: records = result.records or [] limited_records = records[:20] serialized_records = [] for record in limited_records: - if hasattr(record, "data"): - serialized_records.append(record.data()) - continue - try: - serialized_records.append(dict(record)) - except Exception: - serialized_records.append(str(record)) - keys = None - try: - keys = list(result.keys()) - except Exception: - try: - keys = list(result.keys) - except Exception: - keys = None + serialized_records.append(self._serialize_record(record)) + keys = self._extract_keys(result) payload = { "records": serialized_records, "truncated": len(records) > 20, diff --git a/src/core/ingestion/__init__.py b/src/core/ingestion/__init__.py new file mode 100644 index 0000000..6d36e3b --- /dev/null +++ b/src/core/ingestion/__init__.py @@ -0,0 +1,6 @@ +from src.core.ingestion.text_strategy import ( + IngestionTextStrategyFactory, + extract_ingestion_text, +) + +__all__ = ["IngestionTextStrategyFactory", "extract_ingestion_text"] diff --git a/src/core/ingestion/text_strategy.py b/src/core/ingestion/text_strategy.py new file mode 100644 index 0000000..fbeb49f --- /dev/null +++ b/src/core/ingestion/text_strategy.py @@ -0,0 +1,58 @@ +import json +from abc import ABC, abstractmethod + +from src.constants.tasks.ingestion import IngestionTaskJsonArgs, IngestionTaskTextArgs + + +class IngestionTextStrategy(ABC): + @abstractmethod + def can_handle(self, payload_data: object) -> bool: + raise NotImplementedError + + @abstractmethod + def extract(self, payload_data: object) -> str: + raise NotImplementedError + + +class RawTextIngestionStrategy(IngestionTextStrategy): + def can_handle(self, payload_data: object) -> bool: + return isinstance(payload_data, IngestionTaskTextArgs) + + def extract(self, payload_data: object) -> str: + if not isinstance(payload_data, IngestionTaskTextArgs): + raise ValueError("Invalid payload type for text ingestion strategy") + return payload_data.text_data + + +class JsonIngestionStrategy(IngestionTextStrategy): + def can_handle(self, payload_data: object) -> bool: + return isinstance(payload_data, IngestionTaskJsonArgs) + + def extract(self, payload_data: object) -> str: + if not isinstance(payload_data, IngestionTaskJsonArgs): + raise ValueError("Invalid payload type for json ingestion strategy") + return json.dumps(payload_data.json_data) + + +class IngestionTextStrategyFactory: + def __init__(self, strategies: list[IngestionTextStrategy] | None = None): + self._strategies = strategies or [ + RawTextIngestionStrategy(), + JsonIngestionStrategy(), + ] + + def create(self, payload_data: object) -> IngestionTextStrategy: + for strategy in self._strategies: + if strategy.can_handle(payload_data): + return strategy + raise ValueError( + f"Unsupported ingestion payload type: {type(payload_data).__name__}" + ) + + +_default_ingestion_text_strategy_factory = IngestionTextStrategyFactory() + + +def extract_ingestion_text(payload_data: object) -> str: + strategy = _default_ingestion_text_strategy_factory.create(payload_data) + return strategy.extract(payload_data) diff --git a/src/workers/tasks/ingestion.py b/src/workers/tasks/ingestion.py index aa4e395..0dda6fb 100644 --- a/src/workers/tasks/ingestion.py +++ b/src/workers/tasks/ingestion.py @@ -38,6 +38,7 @@ from src.constants.agents import ArchitectAgentRelationship from src.constants.prompts.misc import NODE_DESCRIPTION_PROMPT from src.core.plugins.prompts import prompt_registry +from src.core.ingestion.text_strategy import extract_ingestion_text from src.core.saving.auto_kg import enrich_kg_from_input from src.core.saving.ingestion_manager import IngestionManager from src.services.api.constants.requests import IngestionStructuredRequestBody @@ -52,7 +53,6 @@ from src.workers.app import ingestion_app from src.constants.tasks.ingestion import ( IngestionTaskArgs, - IngestionTaskDataType, IngestionTaskTextArgs, ) from src.services.kg_agent.main import cache_adapter @@ -121,13 +121,11 @@ def ingest_data(self, args: dict): # ================================================ # --------------- Data Saving -------------------- # ================================================ + payload_text = extract_ingestion_text(payload.data) + text_chunk = data_adapter.save_text_chunk( TextChunk( - text=( - payload.data.text_data - if payload.data.data_type == IngestionTaskDataType.TEXT.value - else json.dumps(payload.data.json_data) - ), + text=payload_text, metadata=payload.meta_keys, brain_version=BRAIN_VERSION, ), @@ -152,11 +150,7 @@ def ingest_data(self, args: dict): # --------------- Observations ------------------- # ================================================ observations = observations_agent.observe( - text=( - payload.data.text_data - if payload.data.data_type == IngestionTaskDataType.TEXT.value - else json.dumps(payload.data.json_data) - ), + text=payload_text, observate_for=payload.observate_for, ) @@ -174,7 +168,7 @@ def ingest_data(self, args: dict): # ================================================ # ------------ Triplet Extraction ---------------- # ================================================ - enrich_kg_from_input(payload.data.text_data, brain_id=payload.brain_id) + enrich_kg_from_input(payload_text, brain_id=payload.brain_id) cache_adapter.set( key=f"task:{self.request.id}", diff --git a/tests/test_architecture_refactors.py b/tests/test_architecture_refactors.py index e96ff10..a7972e2 100644 --- a/tests/test_architecture_refactors.py +++ b/tests/test_architecture_refactors.py @@ -62,11 +62,20 @@ def get_by_uuids(self, *_args, **_kwargs): def get_neighbors(self, *_args, **_kwargs): return {} + def get_nodes_by_uuid(self, *_args, **_kwargs): + return [] + class _StubDataAdapter: def get_observations_list(self, *_args, **_kwargs): return [] + def search(self, *_args, **_kwargs): + return type("SearchResult", (), {"text_chunks": [], "observations": []})() + + def get_text_chunks_by_ids(self, *_args, **_kwargs): + return ([], []) + _stub_input_agents = types.ModuleType("src.services.input.agents") _stub_input_agents.embeddings_adapter = _StubEmbeddingsAdapter() @@ -76,12 +85,21 @@ def get_observations_list(self, *_args, **_kwargs): _stub_kg_main.graph_adapter = _StubGraphAdapter() _stub_kg_main.vector_store_adapter = _StubVectorStoreAdapter() _stub_kg_main.embeddings_adapter = _StubEmbeddingsAdapter() +_stub_kg_main.kg_agent = types.SimpleNamespace( + retrieve_neighbors=lambda *_args, **_kwargs: [] +) sys.modules.setdefault("src.services.kg_agent.main", _stub_kg_main) _stub_data_main = types.ModuleType("src.services.data.main") _stub_data_main.data_adapter = _StubDataAdapter() sys.modules.setdefault("src.services.data.main", _stub_data_main) +_stub_ner = types.ModuleType("src.utils.nlp.ner") +_stub_ner._entity_extractor = types.SimpleNamespace( + extract_entities=lambda *_args, **_kwargs: [] +) +sys.modules.setdefault("src.utils.nlp.ner", _stub_ner) + from src.services.api.controllers.entities import get_entity_status from src.services.api.controllers.retrieve import retrieve_data from src.utils.vector_search import VectorSearchFacade diff --git a/tests/test_reliability_refactors.py b/tests/test_reliability_refactors.py new file mode 100644 index 0000000..e688c3b --- /dev/null +++ b/tests/test_reliability_refactors.py @@ -0,0 +1,154 @@ +import json +import os +import unittest +from pathlib import Path + +ENV_DEFAULTS = { + "BRAINPAT_TOKEN": "test-token", + "MODELS_MODE": "local", + "EMBEDDINGS_LOCAL_MODEL": "local-model", + "EMBEDDINGS_SMALL_MODEL": "small-model", + "EMBEDDING_NODES_DIMENSION": "3", + "EMBEDDING_TRIPLETS_DIMENSION": "3", + "EMBEDDING_OBSERVATIONS_DIMENSION": "3", + "EMBEDDING_DATA_DIMENSION": "3", + "EMBEDDING_RELATIONSHIPS_DIMENSION": "3", + "REDIS_HOST": "localhost", + "REDIS_PORT": "6379", + "NEO4J_HOST": "localhost", + "NEO4J_PORT": "7687", + "NEO4J_USERNAME": "neo4j", + "NEO4J_PASSWORD": "password", + "MILVUS_HOST": "localhost", + "MILVUS_PORT": "19530", + "MONGO_CONNECTION_STRING": "mongodb://localhost:27017", + "CELERY_WORKER_CONCURRENCY": "1", + "OLLAMA_HOST": "localhost", + "OLLAMA_PORT": "11434", + "OLLAMA_LLM_SMALL_MODEL": "small", + "OLLAMA_LLM_LARGE_MODEL": "large", +} +for key, value in ENV_DEFAULTS.items(): + os.environ.setdefault(key, value) + +from src.adapters.graph import GraphAdapter +from src.adapters.graph_operation_result_serializer import Neo4jResultSerializer +from src.constants.kg import Node, Predicate +from src.constants.tasks.ingestion import IngestionTaskJsonArgs, IngestionTaskTextArgs +from src.core.ingestion.text_strategy import ( + IngestionTextStrategyFactory, + extract_ingestion_text, +) + + +class FakeVector: + def __init__(self, embeddings, metadata): + self.embeddings = embeddings + self.metadata = metadata + + +class IngestionTextStrategyTests(unittest.TestCase): + def test_extract_ingestion_text_uses_raw_text_payload(self): + payload = IngestionTaskTextArgs(text_data="plain text") + self.assertEqual(extract_ingestion_text(payload), "plain text") + + def test_extract_ingestion_text_uses_json_payload(self): + payload = IngestionTaskJsonArgs(json_data={"k": "v"}) + self.assertEqual(extract_ingestion_text(payload), json.dumps({"k": "v"})) + + def test_factory_raises_for_unknown_payload_type(self): + factory = IngestionTextStrategyFactory() + with self.assertRaises(ValueError): + factory.create(object()) + + +class IngestionTaskSourceRefactorTests(unittest.TestCase): + def test_ingest_data_uses_shared_payload_text_for_enrichment(self): + source = ( + Path(__file__).resolve().parent.parent / "src/workers/tasks/ingestion.py" + ).read_text(encoding="utf-8") + self.assertIn("payload_text = extract_ingestion_text(payload.data)", source) + self.assertIn( + "enrich_kg_from_input(payload_text, brain_id=payload.brain_id)", source + ) + self.assertNotIn( + "enrich_kg_from_input(payload.data.text_data, brain_id=payload.brain_id)", + source, + ) + + +class GraphAdapterNeighborAssemblyTests(unittest.TestCase): + def test_vectors_with_descriptions_uses_node_descriptions(self): + adapter = GraphAdapter() + predicate = Predicate(name="RELATES_TO", description="desc") + neighbors = [ + ( + predicate, + Node(uuid="node-a", labels=["Person"], name="Alice", description="A"), + ), + ( + predicate, + Node(uuid="node-b", labels=["Person"], name="Bob", description=None), + ), + ( + predicate, + Node( + uuid="node-c", + labels=["Person"], + name="Charlie", + description="C", + ), + ), + ] + vectors = { + "node-a": FakeVector([1.0, 0.0], {"uuid": "node-a"}), + "node-b": FakeVector([0.0, 1.0], {"uuid": "node-b"}), + } + result = adapter._vectors_with_descriptions(neighbors, vectors) + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["description"], "A") + self.assertIsNone(result[1]["description"]) + + def test_node_description_returns_none_for_missing_node(self): + adapter = GraphAdapter() + self.assertIsNone(adapter._node_description(None)) + + +class Neo4jResultSerializerTests(unittest.TestCase): + def test_serializer_falls_back_to_string_for_non_mapping_records(self): + class NonMappingRecord: + def __str__(self): + return "non-mapping-record" + + class FakeResult: + def __init__(self): + self.records = [NonMappingRecord()] + + def keys(self): + return ["value"] + + serializer = Neo4jResultSerializer() + with self.assertLogs( + "src.adapters.graph_operation_result_serializer", level="WARNING" + ) as captured_logs: + payload = json.loads(serializer.serialize(FakeResult())) + self.assertEqual(payload["records"], ["non-mapping-record"]) + self.assertEqual(payload["keys"], ["value"]) + self.assertTrue( + any("Failed to map neo4j record" in message for message in captured_logs.output) + ) + + def test_serializer_supports_iterable_keys_attribute(self): + class FakeResult: + def __init__(self): + self.records = [] + self.keys = ("a", "b") + + serializer = Neo4jResultSerializer() + payload = json.loads(serializer.serialize(FakeResult())) + self.assertEqual(payload["keys"], ["a", "b"]) + self.assertFalse(payload["truncated"]) + + +if __name__ == "__main__": + unittest.main()