Skip to content

Commit 52c6ea2

Browse files
committed
typed information needs, hub suppression, seed weights, file discovery
1 parent 13acea6 commit 52c6ea2

8 files changed

Lines changed: 352 additions & 76 deletions

File tree

src/treemapper/diffctx/__init__.py

Lines changed: 120 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
show_file_at_revision,
2525
split_diff_range,
2626
)
27-
from .graph import build_graph
27+
from .graph import Graph, build_graph
2828
from .languages import FILENAME_TO_LANGUAGE
2929
from .ppr import personalized_pagerank
3030
from .render import build_partial_tree
31-
from .select import lazy_greedy_select
31+
from .select import SelectionResult, lazy_greedy_select
3232
from .types import DiffHunk, Fragment, FragmentId, extract_identifiers
33-
from .utility import concepts_from_diff, concepts_from_diff_text
33+
from .utility import concepts_from_diff_text, needs_from_diff
3434

3535
__all__ = ["GitError", "build_diff_context"]
3636

@@ -172,25 +172,26 @@ def _select_with_ppr(
172172
alpha: float,
173173
tau: float,
174174
repo_root: Path | None = None,
175+
seed_weights: dict[FragmentId, float] | None = None,
175176
) -> tuple[list[Fragment], Any]:
176177
graph = build_graph(all_fragments, repo_root=repo_root)
177-
rel_scores = personalized_pagerank(graph, core_ids, alpha=alpha)
178+
rel_scores = personalized_pagerank(graph, core_ids, alpha=alpha, seed_weights=seed_weights)
178179

179-
concepts = concepts_from_diff(all_fragments, core_ids, graph, diff_text)
180-
if not concepts:
181-
concepts = concepts_from_diff_text(diff_text)
180+
needs = needs_from_diff(all_fragments, core_ids, graph, diff_text)
182181

183182
effective_budget = budget_tokens if budget_tokens is not None else _UNLIMITED_BUDGET
184183

185184
result = lazy_greedy_select(
186185
fragments=all_fragments,
187186
core_ids=core_ids,
188187
rel=rel_scores,
189-
concepts=concepts,
188+
needs=needs,
190189
budget_tokens=effective_budget,
191190
tau=tau,
192191
)
193-
return result.selected, result
192+
193+
selected = _coherence_post_pass(result, all_fragments, graph, effective_budget)
194+
return selected.selected, selected
194195

195196

