Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 110 additions & 10 deletions kaizen/backend/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import datetime
import logging
from abc import ABC, abstractmethod

from pydantic_settings import BaseSettings
from kaizen.schema.core import Namespace, Entity, RecordedEntity

from kaizen.schema.conflict_resolution import EntityUpdate
from kaizen.schema.core import Entity, Namespace, RecordedEntity
from kaizen.schema.exceptions import KaizenException
from kaizen.utils.utils import serialize_content

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("entities-db")
Expand Down Expand Up @@ -36,15 +41,6 @@ def search_namespaces(self, limit: int = 10) -> list[Namespace]:
def delete_namespace(self, namespace_id: str):
pass

@abstractmethod
def update_entities(
self,
namespace_id: str,
entities: list[Entity],
enable_conflict_resolution: bool = True,
) -> list[EntityUpdate]:
pass

@abstractmethod
def search_entities(
self, namespace_id: str, query: str | None = None, filters: dict | None = None, limit: int = 10
Expand All @@ -54,3 +50,107 @@ def search_entities(
@abstractmethod
def delete_entity_by_id(self, namespace_id: str, entity_id: str):
pass

# ── update_entities template method ──────────────────────────────

@abstractmethod
def _validate_namespace(self, namespace_id: str) -> None:
"""Raise NamespaceNotFoundException if the namespace does not exist."""
pass

@abstractmethod
def _add_entity(self, namespace_id: str, entity_type: str, content_str: str, timestamp: int, metadata: dict) -> str:
"""Insert a new entity and return its ID as a string."""
pass

@abstractmethod
def _update_entity(self, namespace_id: str, entity_id: str, entity_type: str, content_str: str, timestamp: int, metadata: dict) -> None:
"""Update an existing entity in-place."""
pass

@abstractmethod
def _delete_entity(self, namespace_id: str, entity_id: str) -> None:
"""Delete an entity by ID."""
pass

def _post_update(self, namespace_id: str) -> None:
"""Hook called after all entity mutations are complete. No-op by default."""
pass

def update_entities(
self,
namespace_id: str,
entities: list[Entity],
enable_conflict_resolution: bool = True,
) -> list[EntityUpdate]:
from kaizen.llm.conflict_resolution.conflict_resolution import resolve_conflicts

self._validate_namespace(namespace_id)
if not entities:
logger.warning("No entities to update.")
return []

entity_type = entities[0].type
if not all(entity.type == entity_type for entity in entities):
raise KaizenException("All entities must have the same type.")

now = datetime.datetime.now(datetime.UTC)
timestamp = int(now.timestamp())

entities_with_temporary_ids: list[RecordedEntity] = []
for i, entity in enumerate(entities):
entity_data = entity.model_dump()
if entity_data.get("metadata") is None:
entity_data["metadata"] = {}
entities_with_temporary_ids.append(
RecordedEntity(
**entity_data,
created_at=datetime.datetime.now(datetime.UTC),
id=f"Unprocessed_Entity_{i}",
)
)

if enable_conflict_resolution:
old_entities: list[RecordedEntity] = []
for entity in entities:
query_str = serialize_content(entity.content)
old_entities.extend(
self.search_entities(
namespace_id=namespace_id,
query=query_str,
filters={"type": entity_type},
limit=10,
)
)

updates = resolve_conflicts(old_entities, entities_with_temporary_ids)
for update in updates:
content_str = serialize_content(update.content)
metadata = update.metadata or {}
match update.event:
case "ADD":
update.id = self._add_entity(namespace_id, entity_type, content_str, timestamp, metadata)
case "UPDATE":
self._update_entity(namespace_id, update.id, entity_type, content_str, timestamp, metadata)
case "DELETE":
self._delete_entity(namespace_id, update.id)
case "NONE":
pass
else:
updates = []
for entity in entities:
content_str = serialize_content(entity.content)
metadata = entity.metadata or {}
entity_id = self._add_entity(namespace_id, entity_type, content_str, timestamp, metadata)
updates.append(
EntityUpdate(
id=entity_id,
type=entity_type,
content=entity.content,
event="ADD",
metadata=metadata,
)
)

self._post_update(namespace_id)
return updates
144 changes: 52 additions & 92 deletions kaizen/backend/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from kaizen.backend.base import BaseEntityBackend
from kaizen.config.filesystem import FilesystemSettings, filesystem_settings
from kaizen.llm.conflict_resolution.conflict_resolution import resolve_conflicts
from kaizen.schema.conflict_resolution import EntityUpdate
from kaizen.schema.core import Entity, Namespace, RecordedEntity
from kaizen.schema.exceptions import (
Expand Down Expand Up @@ -40,6 +39,8 @@ def __init__(self, config: FilesystemSettings | None = None):
self.data_dir = Path(self.config.data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self._lock = Lock()
# Holds the loaded namespace data during update_entities so hooks can access it.
self._active_data: FilesystemNamespace | None = None

def _namespace_file(self, namespace_id: str) -> Path:
"""Get the path to a namespace's JSON file."""
Expand All @@ -65,6 +66,11 @@ def details(self) -> dict:
"""Return details about the backend."""
return {"data_dir": str(self.data_dir)}

def _validate_namespace(self, namespace_id: str) -> None:
file_path = self._namespace_file(namespace_id)
if not file_path.exists():
raise NamespaceNotFoundException(f"Namespace `{namespace_id}` not found")

def create_namespace(self, namespace_id: str | None = None) -> Namespace:
"""Create a new namespace for entities to exist in."""
namespace_id = namespace_id or "ns_" + str(uuid.uuid4()).replace("-", "_")
Expand Down Expand Up @@ -124,105 +130,56 @@ def delete_namespace(self, namespace_id: str):
return # Already deleted, no-op
file_path.unlink()

# ── update_entities hooks ────────────────────────────────────────

def _add_entity(self, namespace_id: str, entity_type: str, content_str: str, timestamp: int, metadata: dict) -> str:
assert self._active_data is not None
entity_id = str(self._active_data.next_id)
self._active_data.next_id += 1
created_at_iso = datetime.datetime.fromtimestamp(timestamp, datetime.UTC).isoformat()
self._active_data.entities.append(
{
"id": entity_id,
"type": entity_type,
"content": content_str,
"created_at": created_at_iso,
"metadata": metadata,
}
)
return entity_id

def _update_entity(self, namespace_id: str, entity_id: str, entity_type: str, content_str: str, timestamp: int, metadata: dict) -> None:
assert self._active_data is not None
created_at_iso = datetime.datetime.fromtimestamp(timestamp, datetime.UTC).isoformat()
for ent in self._active_data.entities:
if ent["id"] == entity_id:
ent["content"] = content_str
ent["created_at"] = created_at_iso
ent["metadata"] = metadata
break

def _delete_entity(self, namespace_id: str, entity_id: str) -> None:
assert self._active_data is not None
self._active_data.entities = [e for e in self._active_data.entities if e["id"] != entity_id]

def _post_update(self, namespace_id: str) -> None:
assert self._active_data is not None
self._active_data.num_entities = len(self._active_data.entities)
self._save_namespace_data(namespace_id, self._active_data)
self._active_data = None

def update_entities(
self,
namespace_id: str,
entities: list[Entity],
enable_conflict_resolution: bool = True,
) -> list[EntityUpdate]:
"""Add/update entities in a namespace."""
if len(entities) == 0:
return []

entity_type = entities[0].type
if not all(entity.type == entity_type for entity in entities):
raise KaizenException("All entities must have the same type.")

now = datetime.datetime.now(datetime.UTC)
now_iso = now.isoformat()

# Create temporary entities with placeholder IDs
entities_with_temporary_ids = []
for i, entity in enumerate(entities):
entity_data = entity.model_dump()
if entity_data.get("metadata") is None:
entity_data["metadata"] = {}
entities_with_temporary_ids.append(
RecordedEntity(
**entity_data,
created_at=now,
id=f"Unprocessed_Entity_{i}",
)
)

"""Override to wrap the base template in a lock with loaded data."""
with self._lock:
data = self._load_namespace_data(namespace_id)

if enable_conflict_resolution:
# Find similar existing entities for conflict resolution
old_entities = []
for entity in entities:
# Convert content to string for search query
query_str = entity.content if isinstance(entity.content, str) else json.dumps(entity.content)
similar = self._search_entities_internal(data, query=query_str, filters=None, limit=10)
old_entities.extend(similar)

updates = resolve_conflicts(old_entities, entities_with_temporary_ids)

for update in updates:
match update.event:
case "ADD":
entity_id = str(data.next_id)
data.next_id += 1
data.entities.append(
{
"id": entity_id,
"type": entity_type,
"content": update.content,
"created_at": now_iso,
"metadata": update.metadata,
}
)
update.id = entity_id
case "UPDATE":
for ent in data.entities:
if ent["id"] == update.id:
ent["content"] = update.content
ent["created_at"] = now_iso
ent["metadata"] = update.metadata
break
case "DELETE":
data.entities = [e for e in data.entities if e["id"] != update.id]
case "NONE":
pass
else:
updates = []
for entity in entities:
entity_id = str(data.next_id)
data.next_id += 1
data.entities.append(
{
"id": entity_id,
"type": entity_type,
"content": entity.content,
"created_at": now_iso,
"metadata": entity.metadata,
}
)
updates.append(
EntityUpdate(
id=entity_id,
type=entity_type,
content=entity.content,
event="ADD",
metadata=entity.metadata,
)
)

data.num_entities = len(data.entities)
self._save_namespace_data(namespace_id, data)
self._active_data = self._load_namespace_data(namespace_id)
return super().update_entities(namespace_id, entities, enable_conflict_resolution)
Comment thread
illeatmyhat marked this conversation as resolved.

return updates
# ── search ───────────────────────────────────────────────────────

def _search_entities_internal(
self,
Expand Down Expand Up @@ -291,6 +248,9 @@ def search_entities(
limit: int = 10,
) -> list[RecordedEntity]:
"""Search for entities in a namespace."""
# If called during update_entities (inside the lock), use the active data
if self._active_data is not None:
return self._search_entities_internal(self._active_data, query, filters, limit)
with self._lock:
data = self._load_namespace_data(namespace_id)
return self._search_entities_internal(data, query, filters, limit)
Expand Down
Loading
Loading