2424 show_file_at_revision ,
2525 split_diff_range ,
2626)
27- from .graph import build_graph
27+ from .graph import Graph , build_graph
2828from .languages import FILENAME_TO_LANGUAGE
2929from .ppr import personalized_pagerank
3030from .render import build_partial_tree
31- from .select import lazy_greedy_select
31+ from .select import SelectionResult , lazy_greedy_select
3232from .types import DiffHunk , Fragment , FragmentId , extract_identifiers
33- from .utility import concepts_from_diff , concepts_from_diff_text
33+ from .utility import concepts_from_diff_text , needs_from_diff
3434
3535__all__ = ["GitError" , "build_diff_context" ]
3636
@@ -172,25 +172,26 @@ def _select_with_ppr(
172172 alpha : float ,
173173 tau : float ,
174174 repo_root : Path | None = None ,
175+ seed_weights : dict [FragmentId , float ] | None = None ,
175176) -> tuple [list [Fragment ], Any ]:
176177 graph = build_graph (all_fragments , repo_root = repo_root )
177- rel_scores = personalized_pagerank (graph , core_ids , alpha = alpha )
178+ rel_scores = personalized_pagerank (graph , core_ids , alpha = alpha , seed_weights = seed_weights )
178179
179- concepts = concepts_from_diff (all_fragments , core_ids , graph , diff_text )
180- if not concepts :
181- concepts = concepts_from_diff_text (diff_text )
180+ needs = needs_from_diff (all_fragments , core_ids , graph , diff_text )
182181
183182 effective_budget = budget_tokens if budget_tokens is not None else _UNLIMITED_BUDGET
184183
185184 result = lazy_greedy_select (
186185 fragments = all_fragments ,
187186 core_ids = core_ids ,
188187 rel = rel_scores ,
189- concepts = concepts ,
188+ needs = needs ,
190189 budget_tokens = effective_budget ,
191190 tau = tau ,
192191 )
193- return result .selected , result
192+
193+ selected = _coherence_post_pass (result , all_fragments , graph , effective_budget )
194+ return selected .selected , selected
194195
195196
196197def build_diff_context (
@@ -253,10 +254,16 @@ def build_diff_context(
253254
254255 core_ids = _identify_core_fragments (hunks , all_fragments )
255256
257+ signature_frags = _generate_signature_variants (all_fragments , core_ids )
258+ for frag in signature_frags :
259+ frag .token_count = count_tokens (frag .content ).count + _OVERHEAD_PER_FRAGMENT
260+ all_fragments .extend (signature_frags )
261+
256262 if full :
257263 selected = _select_full_mode (all_fragments , changed_files )
258264 _log_full_mode (selected )
259265 else :
266+ seed_weights = _compute_seed_weights (hunks , core_ids , all_fragments )
260267 selected , result = _select_with_ppr (
261268 all_fragments ,
262269 core_ids ,
@@ -265,6 +272,7 @@ def build_diff_context(
265272 alpha ,
266273 tau ,
267274 repo_root = root_dir ,
275+ seed_weights = seed_weights ,
268276 )
269277 _log_ppr_mode (selected , core_ids , budget_tokens , result , alpha , tau )
270278
@@ -288,7 +296,110 @@ def _validate_inputs(root_dir: Path, alpha: float, tau: float, budget_tokens: in
288296 raise ValueError (f"budget_tokens must be > 0, got { budget_tokens } " )
289297
290298
299+ def _coherence_post_pass (
300+ result : SelectionResult ,
301+ all_fragments : list [Fragment ],
302+ graph : Graph ,
303+ budget : int ,
304+ ) -> SelectionResult :
305+ selected_ids = {f .id for f in result .selected }
306+ remaining = budget - result .used_tokens
307+
308+ name_to_frags : dict [str , list [Fragment ]] = {}
309+ for f in all_fragments :
310+ if f .symbol_name :
311+ name_to_frags .setdefault (f .symbol_name .lower (), []).append (f )
312+
313+ frag_by_id : dict [FragmentId , Fragment ] = {f .id : f for f in all_fragments }
314+
315+ dangling_names : set [str ] = set ()
316+ for frag in result .selected :
317+ for nbr_id in graph .neighbors (frag .id ):
318+ if nbr_id in selected_ids :
319+ continue
320+ cat = graph .edge_categories .get ((frag .id , nbr_id ), "" )
321+ if cat == "semantic" :
322+ nbr_frag = frag_by_id .get (nbr_id )
323+ if nbr_frag and nbr_frag .symbol_name :
324+ dangling_names .add (nbr_frag .symbol_name .lower ())
325+
326+ added : list [Fragment ] = []
327+ for name in dangling_names :
328+ candidates = name_to_frags .get (name , [])
329+ for c in candidates :
330+ if c .id in selected_ids :
331+ break
332+ else :
333+ sig_candidates = [f for f in candidates if "_signature" in f .kind ]
334+ full_candidates = [f for f in candidates if "_signature" not in f .kind ]
335+ pick = sig_candidates [0 ] if sig_candidates else (full_candidates [0 ] if full_candidates else None )
336+ if pick and pick .token_count <= remaining and pick .id not in selected_ids :
337+ added .append (pick )
338+ selected_ids .add (pick .id )
339+ remaining -= pick .token_count
340+
341+ if not added :
342+ return result
343+
344+ return SelectionResult (
345+ selected = result .selected + added ,
346+ reason = result .reason ,
347+ used_tokens = result .used_tokens + sum (f .token_count for f in added ),
348+ utility = result .utility ,
349+ )
350+
351+
352+ def _compute_seed_weights (
353+ hunks : list [DiffHunk ],
354+ core_ids : set [FragmentId ],
355+ all_fragments : list [Fragment ],
356+ ) -> dict [FragmentId , float ]:
357+ frag_hunk_lines : dict [FragmentId , float ] = {}
358+ for h in hunks :
359+ h_start , h_end = h .core_selection_range
360+ hunk_size = max (1 , h_end - h_start + 1 )
361+ for frag in all_fragments :
362+ if frag .id not in core_ids or frag .path != h .path :
363+ continue
364+ if frag .start_line <= h_end and frag .end_line >= h_start :
365+ frag_hunk_lines [frag .id ] = frag_hunk_lines .get (frag .id , 0 ) + hunk_size
366+ if not frag_hunk_lines :
367+ return {}
368+ return frag_hunk_lines
369+
370+
291371_CONTAINER_FRAGMENT_KINDS = frozenset ({"class" , "interface" , "struct" })
372+ _SIGNATURE_ELIGIBLE_KINDS = frozenset ({"function" , "class" , "method" , "struct" , "interface" , "enum" })
373+ _MIN_LINES_FOR_SIGNATURE = 5
374+
375+
376+ def _generate_signature_variants (fragments : list [Fragment ], core_ids : set [FragmentId ]) -> list [Fragment ]:
377+ signatures : list [Fragment ] = []
378+ seen : set [FragmentId ] = set ()
379+ for frag in fragments :
380+ if frag .id in core_ids :
381+ continue
382+ if frag .kind not in _SIGNATURE_ELIGIBLE_KINDS :
383+ continue
384+ if frag .line_count < _MIN_LINES_FOR_SIGNATURE :
385+ continue
386+ lines = frag .content .splitlines ()
387+ sig_end = min (2 , len (lines ))
388+ sig_content = "\n " .join (lines [:sig_end ])
389+ sig_id = FragmentId (frag .path , frag .start_line , frag .start_line + sig_end - 1 )
390+ if sig_id in seen :
391+ continue
392+ seen .add (sig_id )
393+ signatures .append (
394+ Fragment (
395+ id = sig_id ,
396+ kind = f"{ frag .kind } _signature" ,
397+ content = sig_content ,
398+ identifiers = frag .identifiers ,
399+ symbol_name = frag .symbol_name ,
400+ )
401+ )
402+ return signatures
292403
293404
294405def _identify_core_fragments (hunks : list [DiffHunk ], all_fragments : list [Fragment ]) -> set [FragmentId ]:
0 commit comments