Skip to content

Commit 922e931

Browse files
committed
bugs fixed
1 parent e5696a3 commit 922e931

15 files changed

Lines changed: 278 additions & 138 deletions

contextdb/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
__version__ = "0.4.0"
22

3-
from contextdb.api.condb import ConDB, ConDBError, LLMNotConfiguredError, QueryResult, TreeNotFoundError
3+
from contextdb.api.condb import (
4+
ConDB,
5+
ConDBError,
6+
LLMNotConfiguredError,
7+
QueryResult,
8+
TreeNotFoundError,
9+
open, # noqa: A004
10+
)
411
from contextdb.api.context_tree import ContextTree
5-
from contextdb.api.condb import open # noqa: A004
612
from contextdb.core.storage import Entity, Node, StorageProtocol, TreeDB
713
from contextdb.llm import LLMClient, LLMProtocol
814
from contextdb.retriever import BeamRetriever, BlockRetriever, ManualRetriever, RetrievalResult

contextdb/config/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Configuration loader for ContextDB."""
22

33
import os
4-
import yaml
54
from pathlib import Path
6-
from typing import Any
75

6+
import yaml
87
from dotenv import load_dotenv
98

109
CONFIG_DIR = Path(__file__).parent

contextdb/retriever/algorithm/block_cutter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _generate_block_content(self, nodes: list[dict]) -> str:
367367
meta_lines.append(f" range: {page_start}-{page_end}")
368368

369369
node_metadata.append(meta_lines)
370-
metadata_chars += sum(len(l) for l in meta_lines)
370+
metadata_chars += sum(len(line) for line in meta_lines)
371371

372372
text = payload.get("text") or payload.get("content") or ""
373373
node_texts.append(text)

contextdb/retriever/algorithm/block_retriever.py

Lines changed: 140 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
]
5959

6060

61-
class BlockRetriever(BaseRetriever):
61+
class BlockRetriever(BlockRetrieverFilesystemSupport, BlockRetrieverPromptCacheSupport, BaseRetriever):
6262

6363
def __init__(
6464
self,
@@ -118,22 +118,14 @@ def __init__(
118118
self.max_tokens_per_block,
119119
self.min_tokens_per_block,
120120
)
121-
self._filesystem_support = BlockRetrieverFilesystemSupport(self)
122-
self._prompt_cache_support = BlockRetrieverPromptCacheSupport(self)
123121
self._plan_cache: dict[str, BlockTreePlan] = {}
124122
self._precomputed_tree_id: str = ""
125-
126-
def __getattr__(self, name: str):
127-
if name.startswith("__"):
128-
raise AttributeError(name)
129-
130-
for support_name in ("_filesystem_support", "_prompt_cache_support"):
131-
support = self.__dict__.get(support_name)
132-
if support is None:
133-
continue
134-
if hasattr(type(support), name):
135-
return getattr(support, name)
136-
raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}")
123+
self._fs_node_cache: dict[tuple[str, str], dict[str, Any]] = {}
124+
self._fs_attrs_cache: dict[tuple[str, str], dict[str, Any]] = {}
125+
self._fs_path_cache: dict[tuple[str, str], str] = {}
126+
self._fs_is_dir_cache: dict[tuple[str, str], bool] = {}
127+
self._fs_children_cache: dict[tuple[str, str], list[dict[str, Any]]] = {}
128+
self._fs_block_render_cache: dict[tuple[str, ...], tuple[str, int]] = {}
137129

138130
def retrieve(
139131
self,
@@ -162,6 +154,7 @@ def _retrieve_fs(
162154
return self._empty_result()
163155

164156
if tree_id != self._precomputed_tree_id:
157+
self._clear_fs_lookup_cache()
165158
self.token_counter.clear_cache()
166159
self.token_counter.precompute_tree_tokens(self.storage, tree_id)
167160
self._precomputed_tree_id = tree_id
@@ -781,18 +774,18 @@ def _process_block(
781774
def _update_frontier(self, node_ids, tree_id, beam_size):
782775
next_frontier = []
783776
for node_id in node_ids:
784-
node = self.storage.get_node(tree_id, node_id)
785-
attrs = {}
786-
if node and node.attrs_json:
787-
try:
788-
attrs = json.loads(node.attrs_json)
789-
except json.JSONDecodeError:
790-
attrs = {}
791-
792-
frontier_path = attrs.get("rel_path", "") if self.mode == "filesystem" else (node.path if node else "")
777+
if self.mode == "filesystem":
778+
node = self._get_cached_fs_node_dict(tree_id, node_id)
779+
attrs = self._get_cached_fs_attrs(tree_id, node_id, node=node)
780+
frontier_path = attrs.get("rel_path", "")
781+
title = attrs.get("title", "")
782+
else:
783+
node = self.storage.get_node(tree_id, node_id)
784+
frontier_path = node.path if node else ""
785+
title = ""
793786
next_frontier.append({
794787
"node_id": node_id,
795-
"title": attrs.get("title", ""),
788+
"title": title,
796789
"path": frontier_path,
797790
})
798791
if beam_size and len(next_frontier) >= beam_size:
@@ -806,7 +799,13 @@ def _update_beams(self, node_ids, tree_id, beam_size):
806799
def _frontier_has_children(self, tree_id: str, frontier: list[dict[str, str]]) -> bool:
807800
for frontier_node in frontier:
808801
node_id = frontier_node.get("node_id", "")
809-
if node_id and self.storage.get_children(tree_id, node_id):
802+
if not node_id:
803+
continue
804+
if self.mode == "filesystem":
805+
if self._fs_node_has_children(tree_id, node_id):
806+
return True
807+
continue
808+
if self.storage.get_children(tree_id, node_id):
810809
return True
811810
return False
812811

@@ -829,14 +828,14 @@ def _override_done_if_dirs(self, result: BlockResult, tree_id: str, beams: list[
829828
return self._override_done_if_frontier_dirs(result, tree_id, beams)
830829

831830
def _is_fs_directory_id(self, tree_id: str, node_id: str) -> bool:
832-
node = self.storage.get_node(tree_id, node_id)
833-
if not node or not node.attrs_json:
834-
return False
835-
try:
836-
attrs = json.loads(node.attrs_json)
837-
except json.JSONDecodeError:
838-
return False
839-
return bool(attrs.get("is_dir", False))
831+
key = (tree_id, node_id)
832+
if key in self._fs_is_dir_cache:
833+
return self._fs_is_dir_cache[key]
834+
835+
attrs = self._get_cached_fs_attrs(tree_id, node_id)
836+
is_dir = bool(attrs.get("is_dir", False))
837+
self._fs_is_dir_cache[key] = is_dir
838+
return is_dir
840839

841840
# ---- allowed node filtering (dynamic, but content stays fixed) ----
842841

@@ -874,19 +873,120 @@ def _get_node_paths(self, tree_id: str, node_ids: list[str]) -> dict[str, str]:
874873

875874
cursor = self.storage.conn.cursor()
876875
path_map: dict[str, str] = {}
876+
missing: list[str] = []
877+
seen_missing: set[str] = set()
878+
879+
for node_id in node_ids:
880+
key = (tree_id, node_id)
881+
cached_path = self._fs_path_cache.get(key)
882+
if cached_path is not None:
883+
path_map[node_id] = cached_path
884+
elif node_id not in seen_missing:
885+
seen_missing.add(node_id)
886+
missing.append(node_id)
887+
888+
if not missing:
889+
return {node_id: path_map[node_id] for node_id in node_ids if node_id in path_map}
890+
877891
chunk_size = 500
878892

879-
for i in range(0, len(node_ids), chunk_size):
880-
chunk = node_ids[i:i + chunk_size]
893+
for i in range(0, len(missing), chunk_size):
894+
chunk = missing[i:i + chunk_size]
881895
placeholders = ",".join("?" for _ in chunk)
882896
cursor.execute(
883897
f"SELECT node_id, path FROM nodes WHERE tree_id = ? AND node_id IN ({placeholders})",
884898
(tree_id, *chunk),
885899
)
886900
for row in cursor.fetchall():
887-
path_map[row["node_id"]] = row["path"]
901+
node_id = row["node_id"]
902+
path = row["path"]
903+
self._fs_path_cache[(tree_id, node_id)] = path
904+
path_map[node_id] = path
905+
906+
return {node_id: path_map[node_id] for node_id in node_ids if node_id in path_map}
907+
908+
@staticmethod
909+
def _parse_fs_attrs(attrs_value: Any) -> dict[str, Any]:
910+
if isinstance(attrs_value, dict):
911+
return attrs_value
912+
if isinstance(attrs_value, str) and attrs_value:
913+
try:
914+
parsed = json.loads(attrs_value)
915+
except json.JSONDecodeError:
916+
return {}
917+
return parsed if isinstance(parsed, dict) else {}
918+
return {}
919+
920+
def _remember_fs_node(self, tree_id: str, node: dict[str, Any]) -> dict[str, Any]:
921+
node_id = node.get("node_id")
922+
if not node_id:
923+
return node
924+
925+
key = (tree_id, node_id)
926+
existing = self._fs_node_cache.get(key)
927+
if existing:
928+
merged = {**existing, **node}
929+
if "entity" not in node and "entity" in existing:
930+
merged["entity"] = existing["entity"]
931+
node = merged
932+
933+
self._fs_node_cache[key] = node
934+
path = node.get("path")
935+
if isinstance(path, str):
936+
self._fs_path_cache[key] = path
937+
938+
attrs = self._parse_fs_attrs(node.get("attrs") if "attrs" in node else node.get("attrs_json"))
939+
self._fs_attrs_cache[key] = attrs
940+
self._fs_is_dir_cache[key] = bool(attrs.get("is_dir", False))
941+
return node
942+
943+
def _get_cached_fs_node_dict(self, tree_id: str, node_id: str) -> dict[str, Any] | None:
944+
key = (tree_id, node_id)
945+
cached = self._fs_node_cache.get(key)
946+
if cached is not None:
947+
return cached
888948

889-
return path_map
949+
node = self.storage.get_node(tree_id, node_id)
950+
if not node:
951+
return None
952+
return self._remember_fs_node(tree_id, node.to_dict())
953+
954+
def _get_cached_fs_attrs(
955+
self,
956+
tree_id: str,
957+
node_id: str,
958+
node: dict[str, Any] | None = None,
959+
) -> dict[str, Any]:
960+
key = (tree_id, node_id)
961+
cached = self._fs_attrs_cache.get(key)
962+
if cached is not None:
963+
return cached
964+
965+
node = node or self._get_cached_fs_node_dict(tree_id, node_id)
966+
if not node:
967+
attrs: dict[str, Any] = {}
968+
else:
969+
attrs = self._parse_fs_attrs(node.get("attrs") if "attrs" in node else node.get("attrs_json"))
970+
self._remember_fs_node(tree_id, node)
971+
972+
self._fs_attrs_cache[key] = attrs
973+
self._fs_is_dir_cache[key] = bool(attrs.get("is_dir", False))
974+
return attrs
975+
976+
def _fs_node_has_children(self, tree_id: str, node_id: str) -> bool:
977+
key = (tree_id, node_id)
978+
cached = self._fs_children_cache.get(key)
979+
if cached is not None:
980+
return bool(cached)
981+
return bool(self._get_direct_children_nodes(tree_id, node_id))
982+
983+
def _clear_fs_lookup_cache(self) -> None:
984+
self._fs_node_cache.clear()
985+
self._fs_attrs_cache.clear()
986+
self._fs_path_cache.clear()
987+
self._fs_is_dir_cache.clear()
988+
self._fs_children_cache.clear()
989+
self._fs_block_render_cache.clear()
890990

891991
# ---- DB helpers ----
892992

@@ -920,12 +1020,14 @@ def _empty_result(self):
9201020

9211021
def clear_cache(self):
9221022
self._plan_cache.clear()
1023+
self._clear_fs_lookup_cache()
9231024

9241025
def clear_plan_cache(self, tree_id=None):
9251026
if tree_id:
9261027
self._plan_cache.pop(tree_id, None)
9271028
else:
9281029
self._plan_cache.clear()
1030+
self._clear_fs_lookup_cache()
9291031

9301032
def get_cache_stats(self):
9311033
return {"plan_cache_size": len(self._plan_cache)}

0 commit comments

Comments
 (0)