Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions src/adapters/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,30 @@ def _average_embeddings(self, vectors: list[Vector]) -> list[float]:
for i in range(dimension)
]

def _node_description(self, node: Optional[Node]) -> Optional[str]:
if node is None:
return None
return node.description

def _vectors_with_descriptions(
self,
neighbors: List[Tuple[Predicate, Node]],
vectors_by_uuid: Dict[str, Vector],
) -> list[dict]:
vectors_with_desc = []
for _, node in neighbors:
vector = vectors_by_uuid.get(node.uuid)
if vector is None:
continue
vectors_with_desc.append(
{
"embeddings": vector.embeddings,
"metadata": vector.metadata,
"description": self._node_description(node),
}
)
return vectors_with_desc

def get_2nd_degree_hops(
self,
from_uuids: List[str],
Expand Down Expand Up @@ -588,20 +612,7 @@ def flatten_pred(p):
filtered_fd_by_origin: Dict[str, List[Tuple[Predicate, Node]]] = {}

for node_uuid, fd_list in all_fd_nodes.items():
fd_nodes_by_uuid = {fd[1].uuid: fd[1] for fd in fd_list}
fd_vs_with_desc = [
{
"embeddings": fd_vs_by_uuid[fd[1].uuid].embeddings,
"metadata": fd_vs_by_uuid[fd[1].uuid].metadata,
"description": (
fd_nodes_by_uuid.get(fd[1].uuid, {}).description
if fd_nodes_by_uuid.get(fd[1].uuid)
else None
),
}
for fd in fd_list
if fd[1].uuid in fd_vs_by_uuid
]
fd_vs_with_desc = self._vectors_with_descriptions(fd_list, fd_vs_by_uuid)
from_node = nodes_by_uuid[node_uuid]
filtered_uuids = self._reduce_neighbor_vectors(
vectors_with_desc=fd_vs_with_desc,
Expand Down Expand Up @@ -646,21 +657,7 @@ def flatten_pred(p):

for fd_pred, fd_node in filtered_fd_by_origin.get(from_uuid, []):
sd_list = all_sd_nodes.get(fd_node.uuid, [])
sd_nodes_by_uuid = {sd[1].uuid: sd[1] for sd in sd_list}

sd_vs_with_desc = [
{
"embeddings": sd_vs_by_uuid[sd[1].uuid].embeddings,
"metadata": sd_vs_by_uuid[sd[1].uuid].metadata,
"description": (
sd_nodes_by_uuid.get(sd[1].uuid, {}).description
if sd_nodes_by_uuid.get(sd[1].uuid)
else None
),
}
for sd in sd_list
if sd[1].uuid in sd_vs_by_uuid
]
sd_vs_with_desc = self._vectors_with_descriptions(sd_list, sd_vs_by_uuid)

reduced_uuids = self._reduce_neighbor_vectors(
vectors_with_desc=sd_vs_with_desc,
Expand Down
46 changes: 31 additions & 15 deletions src/adapters/graph_operation_result_serializer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Iterable

logger = logging.getLogger(__name__)


class GraphOperationResultSerializer(ABC):
@abstractmethod
Expand Down Expand Up @@ -33,26 +36,39 @@ class Neo4jResultSerializer(GraphOperationResultSerializer):
def can_handle(self, result: Any) -> bool:
return hasattr(result, "records")

def _serialize_record(self, record: Any) -> dict | str:
record_data = getattr(record, "data", None)
if callable(record_data):
return record_data()
try:
return dict(record)
except (TypeError, ValueError) as exc:
logger.warning("Failed to map neo4j record, using string fallback: %s", exc)
return str(record)

def _extract_keys(self, result: Any) -> list | None:
keys_accessor = getattr(result, "keys", None)
if callable(keys_accessor):
try:
return list(keys_accessor())
except (TypeError, ValueError) as exc:
logger.warning(
"Failed to extract neo4j result keys from callable: %s", exc
)
return None
if isinstance(keys_accessor, Iterable) and not isinstance(
keys_accessor, (str, bytes)
):
return list(keys_accessor)
return None

def serialize(self, result: Any) -> str:
records = result.records or []
limited_records = records[:20]
serialized_records = []
for record in limited_records:
if hasattr(record, "data"):
serialized_records.append(record.data())
continue
try:
serialized_records.append(dict(record))
except Exception:
serialized_records.append(str(record))
keys = None
try:
keys = list(result.keys())
except Exception:
try:
keys = list(result.keys)
except Exception:
keys = None
serialized_records.append(self._serialize_record(record))
keys = self._extract_keys(result)
payload = {
"records": serialized_records,
"truncated": len(records) > 20,
Expand Down
6 changes: 6 additions & 0 deletions src/core/ingestion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from src.core.ingestion.text_strategy import (
IngestionTextStrategyFactory,
extract_ingestion_text,
)

__all__ = ["IngestionTextStrategyFactory", "extract_ingestion_text"]
58 changes: 58 additions & 0 deletions src/core/ingestion/text_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import json
from abc import ABC, abstractmethod

from src.constants.tasks.ingestion import IngestionTaskJsonArgs, IngestionTaskTextArgs


class IngestionTextStrategy(ABC):
@abstractmethod
def can_handle(self, payload_data: object) -> bool:
raise NotImplementedError

@abstractmethod
def extract(self, payload_data: object) -> str:
raise NotImplementedError


class RawTextIngestionStrategy(IngestionTextStrategy):
def can_handle(self, payload_data: object) -> bool:
return isinstance(payload_data, IngestionTaskTextArgs)

def extract(self, payload_data: object) -> str:
if not isinstance(payload_data, IngestionTaskTextArgs):
raise ValueError("Invalid payload type for text ingestion strategy")
return payload_data.text_data


class JsonIngestionStrategy(IngestionTextStrategy):
def can_handle(self, payload_data: object) -> bool:
return isinstance(payload_data, IngestionTaskJsonArgs)

def extract(self, payload_data: object) -> str:
if not isinstance(payload_data, IngestionTaskJsonArgs):
raise ValueError("Invalid payload type for json ingestion strategy")
return json.dumps(payload_data.json_data)


class IngestionTextStrategyFactory:
def __init__(self, strategies: list[IngestionTextStrategy] | None = None):
self._strategies = strategies or [
RawTextIngestionStrategy(),
JsonIngestionStrategy(),
]

def create(self, payload_data: object) -> IngestionTextStrategy:
for strategy in self._strategies:
if strategy.can_handle(payload_data):
return strategy
raise ValueError(
f"Unsupported ingestion payload type: {type(payload_data).__name__}"
)


_default_ingestion_text_strategy_factory = IngestionTextStrategyFactory()


def extract_ingestion_text(payload_data: object) -> str:
strategy = _default_ingestion_text_strategy_factory.create(payload_data)
return strategy.extract(payload_data)
18 changes: 6 additions & 12 deletions src/workers/tasks/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from src.constants.agents import ArchitectAgentRelationship
from src.constants.prompts.misc import NODE_DESCRIPTION_PROMPT
from src.core.plugins.prompts import prompt_registry
from src.core.ingestion.text_strategy import extract_ingestion_text
from src.core.saving.auto_kg import enrich_kg_from_input
from src.core.saving.ingestion_manager import IngestionManager
from src.services.api.constants.requests import IngestionStructuredRequestBody
Expand All @@ -52,7 +53,6 @@
from src.workers.app import ingestion_app
from src.constants.tasks.ingestion import (
IngestionTaskArgs,
IngestionTaskDataType,
IngestionTaskTextArgs,
)
from src.services.kg_agent.main import cache_adapter
Expand Down Expand Up @@ -121,13 +121,11 @@ def ingest_data(self, args: dict):
# ================================================
# --------------- Data Saving --------------------
# ================================================
payload_text = extract_ingestion_text(payload.data)

text_chunk = data_adapter.save_text_chunk(
TextChunk(
text=(
payload.data.text_data
if payload.data.data_type == IngestionTaskDataType.TEXT.value
else json.dumps(payload.data.json_data)
),
text=payload_text,
metadata=payload.meta_keys,
brain_version=BRAIN_VERSION,
),
Expand All @@ -152,11 +150,7 @@ def ingest_data(self, args: dict):
# --------------- Observations -------------------
# ================================================
observations = observations_agent.observe(
text=(
payload.data.text_data
if payload.data.data_type == IngestionTaskDataType.TEXT.value
else json.dumps(payload.data.json_data)
),
text=payload_text,
observate_for=payload.observate_for,
)

Expand All @@ -174,7 +168,7 @@ def ingest_data(self, args: dict):
# ================================================
# ------------ Triplet Extraction ----------------
# ================================================
enrich_kg_from_input(payload.data.text_data, brain_id=payload.brain_id)
enrich_kg_from_input(payload_text, brain_id=payload.brain_id)

cache_adapter.set(
key=f"task:{self.request.id}",
Expand Down
18 changes: 18 additions & 0 deletions tests/test_architecture_refactors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,20 @@ def get_by_uuids(self, *_args, **_kwargs):
def get_neighbors(self, *_args, **_kwargs):
return {}

def get_nodes_by_uuid(self, *_args, **_kwargs):
return []


class _StubDataAdapter:
def get_observations_list(self, *_args, **_kwargs):
return []

def search(self, *_args, **_kwargs):
return type("SearchResult", (), {"text_chunks": [], "observations": []})()

def get_text_chunks_by_ids(self, *_args, **_kwargs):
return ([], [])


_stub_input_agents = types.ModuleType("src.services.input.agents")
_stub_input_agents.embeddings_adapter = _StubEmbeddingsAdapter()
Expand All @@ -76,12 +85,21 @@ def get_observations_list(self, *_args, **_kwargs):
_stub_kg_main.graph_adapter = _StubGraphAdapter()
_stub_kg_main.vector_store_adapter = _StubVectorStoreAdapter()
_stub_kg_main.embeddings_adapter = _StubEmbeddingsAdapter()
_stub_kg_main.kg_agent = types.SimpleNamespace(
retrieve_neighbors=lambda *_args, **_kwargs: []
)
sys.modules.setdefault("src.services.kg_agent.main", _stub_kg_main)

_stub_data_main = types.ModuleType("src.services.data.main")
_stub_data_main.data_adapter = _StubDataAdapter()
sys.modules.setdefault("src.services.data.main", _stub_data_main)

_stub_ner = types.ModuleType("src.utils.nlp.ner")
_stub_ner._entity_extractor = types.SimpleNamespace(
extract_entities=lambda *_args, **_kwargs: []
)
sys.modules.setdefault("src.utils.nlp.ner", _stub_ner)

from src.services.api.controllers.entities import get_entity_status
from src.services.api.controllers.retrieve import retrieve_data
from src.utils.vector_search import VectorSearchFacade
Expand Down
Loading