Skip to content

Commit 44dc9fe

Browse files
committed
fix: scalability + OOM fixes for diffctx on large repos
1 parent e3141fa commit 44dc9fe

6 files changed

Lines changed: 209 additions & 43 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ __pycache__/
33
*.py[cod]
44
*$py.class
55
test-repos
6+
TODO.md
67

78
# C extensions
89
*.so

src/treemapper/diffctx/__init__.py

Lines changed: 159 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,101 @@ def _build_preferred_revs(base_rev: str | None, head_rev: str | None) -> list[st
102102
return revs
103103

104104

105+
_MAX_GENERATED_FRAGMENTS = LIMITS.max_generated_fragments
106+
107+
108+
_GENERATED_FILENAME_PATTERNS = frozenset(
109+
{
110+
".pb.go",
111+
"_pb2.py",
112+
"_pb2_grpc.py",
113+
".pb.h",
114+
".pb.cc",
115+
".pb.swift",
116+
".min.js",
117+
".min.css",
118+
".designer.cs",
119+
}
120+
)
121+
122+
_GENERATED_FILENAME_SUFFIXES = ("_generated.", "OuterClass.java")
123+
124+
_GENERATED_PATH_SEGMENTS = frozenset(
125+
{
126+
"generated",
127+
"gen-java",
128+
"gen-go",
129+
"gen-py",
130+
"gen-cpp",
131+
"gen-swift",
132+
"__generated__",
133+
"autogen",
134+
}
135+
)
136+
137+
_GENERATED_CONTENT_MARKERS = (
138+
"@generated",
139+
"do not edit",
140+
"code generated",
141+
"auto-generated",
142+
"this file is generated",
143+
"generated by",
144+
"automatically generated",
145+
"auto generated",
146+
)
147+
148+
149+
def _is_generated_file(path: Path, content: str) -> bool:
150+
name = path.name
151+
for pattern in _GENERATED_FILENAME_PATTERNS:
152+
if name.endswith(pattern):
153+
return True
154+
for suffix in _GENERATED_FILENAME_SUFFIXES:
155+
if name.endswith(suffix):
156+
return True
157+
158+
for part in path.parts:
159+
if part.lower() in _GENERATED_PATH_SEGMENTS:
160+
return True
161+
162+
header_lower = "\n".join(content[:2000].splitlines()[:5]).lower()
163+
for marker in _GENERATED_CONTENT_MARKERS:
164+
if marker in header_lower:
165+
return True
166+
167+
return False
168+
169+
105170
def _process_files_for_fragments(
106171
files: list[Path],
107172
root_dir: Path,
108173
preferred_revs: list[str],
109174
seen_frag_ids: set[FragmentId],
110175
) -> list[Fragment]:
176+
max_frags = LIMITS.max_fragments
111177
fragments: list[Fragment] = []
112178
for file_path in files:
113179
content = _read_file_content(file_path, root_dir, preferred_revs)
114180
if content is None:
115181
continue
116-
for frag in fragment_file(file_path, content):
117-
if frag.id not in seen_frag_ids:
118-
fragments.append(frag)
119-
seen_frag_ids.add(frag.id)
182+
file_frags = [f for f in fragment_file(file_path, content) if f.id not in seen_frag_ids]
183+
184+
is_generated = _is_generated_file(file_path, content)
185+
cap = _MAX_GENERATED_FRAGMENTS if is_generated else max_frags
186+
187+
if len(file_frags) > cap:
188+
file_frags.sort(key=lambda f: f.line_count, reverse=True)
189+
file_frags = file_frags[:cap]
190+
logging.debug(
191+
"diffctx: capped %s to %d fragments%s",
192+
file_path.name,
193+
cap,
194+
" (generated)" if is_generated else "",
195+
)
196+
197+
for frag in file_frags:
198+
fragments.append(frag)
199+
seen_frag_ids.add(frag.id)
120200
return fragments
121201

122202

@@ -301,11 +381,7 @@ def _filter_low_relevance_fragments(
301381
rel: dict[FragmentId, float],
302382
) -> list[Fragment]:
303383
changed_paths = {fid.path for fid in core_ids}
304-
kept = [
305-
f
306-
for f in fragments
307-
if f.path in changed_paths or rel.get(f.id, 0.0) >= _LOW_RELEVANCE_THRESHOLD
308-
]
384+
kept = [f for f in fragments if f.path in changed_paths or rel.get(f.id, 0.0) >= _LOW_RELEVANCE_THRESHOLD]
309385
removed = len(fragments) - len(kept)
310386
if removed:
311387
logging.debug("diffctx: filtered %d low-relevance fragments (threshold=%.4f)", removed, _LOW_RELEVANCE_THRESHOLD)
@@ -317,6 +393,8 @@ def _ensure_changed_files_represented(
317393
all_fragments: list[Fragment],
318394
changed_files: list[Path],
319395
remaining_budget: int,
396+
root_dir: Path,
397+
preferred_revs: list[str],
320398
) -> list[Fragment]:
321399
selected_paths = {f.path for f in selected}
322400
changed_paths = set(changed_files)
@@ -336,6 +414,20 @@ def _ensure_changed_files_represented(
336414

337415
for path in sorted(missing_paths):
338416
candidates = frags_by_path.get(path, [])
417+
418+
if not candidates:
419+
content = _read_file_content(path, root_dir, preferred_revs)
420+
if content and content.strip():
421+
lines = content.splitlines()
422+
frag = Fragment(
423+
id=FragmentId(path=path, start_line=1, end_line=len(lines)),
424+
kind="chunk",
425+
content=content,
426+
identifiers=extract_identifiers(content),
427+
)
428+
frag.token_count = count_tokens(content).count + _OVERHEAD_PER_FRAGMENT
429+
candidates = [frag]
430+
339431
if not candidates:
340432
continue
341433
best = max(candidates, key=lambda f: f.token_count if f.token_count > 0 else 0)
@@ -433,18 +525,21 @@ def build_diff_context(
433525
seen_frag_ids: set[FragmentId] = set()
434526
all_fragments = _process_files_for_fragments(changed_files, root_dir, preferred_revs, seen_frag_ids)
435527

436-
all_candidate_files = _collect_candidate_files(root_dir, set(changed_files), combined_spec)
528+
all_candidate_files, is_large_repo = _collect_candidate_files(root_dir, set(changed_files), combined_spec)
437529
all_candidate_files = _filter_whitelist(all_candidate_files, root_dir, wl_spec)
438530

439531
edge_discovered = discover_all_related_files(changed_files, all_candidate_files, root_dir)
440532
edge_discovered = [_normalize_path(p, root_dir) for p in edge_discovered]
441533
all_fragments.extend(_process_files_for_fragments(edge_discovered, root_dir, preferred_revs, seen_frag_ids))
442534

443-
expanded_files = _expand_universe_by_rare_identifiers(
444-
root_dir, expansion_concepts, changed_files + edge_discovered, combined_spec
445-
)
446-
expanded_files = [_normalize_path(p, root_dir) for p in expanded_files]
447-
all_fragments.extend(_process_files_for_fragments(expanded_files, root_dir, preferred_revs, seen_frag_ids))
535+
if not is_large_repo:
536+
expanded_files = _expand_universe_by_rare_identifiers(
537+
root_dir, expansion_concepts, changed_files + edge_discovered, combined_spec
538+
)
539+
expanded_files = [_normalize_path(p, root_dir) for p in expanded_files]
540+
all_fragments.extend(_process_files_for_fragments(expanded_files, root_dir, preferred_revs, seen_frag_ids))
541+
else:
542+
logging.debug("diffctx: skipping rare-identifier expansion for large repo")
448543

449544
for frag in all_fragments:
450545
frag.token_count = count_tokens(frag.content).count + _OVERHEAD_PER_FRAGMENT
@@ -473,7 +568,7 @@ def build_diff_context(
473568
)
474569
effective_budget = budget_tokens if budget_tokens is not None else _UNLIMITED_BUDGET
475570
remaining = effective_budget - result.used_tokens
476-
selected = _ensure_changed_files_represented(selected, all_fragments, changed_files, remaining)
571+
selected = _ensure_changed_files_represented(selected, all_fragments, changed_files, remaining, root_dir, preferred_revs)
477572
_log_ppr_mode(selected, core_ids, budget_tokens, result, alpha, tau)
478573

479574
if no_content:
@@ -717,7 +812,39 @@ def _is_candidate_file(file_path: Path, root_dir: Path, included_set: set[Path],
717812
return True
718813

719814

720-
def _collect_candidate_files(root_dir: Path, included_set: set[Path], combined_spec: pathspec.PathSpec) -> list[Path]:
815+
_MAX_CANDIDATE_FILES = LIMITS.max_candidate_files
816+
817+
818+
def _prioritize_candidates(
819+
candidates: list[Path],
820+
changed_files: set[Path],
821+
) -> list[Path]:
822+
changed_dirs: set[Path] = set()
823+
changed_extensions: set[str] = set()
824+
for f in changed_files:
825+
changed_dirs.add(f.parent)
826+
if f.parent.parent != f.parent:
827+
changed_dirs.add(f.parent.parent)
828+
if f.suffix:
829+
changed_extensions.add(f.suffix.lower())
830+
831+
priority: list[Path] = []
832+
rest: list[Path] = []
833+
for c in candidates:
834+
if c.parent in changed_dirs or c.suffix.lower() in changed_extensions:
835+
priority.append(c)
836+
else:
837+
rest.append(c)
838+
839+
budget = _MAX_CANDIDATE_FILES - len(priority)
840+
if budget > 0:
841+
priority.extend(rest[:budget])
842+
return priority[:_MAX_CANDIDATE_FILES]
843+
844+
845+
def _collect_candidate_files(
846+
root_dir: Path, included_set: set[Path], combined_spec: pathspec.PathSpec
847+
) -> tuple[list[Path], bool]:
721848
try:
722849
result = subprocess.run(
723850
["git", "ls-files", "-z"],
@@ -729,18 +856,27 @@ def _collect_candidate_files(root_dir: Path, included_set: set[Path], combined_s
729856
if result.returncode == 0 and result.stdout:
730857
out = result.stdout.decode("utf-8", errors="surrogateescape")
731858
files = [root_dir / f for f in out.split("\0") if f]
732-
return [f for f in files if _is_candidate_file(f, root_dir, included_set, combined_spec)]
859+
candidates = [f for f in files if _is_candidate_file(f, root_dir, included_set, combined_spec)]
860+
is_large_repo = len(candidates) > _MAX_CANDIDATE_FILES
861+
if is_large_repo:
862+
logging.debug(
863+
"diffctx: %d candidates exceed cap %d, prioritizing by proximity",
864+
len(candidates),
865+
_MAX_CANDIDATE_FILES,
866+
)
867+
candidates = _prioritize_candidates(candidates, included_set)
868+
return candidates, is_large_repo
733869
except (subprocess.SubprocessError, OSError):
734870
pass
735871
logging.warning("diffctx: git ls-files failed, falling back to rglob (limit %d files)", _FALLBACK_MAX_FILES)
736-
candidates: list[Path] = []
872+
fallback: list[Path] = []
737873
for f in root_dir.rglob("*"):
738-
if len(candidates) >= _FALLBACK_MAX_FILES:
874+
if len(fallback) >= _FALLBACK_MAX_FILES:
739875
logging.warning("diffctx: fallback scan hit limit, results may be incomplete")
740876
break
741877
if _is_candidate_file(f, root_dir, included_set, combined_spec):
742-
candidates.append(f)
743-
return candidates
878+
fallback.append(f)
879+
return fallback, False
744880

745881

746882
def _build_ident_index(files: list[Path], concepts: frozenset[str]) -> dict[str, list[Path]]:
@@ -822,7 +958,7 @@ def _expand_universe_by_rare_identifiers(
822958
return []
823959

824960
included_set = set(already_included)
825-
files = _collect_candidate_files(root_dir, included_set, combined_spec)
961+
files, _ = _collect_candidate_files(root_dir, included_set, combined_spec)
826962
inverted_index = _build_ident_index(files, concepts)
827963
return _collect_expansion_files(inverted_index, concepts, included_set)
828964

src/treemapper/diffctx/config/limits.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
class AlgorithmLimits:
88
max_file_size: int = 100_000
99
max_fragments: int = 200
10+
max_generated_fragments: int = 5
11+
max_candidate_files: int = 5000
12+
skip_expensive_threshold: int = 2000
1013
rare_identifier_threshold: int = 3
1114
max_expansion_files: int = 20
1215
overhead_per_fragment: int = 18

src/treemapper/diffctx/edges/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from collections.abc import Callable
45
from pathlib import Path
56
from typing import TYPE_CHECKING
@@ -40,10 +41,20 @@ def get_all_builders() -> list[EdgeBuilder]:
4041
return [cls() for cls in all_builder_classes]
4142

4243

43-
def collect_all_edges(fragments: list[Fragment], repo_root: Path | None = None) -> tuple[EdgeDict, EdgeCategories]:
44+
_EXPENSIVE_CATEGORIES = frozenset({"similarity", "history"})
45+
46+
47+
def collect_all_edges(
48+
fragments: list[Fragment],
49+
repo_root: Path | None = None,
50+
skip_expensive: bool = False,
51+
) -> tuple[EdgeDict, EdgeCategories]:
4452
all_edges: EdgeDict = {}
4553
edge_categories: EdgeCategories = {}
4654
for category, get_builders in _BUILDER_CATEGORIES:
55+
if skip_expensive and category in _EXPENSIVE_CATEGORIES:
56+
logging.debug("diffctx: skipping %s edge builders (skip_expensive=True)", category)
57+
continue
4758
for cls in get_builders():
4859
builder = cls()
4960
cat = builder.category or category

src/treemapper/diffctx/graph.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import dataclass, field
66
from pathlib import Path
77

8+
from .config import LIMITS
89
from .edges import collect_all_edges
910
from .edges.similarity.lexical import clamp_lexical_weight
1011
from .embeddings import _build_embedding_edges
@@ -14,6 +15,7 @@
1415
@dataclass
1516
class Graph:
1617
adjacency: dict[FragmentId, dict[FragmentId, float]] = field(default_factory=dict)
18+
reverse_adjacency: dict[FragmentId, dict[FragmentId, float]] = field(default_factory=dict)
1719
nodes: set[FragmentId] = field(default_factory=set)
1820
edge_categories: dict[tuple[FragmentId, FragmentId], str] = field(default_factory=dict)
1921

@@ -28,6 +30,12 @@ def add_edge(self, src: FragmentId, dst: FragmentId, weight: float) -> None:
2830
self.adjacency[src] = {}
2931
existing = self.adjacency[src].get(dst, 0.0)
3032
self.adjacency[src][dst] = max(existing, weight)
33+
34+
if dst not in self.reverse_adjacency:
35+
self.reverse_adjacency[dst] = {}
36+
existing_rev = self.reverse_adjacency[dst].get(src, 0.0)
37+
self.reverse_adjacency[dst][src] = max(existing_rev, weight)
38+
3139
self.nodes.add(src)
3240
self.nodes.add(dst)
3341

@@ -41,20 +49,25 @@ def build_graph(fragments: list[Fragment], repo_root: Path | None = None) -> Gra
4149
for frag in fragments:
4250
graph.nodes.add(frag.id)
4351

52+
skip_expensive = len(fragments) > LIMITS.skip_expensive_threshold
53+
if skip_expensive:
54+
logging.debug("diffctx: %d fragments exceed threshold, skipping expensive edge builders", len(fragments))
55+
4456
all_edges: dict[tuple[FragmentId, FragmentId], float] = {}
4557
edge_categories: dict[tuple[FragmentId, FragmentId], str] = {}
4658

47-
plugin_edges, plugin_categories = collect_all_edges(fragments, repo_root)
59+
plugin_edges, plugin_categories = collect_all_edges(fragments, repo_root, skip_expensive=skip_expensive)
4860
for (src, dst), weight in plugin_edges.items():
4961
if weight > all_edges.get((src, dst), 0.0):
5062
all_edges[(src, dst)] = weight
5163
edge_categories[(src, dst)] = plugin_categories.get((src, dst), "generic")
5264

53-
embedding_edges = _build_embedding_edges(fragments, clamp_lexical_weight)
54-
for (src, dst), weight in embedding_edges.items():
55-
if weight > all_edges.get((src, dst), 0.0):
56-
all_edges[(src, dst)] = weight
57-
edge_categories[(src, dst)] = "similarity"
65+
if not skip_expensive:
66+
embedding_edges = _build_embedding_edges(fragments, clamp_lexical_weight)
67+
for (src, dst), weight in embedding_edges.items():
68+
if weight > all_edges.get((src, dst), 0.0):
69+
all_edges[(src, dst)] = weight
70+
edge_categories[(src, dst)] = "similarity"
5871

5972
all_edges = _apply_hub_suppression(all_edges, edge_categories)
6073

0 commit comments

Comments
 (0)