196197
def build_diff_context(
@@ -253,10 +254,16 @@ def build_diff_context(
253254

254255
core_ids = _identify_core_fragments(hunks, all_fragments)
255256

257+
signature_frags = _generate_signature_variants(all_fragments, core_ids)
258+
for frag in signature_frags:
259+
frag.token_count = count_tokens(frag.content).count + _OVERHEAD_PER_FRAGMENT
260+
all_fragments.extend(signature_frags)
261+
256262
if full:
257263
selected = _select_full_mode(all_fragments, changed_files)
258264
_log_full_mode(selected)
259265
else:
266+
seed_weights = _compute_seed_weights(hunks, core_ids, all_fragments)
260267
selected, result = _select_with_ppr(
261268
all_fragments,
262269
core_ids,
@@ -265,6 +272,7 @@ def build_diff_context(
265272
alpha,
266273
tau,
267274
repo_root=root_dir,
275+
seed_weights=seed_weights,
268276
)
269277
_log_ppr_mode(selected, core_ids, budget_tokens, result, alpha, tau)
270278

@@ -288,7 +296,110 @@ def _validate_inputs(root_dir: Path, alpha: float, tau: float, budget_tokens: in
288296
raise ValueError(f"budget_tokens must be > 0, got {budget_tokens}")
289297

290298

299+
def _coherence_post_pass(
300+
result: SelectionResult,
301+
all_fragments: list[Fragment],
302+
graph: Graph,
303+
budget: int,
304+
) -> SelectionResult:
305+
selected_ids = {f.id for f in result.selected}
306+
remaining = budget - result.used_tokens
307+
308+
name_to_frags: dict[str, list[Fragment]] = {}
309+
for f in all_fragments:
310+
if f.symbol_name:
311+
name_to_frags.setdefault(f.symbol_name.lower(), []).append(f)
312+
313+
frag_by_id: dict[FragmentId, Fragment] = {f.id: f for f in all_fragments}
314+
315+
dangling_names: set[str] = set()
316+
for frag in result.selected:
317+
for nbr_id in graph.neighbors(frag.id):
318+
if nbr_id in selected_ids:
319+
continue
320+
cat = graph.edge_categories.get((frag.id, nbr_id), "")
321+
if cat == "semantic":
322+
nbr_frag = frag_by_id.get(nbr_id)
323+
if nbr_frag and nbr_frag.symbol_name:
324+
dangling_names.add(nbr_frag.symbol_name.lower())
325+
326+
added: list[Fragment] = []
327+
for name in dangling_names:
328+
candidates = name_to_frags.get(name, [])
329+
for c in candidates:
330+
if c.id in selected_ids:
331+
break
332+
else:
333+
sig_candidates = [f for f in candidates if "_signature" in f.kind]
334+
full_candidates = [f for f in candidates if "_signature" not in f.kind]
335+
pick = sig_candidates[0] if sig_candidates else (full_candidates[0] if full_candidates else None)
336+
if pick and pick.token_count <= remaining and pick.id not in selected_ids:
337+
added.append(pick)
338+
selected_ids.add(pick.id)
339+
remaining -= pick.token_count
340+
341+
if not added:
342+
return result
343+
344+
return SelectionResult(
345+
selected=result.selected + added,
346+
reason=result.reason,
347+
used_tokens=result.used_tokens + sum(f.token_count for f in added),
348+
utility=result.utility,
349+
)
350+
351+
352+
def _compute_seed_weights(
353+
hunks: list[DiffHunk],
354+
core_ids: set[FragmentId],
355+
all_fragments: list[Fragment],
356+
) -> dict[FragmentId, float]:
357+
frag_hunk_lines: dict[FragmentId, float] = {}
358+
for h in hunks:
359+
h_start, h_end = h.core_selection_range
360+
hunk_size = max(1, h_end - h_start + 1)
361+
for frag in all_fragments:
362+
if frag.id not in core_ids or frag.path != h.path:
363+
continue
364+
if frag.start_line <= h_end and frag.end_line >= h_start:
365+
frag_hunk_lines[frag.id] = frag_hunk_lines.get(frag.id, 0) + hunk_size
366+
if not frag_hunk_lines:
367+
return {}
368+
return frag_hunk_lines
369+
370+
291371
_CONTAINER_FRAGMENT_KINDS = frozenset({"class", "interface", "struct"})
372+
_SIGNATURE_ELIGIBLE_KINDS = frozenset({"function", "class", "method", "struct", "interface", "enum"})
373+
_MIN_LINES_FOR_SIGNATURE = 5
374+
375+
376+
def _generate_signature_variants(fragments: list[Fragment], core_ids: set[FragmentId]) -> list[Fragment]:
377+
signatures: list[Fragment] = []
378+
seen: set[FragmentId] = set()
379+
for frag in fragments:
380+
if frag.id in core_ids:
381+
continue
382+
if frag.kind not in _SIGNATURE_ELIGIBLE_KINDS:
383+
continue
384+
if frag.line_count < _MIN_LINES_FOR_SIGNATURE:
385+
continue
386+
lines = frag.content.splitlines()
387+
sig_end = min(2, len(lines))
388+
sig_content = "\n".join(lines[:sig_end])
389+
sig_id = FragmentId(frag.path, frag.start_line, frag.start_line + sig_end - 1)
390+
if sig_id in seen:
391+
continue
392+
seen.add(sig_id)
393+
signatures.append(
394+
Fragment(
395+
id=sig_id,
396+
kind=f"{frag.kind}_signature",
397+
content=sig_content,
398+
identifiers=frag.identifiers,
399+
symbol_name=frag.symbol_name,
400+
)
401+
)
402+
return signatures
292403

293404

294405
def _identify_core_fragments(hunks: list[DiffHunk], all_fragments: list[Fragment]) -> set[FragmentId]:

src/treemapper/diffctx/edges/__init__.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
from collections.abc import Callable
34
from pathlib import Path
45
from typing import TYPE_CHECKING
56

7+
from ..types import FragmentId
68
from .base import EdgeBuilder, EdgeDict
79
from .config import get_config_builders
810
from .document import get_document_builders
@@ -14,6 +16,17 @@
1416
if TYPE_CHECKING:
1517
from ..types import Fragment
1618

19+
EdgeCategories = dict[tuple[FragmentId, FragmentId], str]
20+
21+
_BUILDER_CATEGORIES: list[tuple[str, Callable[[], list[type[EdgeBuilder]]]]] = [
22+
("semantic", get_semantic_builders),
23+
("structural", get_structural_builders),
24+
("config", get_config_builders),
25+
("document", get_document_builders),
26+
("similarity", get_similarity_builders),
27+
("history", get_history_builders),
28+
]
29+
1730

1831
def get_all_builders() -> list[EdgeBuilder]:
1932
all_builder_classes = (
@@ -27,12 +40,17 @@ def get_all_builders() -> list[EdgeBuilder]:
2740
return [cls() for cls in all_builder_classes]
2841

2942

30-
def collect_all_edges(fragments: list[Fragment], repo_root: Path | None = None) -> EdgeDict:
43+
def collect_all_edges(fragments: list[Fragment], repo_root: Path | None = None) -> tuple[EdgeDict, EdgeCategories]:
3144
all_edges: EdgeDict = {}
32-
for builder in get_all_builders():
33-
for (src, dst), weight in builder.build(fragments, repo_root).items():
34-
all_edges[(src, dst)] = max(all_edges.get((src, dst), 0.0), weight)
35-
return all_edges
45+
edge_categories: EdgeCategories = {}
46+
for category, get_builders in _BUILDER_CATEGORIES:
47+
for cls in get_builders():
48+
builder = cls()
49+
for (src, dst), weight in builder.build(fragments, repo_root).items():
50+
if weight > all_edges.get((src, dst), 0.0):
51+
all_edges[(src, dst)] = weight
52+
edge_categories[(src, dst)] = category
53+
return all_edges, edge_categories
3654

3755

3856
def discover_all_related_files(
@@ -49,6 +67,7 @@ def discover_all_related_files(
4967

5068
__all__ = [
5169
"EdgeBuilder",
70+
"EdgeCategories",
5271
"EdgeDict",
5372
"collect_all_edges",
5473
"discover_all_related_files",

src/treemapper/diffctx/edges/semantic/python.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
from collections import defaultdict
66
from pathlib import Path
77

8+
from ...config.weights import LANG_WEIGHTS
89
from ...python_semantics import PyFragmentInfo, analyze_python_fragment
910
from ...types import Fragment, FragmentId
1011
from ..base import EdgeBuilder, EdgeDict, add_semantic_edges, path_to_module
1112

1213
_PYTHON_EXTS = {".py", ".pyi", ".pyw"}
1314

14-
_CALL_WEIGHT = 0.85
15-
_SYMBOL_REF_WEIGHT = 0.95
16-
_TYPE_REF_WEIGHT = 0.60
15+
_PY_WEIGHTS = LANG_WEIGHTS["python"]
16+
_CALL_WEIGHT = _PY_WEIGHTS.call
17+
_SYMBOL_REF_WEIGHT = _PY_WEIGHTS.symbol_ref
18+
_TYPE_REF_WEIGHT = _PY_WEIGHTS.type_ref
1719

1820
_PY_IMPORT_RE = re.compile(r"(?:from\s{1,20}(\.{0,3}[\w.]{0,200})\s{1,20}import|import\s{1,20}([\w.]{1,200}))")
1921

src/treemapper/diffctx/edges/structural/test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,39 @@ class TestEdgeBuilder(EdgeBuilder):
9393
weight_naming = EDGE_WEIGHTS["test_naming"].forward
9494
reverse_weight_factor = EDGE_WEIGHTS["test_reverse"].forward / EDGE_WEIGHTS["test_direct"].forward
9595

96+
def discover_related_files(
97+
self,
98+
changed_files: list[Path],
99+
all_candidate_files: list[Path],
100+
repo_root: Path | None = None,
101+
) -> list[Path]:
102+
changed_set = set(changed_files)
103+
discovered: list[Path] = []
104+
105+
candidate_by_stem: dict[str, list[Path]] = defaultdict(list)
106+
for c in all_candidate_files:
107+
if c not in changed_set:
108+
candidate_by_stem[c.stem.lower()].append(c)
109+
110+
for changed in changed_files:
111+
stem = changed.stem.lower()
112+
suffix = changed.suffix.lower()
113+
114+
if _is_test_file(changed):
115+
target = _extract_target_name_from_test(changed.stem)
116+
if target:
117+
for cand in candidate_by_stem.get(target, []):
118+
if cand.suffix.lower() == suffix:
119+
discovered.append(cand)
120+
else:
121+
test_stems = [f"test_{stem}", f"{stem}_test"]
122+
for ts in test_stems:
123+
for cand in candidate_by_stem.get(ts, []):
124+
if cand.suffix.lower() == suffix and _is_test_file(cand):
125+
discovered.append(cand)
126+
127+
return discovered
128+
96129
def build(self, fragments: list[Fragment], repo_root: Path | None = None) -> EdgeDict:
97130
edges: EdgeDict = {}
98131

src/treemapper/diffctx/graph.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class Graph:
1616
adjacency: dict[FragmentId, dict[FragmentId, float]] = field(default_factory=dict)
1717
nodes: set[FragmentId] = field(default_factory=set)
18+
edge_categories: dict[tuple[FragmentId, FragmentId], str] = field(default_factory=dict)
1819

1920
def add_node(self, node: FragmentId) -> None:
2021
self.nodes.add(node)
@@ -41,25 +42,36 @@ def build_graph(fragments: list[Fragment], repo_root: Path | None = None) -> Gra
4142
graph.nodes.add(frag.id)
4243

4344
all_edges: dict[tuple[FragmentId, FragmentId], float] = {}
45+
edge_categories: dict[tuple[FragmentId, FragmentId], str] = {}
4446

45-
plugin_edges = collect_all_edges(fragments, repo_root)
47+
plugin_edges, plugin_categories = collect_all_edges(fragments, repo_root)
4648
for (src, dst), weight in plugin_edges.items():
47-
all_edges[(src, dst)] = max(all_edges.get((src, dst), 0.0), weight)
49+
if weight > all_edges.get((src, dst), 0.0):
50+
all_edges[(src, dst)] = weight
51+
edge_categories[(src, dst)] = plugin_categories.get((src, dst), "generic")
4852

4953
embedding_edges = _build_embedding_edges(fragments, clamp_lexical_weight)
5054
for (src, dst), weight in embedding_edges.items():
51-
all_edges[(src, dst)] = max(all_edges.get((src, dst), 0.0), weight)
55+
if weight > all_edges.get((src, dst), 0.0):
56+
all_edges[(src, dst)] = weight
57+
edge_categories[(src, dst)] = "similarity"
5258

53-
all_edges = _apply_hub_suppression(all_edges)
59+
all_edges = _apply_hub_suppression(all_edges, edge_categories)
5460

5561
for (src, dst), weight in all_edges.items():
5662
graph.add_edge(src, dst, weight)
5763

64+
graph.edge_categories = edge_categories
65+
5866
return graph
5967

6068

69+
_SUPPRESSION_EXEMPT = frozenset({"semantic", "structural", "config", "document"})
70+
71+
6172
def _apply_hub_suppression(
6273
edges: dict[tuple[FragmentId, FragmentId], float],
74+
edge_categories: dict[tuple[FragmentId, FragmentId], str],
6375
) -> dict[tuple[FragmentId, FragmentId], float]:
6476
if not edges:
6577
return edges
@@ -71,11 +83,16 @@ def _apply_hub_suppression(
7183
if not in_degree:
7284
return edges
7385

86+
degrees = sorted(in_degree.values())
87+
mid = len(degrees) // 2
88+
d_median = (degrees[mid] + degrees[~mid]) / 2.0
89+
7490
suppressed: dict[tuple[FragmentId, FragmentId], float] = {}
7591
for (src, dst), weight in edges.items():
7692
deg = in_degree.get(dst, 0)
77-
if deg > 0:
78-
weight = weight / math.log(1 + deg)
93+
cat = edge_categories.get((src, dst), "generic")
94+
if deg > d_median and cat not in _SUPPRESSION_EXEMPT:
95+
weight = weight / max(1.0, math.log(1 + deg))
7996
suppressed[(src, dst)] = weight
8097

8198
return suppressed

0 commit comments

Comments
 (0)