1818_CALL_RE = re .compile (r"(\w+)\s*\(" )
1919_TYPE_REF_RE = re .compile (r"(?::|->)\s*([A-Z]\w+)" )
2020_CLOSURE_MIN_EDGE_WEIGHT = 0.5
21+ _INVARIANT_RE = re .compile (r"\b(?:assert|require|ensure|precondition|postcondition|invariant)\s*[(\s]+(\w+)" , re .IGNORECASE )
2122
2223
2324@dataclass (frozen = True )
@@ -30,10 +31,27 @@ class InformationNeed:
3031
3132def _match_strength_typed (frag : Fragment , need : InformationNeed ) -> float :
3233 sym = need .symbol
33- if frag .symbol_name and frag .symbol_name .lower () == sym :
34+ defines = frag .symbol_name is not None and frag .symbol_name .lower () == sym
35+ is_signature = "_signature" in frag .kind
36+ mentions = sym in frag .identifiers
37+ is_test_frag = frag .symbol_name is not None and frag .symbol_name .lower ().startswith ("test_" )
38+
39+ scope_match = need .scope is None or need .scope == frag .path
40+
41+ if defines and not is_signature :
42+ if not scope_match :
43+ return 0.3
3444 return 1.0
35- if sym in frag .identifiers :
36- return 0.5
45+ if need .need_type == "impact" and not defines and mentions :
46+ return 0.8
47+ if is_signature and defines :
48+ return 0.7
49+ if need .need_type == "signature" and defines :
50+ return 0.7
51+ if need .need_type == "test" and is_test_frag and mentions :
52+ return 0.6
53+ if mentions :
54+ return 0.3
3755 return 0.0
3856
3957
@@ -202,6 +220,17 @@ def _collect_test_needs(
202220 needs [key ] = InformationNeed ("test" , tested , None , 0.6 )
203221
204222
223+ def _collect_invariant_needs (
224+ diff_text : str ,
225+ needs : dict [tuple [str , str ], InformationNeed ],
226+ ) -> None :
227+ for line in _extract_changed_lines (diff_text ):
228+ for m in _INVARIANT_RE .finditer (line ):
229+ sym = m .group (1 ).lower ()
230+ if len (sym ) >= 3 and sym not in CODE_STOPWORDS :
231+ needs .setdefault (("invariant" , sym ), InformationNeed ("invariant" , sym , None , 0.85 ))
232+
233+
205234def needs_from_diff (
206235 all_fragments : list [Fragment ],
207236 core_ids : set [FragmentId ],
@@ -213,6 +242,7 @@ def needs_from_diff(
213242
214243 core_symbol_names = _collect_core_needs (all_fragments , core_ids , needs )
215244 _collect_diff_line_needs (diff_text , needs )
245+ _collect_invariant_needs (diff_text , needs )
216246 _collect_test_needs (all_fragments , core_symbol_names , needs )
217247
218248 base_symbols = {n .symbol for n in needs .values ()}
@@ -238,8 +268,8 @@ def needs_from_diff(
238268
239269@dataclass
240270class UtilityState :
241- max_rel : dict [str , float ] = field (default_factory = dict )
242- priorities : dict [str , float ] = field (default_factory = dict )
271+ max_rel : dict [tuple [ str , str ] , float ] = field (default_factory = dict )
272+ priorities : dict [tuple [ str , str ] , float ] = field (default_factory = dict )
243273 structural_sum : float = 0.0
244274 eta : float = UTILITY .eta
245275 gamma : float = UTILITY .gamma
@@ -294,16 +324,15 @@ def marginal_gain(
294324 continue
295325 has_match = True
296326 a_fz = _augmented_score (m , rel_score , state )
297- old_max = state .max_rel .get (need .symbol , 0.0 )
327+ nkey = (need .need_type , need .symbol )
328+ old_max = state .max_rel .get (nkey , 0.0 )
298329 new_max = max (old_max , a_fz )
299330 gain += need .priority * (_phi (new_max ) - _phi (old_max ))
300331
301- # Diversity floor: after needs saturate (U₁ gain → 0), high-PPR
332+ # Diversity floor: after needs saturate (U1 gain -> 0), high-PPR
302333 # fragments still get nonzero gain proportional to unsatisfied needs.
303- # Prevents garbage with accidental identifier overlap from winning
304- # over structurally relevant fragments when all gains are near-zero.
305334 if needs and rel_score >= _MIN_REL_FOR_BONUS and (gain > 0 or rel_score >= _STRONG_REL_THRESHOLD ):
306- total_covered = sum (min (state .max_rel .get (n . symbol , 0.0 ), 1.0 ) for n in needs )
335+ total_covered = sum (min (state .max_rel .get (( n . need_type , n . symbol ) , 0.0 ), 1.0 ) for n in needs )
307336 unsatisfied = max (0.0 , 1.0 - total_covered / max (1 , len (needs )))
308337 floor = rel_score * _RELATEDNESS_BONUS * unsatisfied
309338 gain = max (gain , floor )
@@ -333,9 +362,10 @@ def apply_fragment(
333362 continue
334363 has_match = True
335364 a_fz = _augmented_score (m , rel_score , state )
336- old_max = state .max_rel .get (need .symbol , 0.0 )
337- state .max_rel [need .symbol ] = max (old_max , a_fz )
338- state .priorities [need .symbol ] = max (state .priorities .get (need .symbol , 0.0 ), need .priority )
365+ nkey = (need .need_type , need .symbol )
366+ old_max = state .max_rel .get (nkey , 0.0 )
367+ state .max_rel [nkey ] = max (old_max , a_fz )
368+ state .priorities [nkey ] = max (state .priorities .get (nkey , 0.0 ), need .priority )
339369 if has_match :
340370 r_norm = min (rel_score / state .r_cap , 1.0 ) if state .r_cap > 0 else 0.0
341371 state .structural_sum += state .gamma * r_norm
0 commit comments