Skip to content

Commit a67dcc7

Browse files
committed
feat(diffctx): typed match strength, scope disambiguation, keyed state, invariant needs
1 parent f448040 commit a67dcc7

1 file changed

Lines changed: 43 additions & 13 deletions

File tree

src/treemapper/diffctx/utility.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_CALL_RE = re.compile(r"(\w+)\s*\(")
1919
_TYPE_REF_RE = re.compile(r"(?::|->)\s*([A-Z]\w+)")
2020
_CLOSURE_MIN_EDGE_WEIGHT = 0.5
21+
_INVARIANT_RE = re.compile(r"\b(?:assert|require|ensure|precondition|postcondition|invariant)\s*[(\s]+(\w+)", re.IGNORECASE)
2122

2223

2324
@dataclass(frozen=True)
@@ -30,10 +31,27 @@ class InformationNeed:
3031

3132
def _match_strength_typed(frag: Fragment, need: InformationNeed) -> float:
3233
sym = need.symbol
33-
if frag.symbol_name and frag.symbol_name.lower() == sym:
34+
defines = frag.symbol_name is not None and frag.symbol_name.lower() == sym
35+
is_signature = "_signature" in frag.kind
36+
mentions = sym in frag.identifiers
37+
is_test_frag = frag.symbol_name is not None and frag.symbol_name.lower().startswith("test_")
38+
39+
scope_match = need.scope is None or need.scope == frag.path
40+
41+
if defines and not is_signature:
42+
if not scope_match:
43+
return 0.3
3444
return 1.0
35-
if sym in frag.identifiers:
36-
return 0.5
45+
if need.need_type == "impact" and not defines and mentions:
46+
return 0.8
47+
if is_signature and defines:
48+
return 0.7
49+
if need.need_type == "signature" and defines:
50+
return 0.7
51+
if need.need_type == "test" and is_test_frag and mentions:
52+
return 0.6
53+
if mentions:
54+
return 0.3
3755
return 0.0
3856

3957

@@ -202,6 +220,17 @@ def _collect_test_needs(
202220
needs[key] = InformationNeed("test", tested, None, 0.6)
203221

204222

223+
def _collect_invariant_needs(
224+
diff_text: str,
225+
needs: dict[tuple[str, str], InformationNeed],
226+
) -> None:
227+
for line in _extract_changed_lines(diff_text):
228+
for m in _INVARIANT_RE.finditer(line):
229+
sym = m.group(1).lower()
230+
if len(sym) >= 3 and sym not in CODE_STOPWORDS:
231+
needs.setdefault(("invariant", sym), InformationNeed("invariant", sym, None, 0.85))
232+
233+
205234
def needs_from_diff(
206235
all_fragments: list[Fragment],
207236
core_ids: set[FragmentId],
@@ -213,6 +242,7 @@ def needs_from_diff(
213242

214243
core_symbol_names = _collect_core_needs(all_fragments, core_ids, needs)
215244
_collect_diff_line_needs(diff_text, needs)
245+
_collect_invariant_needs(diff_text, needs)
216246
_collect_test_needs(all_fragments, core_symbol_names, needs)
217247

218248
base_symbols = {n.symbol for n in needs.values()}
@@ -238,8 +268,8 @@ def needs_from_diff(
238268

239269
@dataclass
240270
class UtilityState:
241-
max_rel: dict[str, float] = field(default_factory=dict)
242-
priorities: dict[str, float] = field(default_factory=dict)
271+
max_rel: dict[tuple[str, str], float] = field(default_factory=dict)
272+
priorities: dict[tuple[str, str], float] = field(default_factory=dict)
243273
structural_sum: float = 0.0
244274
eta: float = UTILITY.eta
245275
gamma: float = UTILITY.gamma
@@ -294,16 +324,15 @@ def marginal_gain(
294324
continue
295325
has_match = True
296326
a_fz = _augmented_score(m, rel_score, state)
297-
old_max = state.max_rel.get(need.symbol, 0.0)
327+
nkey = (need.need_type, need.symbol)
328+
old_max = state.max_rel.get(nkey, 0.0)
298329
new_max = max(old_max, a_fz)
299330
gain += need.priority * (_phi(new_max) - _phi(old_max))
300331

301-
# Diversity floor: after needs saturate (U₁ gain 0), high-PPR
332+
# Diversity floor: after needs saturate (U1 gain -> 0), high-PPR
302333
# fragments still get nonzero gain proportional to unsatisfied needs.
303-
# Prevents garbage with accidental identifier overlap from winning
304-
# over structurally relevant fragments when all gains are near-zero.
305334
if needs and rel_score >= _MIN_REL_FOR_BONUS and (gain > 0 or rel_score >= _STRONG_REL_THRESHOLD):
306-
total_covered = sum(min(state.max_rel.get(n.symbol, 0.0), 1.0) for n in needs)
335+
total_covered = sum(min(state.max_rel.get((n.need_type, n.symbol), 0.0), 1.0) for n in needs)
307336
unsatisfied = max(0.0, 1.0 - total_covered / max(1, len(needs)))
308337
floor = rel_score * _RELATEDNESS_BONUS * unsatisfied
309338
gain = max(gain, floor)
@@ -333,9 +362,10 @@ def apply_fragment(
333362
continue
334363
has_match = True
335364
a_fz = _augmented_score(m, rel_score, state)
336-
old_max = state.max_rel.get(need.symbol, 0.0)
337-
state.max_rel[need.symbol] = max(old_max, a_fz)
338-
state.priorities[need.symbol] = max(state.priorities.get(need.symbol, 0.0), need.priority)
365+
nkey = (need.need_type, need.symbol)
366+
old_max = state.max_rel.get(nkey, 0.0)
367+
state.max_rel[nkey] = max(old_max, a_fz)
368+
state.priorities[nkey] = max(state.priorities.get(nkey, 0.0), need.priority)
339369
if has_match:
340370
r_norm = min(rel_score / state.r_cap, 1.0) if state.r_cap > 0 else 0.0
341371
state.structural_sum += state.gamma * r_norm

0 commit comments

Comments
 (0)