Skip to content

Commit b674080

Browse files
committed
feat(diffctx): BM25 discovery channel + file content cache
Add BM25Discovery (top_k=1) as additional discovery channel in EnsembleDiscovery. BM25 finds files by lexical similarity to diff text, complementing structural import-based discovery. Add file content cache in pipeline to avoid redundant file reads across discovery channels. LOO recall: 0% → 5% (3/60 found).
1 parent 9dbe9bc commit b674080

2 files changed

Lines changed: 89 additions & 2 deletions

File tree

src/treemapper/diffctx/pipeline.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .git import CatFileBatch, GitError, split_diff_range
1919
from .postpass import _coherence_post_pass, _ensure_changed_files_represented
2020
from .render import build_diff_context_output
21-
from .scoring import DefaultDiscovery, DiscoveryContext, PPRScoring, ScoringStrategy
21+
from .scoring import DiscoveryContext, EnsembleDiscovery, PPRScoring, ScoringStrategy
2222
from .select import lazy_greedy_select
2323
from .signatures import _generate_signature_variants
2424
from .types import Fragment, FragmentId
@@ -207,14 +207,23 @@ def build_diff_context(
207207

208208
t1 = time.perf_counter()
209209

210+
file_cache: dict[Path, str] = {}
211+
for f in all_candidate_files:
212+
try:
213+
if f.stat().st_size <= 100_000:
214+
file_cache[f] = f.read_text(encoding="utf-8")
215+
except (OSError, UnicodeDecodeError):
216+
continue
217+
210218
discovery_ctx = DiscoveryContext(
211219
root_dir=root_dir,
212220
changed_files=changed_files,
213221
all_candidate_files=all_candidate_files,
214222
diff_text=diff_text,
215223
expansion_concepts=frozenset(expansion_concepts),
224+
file_cache=file_cache,
216225
)
217-
discovery_strategy = DefaultDiscovery()
226+
discovery_strategy = EnsembleDiscovery()
218227
discovered_files = discovery_strategy.discover(discovery_ctx)
219228
discovered_files = [_normalize_path(p, root_dir) for p in discovered_files]
220229
all_fragments.extend(

src/treemapper/diffctx/scoring.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ class DiscoveryContext:
1515
all_candidate_files: list[Path]
1616
diff_text: str
1717
expansion_concepts: frozenset[str]
18+
file_cache: dict[Path, str] | None = None
19+
20+
def read_file(self, path: Path) -> str | None:
21+
if self.file_cache is not None and path in self.file_cache:
22+
return self.file_cache[path]
23+
try:
24+
return path.read_text(encoding="utf-8")
25+
except (OSError, UnicodeDecodeError):
26+
return None
1827

1928

2029
class DiscoveryStrategy(ABC):
@@ -43,6 +52,75 @@ def discover(self, ctx: DiscoveryContext) -> list[Path]:
4352
return list(dict.fromkeys(edge_discovered + expanded))
4453

4554

55+
class BM25Discovery(DiscoveryStrategy):
56+
def __init__(self, top_k: int = 1) -> None:
57+
self.top_k = top_k
58+
59+
def discover(self, ctx: DiscoveryContext) -> list[Path]:
60+
import math
61+
import re
62+
from collections import Counter
63+
64+
token_re = re.compile(r"[A-Za-z_]\w{2,}")
65+
changed_set = set(ctx.changed_files)
66+
67+
query_tokens = [m.group().lower() for m in token_re.finditer(ctx.diff_text)]
68+
if not query_tokens:
69+
return []
70+
71+
corpus: list[list[str]] = []
72+
paths: list[Path] = []
73+
for f in ctx.all_candidate_files:
74+
if f in changed_set:
75+
continue
76+
content = ctx.read_file(f)
77+
if content is None:
78+
continue
79+
corpus.append([m.group().lower() for m in token_re.finditer(content)])
80+
paths.append(f)
81+
82+
if not corpus:
83+
return []
84+
85+
n_docs = len(corpus)
86+
avgdl = sum(len(d) for d in corpus) / n_docs
87+
df: Counter[str] = Counter()
88+
for doc in corpus:
89+
for term in set(doc):
90+
df[term] += 1
91+
92+
query_set = set(query_tokens)
93+
idf = {t: math.log((n_docs - df.get(t, 0) + 0.5) / (df.get(t, 0) + 0.5) + 1.0) for t in query_set}
94+
95+
scores: list[float] = []
96+
for doc in corpus:
97+
tf: Counter[str] = Counter(doc)
98+
dl = len(doc)
99+
s = 0.0
100+
for t in query_set:
101+
if t not in tf:
102+
continue
103+
freq = tf[t]
104+
s += idf.get(t, 0) * (freq * 2.5) / (freq + 1.5 * (1 - 0.75 + 0.75 * dl / avgdl))
105+
scores.append(s)
106+
107+
ranked = sorted(range(len(scores)), key=lambda i: -scores[i])
108+
return [paths[i] for i in ranked[: self.top_k] if scores[i] > 0]
109+
110+
111+
class EnsembleDiscovery(DiscoveryStrategy):
112+
def __init__(self, strategies: list[DiscoveryStrategy] | None = None) -> None:
113+
self._strategies = strategies or [DefaultDiscovery(), BM25Discovery()]
114+
115+
def discover(self, ctx: DiscoveryContext) -> list[Path]:
116+
seen: dict[Path, None] = {}
117+
for strategy in self._strategies:
118+
for path in strategy.discover(ctx):
119+
if path not in seen:
120+
seen[path] = None
121+
return list(seen.keys())
122+
123+
46124
@dataclass
47125
class ScoringResult:
48126
rel_scores: dict[FragmentId, float]

0 commit comments

Comments
 (0)