Skip to content

Commit 49ed299

Browse files
committed
fix: align diffctx algorithm with paper — 5 theory/code misalignments
1 parent 9eb7e84 commit 49ed299

3 files changed

Lines changed: 62 additions & 13 deletions

File tree

src/treemapper/diffctx/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def build_diff_context(
274274

275275
core_ids = _identify_core_fragments(hunks, all_fragments)
276276

277-
signature_frags = _generate_signature_variants(all_fragments, core_ids)
277+
signature_frags = _generate_signature_variants(all_fragments)
278278
for frag in signature_frags:
279279
frag.token_count = count_tokens(frag.content).count + _OVERHEAD_PER_FRAGMENT
280280
all_fragments.extend(signature_frags)
@@ -403,12 +403,10 @@ def _compute_seed_weights(
403403
_MIN_LINES_FOR_SIGNATURE = 5
404404

405405

406-
def _generate_signature_variants(fragments: list[Fragment], core_ids: set[FragmentId]) -> list[Fragment]:
406+
def _generate_signature_variants(fragments: list[Fragment]) -> list[Fragment]:
407407
signatures: list[Fragment] = []
408408
seen: set[FragmentId] = set()
409409
for frag in fragments:
410-
if frag.id in core_ids:
411-
continue
412410
if frag.kind not in _SIGNATURE_ELIGIBLE_KINDS:
413411
continue
414412
if frag.line_count < _MIN_LINES_FOR_SIGNATURE:
@@ -456,7 +454,7 @@ def _add_container_headers(core_ids: set[FragmentId], frags_by_path: dict[Path,
456454
if frag.kind not in _CONTAINER_FRAGMENT_KINDS or frag.id in core_ids:
457455
continue
458456
for core_id in core_ids:
459-
if core_id.path == path and core_id.start_line > frag.end_line:
457+
if core_id.path == path and frag.start_line <= core_id.start_line and core_id.end_line <= frag.end_line:
460458
headers_to_add.append(frag.id)
461459
break
462460
core_ids.update(headers_to_add)

src/treemapper/diffctx/select.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22

33
import heapq
44
import logging
5+
import math
56
import statistics
67
from dataclasses import dataclass, field
8+
from pathlib import Path
79

810
from .types import Fragment, FragmentId
911
from .utility import InformationNeed, UtilityState, apply_fragment, compute_density, marginal_gain, utility_value
1012

11-
_BASELINE_K = 5
13+
_BASELINE_K_MAX = 5
14+
_CORE_BUDGET_FRACTION = 0.70
15+
16+
17+
def _adaptive_baseline_k(n_candidates: int) -> int:
18+
return min(_BASELINE_K_MAX, math.ceil(0.1 * n_candidates))
1219

1320

1421
@dataclass
@@ -45,16 +52,30 @@ def _select_core_fragments(
4552
rel: dict[FragmentId, float],
4653
needs: tuple[InformationNeed, ...],
4754
state: _SelectionState,
55+
budget_tokens: int,
56+
sig_lookup: dict[FragmentId, Fragment] | None = None,
4857
) -> None:
58+
core_budget = int(budget_tokens * _CORE_BUDGET_FRACTION)
59+
core_used = 0
4960
sorted_core = sorted(core_fragments, key=lambda f: rel.get(f.id, 0.0), reverse=True)
5061

5162
for frag in sorted_core:
5263
if _is_subset_of_selected(frag, state.selected_ids):
5364
continue
65+
if core_used + frag.token_count > core_budget:
66+
sig = sig_lookup.get(frag.id) if sig_lookup else None
67+
if sig and sig.id not in state.selected_ids and core_used + sig.token_count <= core_budget:
68+
state.selected.append(sig)
69+
state.selected_ids.add(sig.id)
70+
state.remaining_budget -= sig.token_count
71+
core_used += sig.token_count
72+
apply_fragment(sig, rel.get(frag.id, 0.0), needs, state.utility_state)
73+
continue
5474

5575
state.selected.append(frag)
5676
state.selected_ids.add(frag.id)
5777
state.remaining_budget -= frag.token_count
78+
core_used += frag.token_count
5879
apply_fragment(frag, rel.get(frag.id, 0.0), needs, state.utility_state)
5980

6081

@@ -159,8 +180,18 @@ def lazy_greedy_select(
159180
core_fragments.sort(key=lambda f: (f.token_count if f.token_count > 0 else 10**9, f.line_count, f.start_line))
160181
non_core_fragments = [f for f in fragments if f.id not in core_ids]
161182

183+
sig_by_loc: dict[tuple[Path, int], Fragment] = {}
184+
for f in fragments:
185+
if "_signature" in f.kind:
186+
sig_by_loc[(f.path, f.start_line)] = f
187+
sig_lookup: dict[FragmentId, Fragment] = {}
188+
for cf in core_fragments:
189+
key = (cf.path, cf.start_line)
190+
if key in sig_by_loc:
191+
sig_lookup[cf.id] = sig_by_loc[key]
192+
162193
state = _SelectionState(remaining_budget=budget_tokens)
163-
_select_core_fragments(core_fragments, rel, needs, state)
194+
_select_core_fragments(core_fragments, rel, needs, state, budget_tokens, sig_lookup)
164195

165196
if state.remaining_budget <= 0:
166197
used = budget_tokens - state.remaining_budget
@@ -179,10 +210,11 @@ def lazy_greedy_select(
179210
base_budget = state.remaining_budget
180211

181212
candidates = [f for f in non_core_fragments if not _overlaps_with_selected(f, state.selected_ids)]
213+
baseline_k = _adaptive_baseline_k(len(candidates))
182214
id_to_frag: dict[FragmentId, Fragment] = {}
183215
heap = _build_initial_heap(candidates, rel, needs, state.utility_state, id_to_frag)
184216

185-
selections_for_baseline, threshold = _run_greedy_loop_heap(heap, id_to_frag, state, rel, needs, tau)
217+
selections_for_baseline, threshold = _run_greedy_loop_heap(heap, id_to_frag, state, rel, needs, tau, baseline_k)
186218

187219
greedy_utility = utility_value(state.utility_state)
188220
base_selected_ids = {f.id for f in base_selected}
@@ -203,7 +235,15 @@ def lazy_greedy_select(
203235
return singleton_result
204236

205237
return _determine_final_result(
206-
state, base_selected, budget_tokens, greedy_utility, selections_for_baseline, threshold, bool(heap), core_ids
238+
state,
239+
base_selected,
240+
budget_tokens,
241+
greedy_utility,
242+
selections_for_baseline,
243+
threshold,
244+
bool(heap),
245+
core_ids,
246+
baseline_k,
207247
)
208248

209249

@@ -214,6 +254,7 @@ def _run_greedy_loop_heap(
214254
rel: dict[FragmentId, float],
215255
needs: tuple[InformationNeed, ...],
216256
tau: float,
257+
baseline_k: int = _BASELINE_K_MAX,
217258
) -> tuple[int, float]:
218259
baseline_densities: list[float] = []
219260
threshold = 0.0
@@ -235,10 +276,10 @@ def _run_greedy_loop_heap(
235276
if best_frag is None or best_density <= 0:
236277
break
237278

238-
if selections_for_baseline < _BASELINE_K:
279+
if selections_for_baseline < baseline_k:
239280
baseline_densities.append(best_density)
240281
selections_for_baseline += 1
241-
if selections_for_baseline == _BASELINE_K and baseline_densities:
282+
if selections_for_baseline == baseline_k and baseline_densities:
242283
threshold = tau * statistics.median(baseline_densities)
243284
elif best_density < threshold:
244285
break
@@ -293,6 +334,7 @@ def _determine_final_result(
293334
threshold: float,
294335
has_remaining_candidates: bool,
295336
core_ids: set[FragmentId],
337+
baseline_k: int = _BASELINE_K_MAX,
296338
) -> SelectionResult:
297339
used = budget_tokens - state.remaining_budget
298340

@@ -302,7 +344,7 @@ def _determine_final_result(
302344
reason = "no_utility"
303345
elif not state.selected or len(state.selected) == len(base_selected):
304346
reason = "no_candidates"
305-
elif selections_for_baseline >= _BASELINE_K and threshold > 0 and has_remaining_candidates:
347+
elif selections_for_baseline >= baseline_k and threshold > 0 and has_remaining_candidates:
306348
reason = "stopped_by_tau"
307349
else:
308350
reason = "no_candidates"

src/treemapper/diffctx/utility.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def _build_sigma(
9191
return sigma
9292

9393

94+
_CLOSURE_EDGE_CATEGORIES = frozenset({"structural", "semantic"})
95+
96+
9497
def _closure_expand_step(
9598
closure: set[str],
9699
frag_by_symbol: dict[str, list[Fragment]],
@@ -103,6 +106,9 @@ def _closure_expand_step(
103106
for nbr_id, weight in graph.neighbors(frag.id).items():
104107
if weight < _CLOSURE_MIN_EDGE_WEIGHT:
105108
continue
109+
cat = graph.edge_categories.get((frag.id, nbr_id), "")
110+
if cat and cat not in _CLOSURE_EDGE_CATEGORIES:
111+
continue
106112
nbr = frag_by_id.get(nbr_id)
107113
if nbr and nbr.symbol_name and nbr.symbol_name.lower() not in closure:
108114
new_symbols.add(nbr.symbol_name.lower())
@@ -282,7 +288,10 @@ def marginal_gain(
282288
gain += _phi(new_max) - _phi(old_max)
283289

284290
if rel_score >= _MIN_REL_FOR_BONUS and (gain > 0 or rel_score >= _STRONG_REL_THRESHOLD):
285-
gain = max(gain, rel_score * _RELATEDNESS_BONUS)
291+
total_covered = sum(min(state.max_rel.get(n.symbol, 0.0), 1.0) for n in needs)
292+
unsatisfied = max(0.0, 1.0 - total_covered / max(1, len(needs)))
293+
floor = rel_score * _RELATEDNESS_BONUS * unsatisfied
294+
gain = max(gain, floor)
286295

287296
return gain
288297

0 commit comments

Comments
 (0)