44import heapq
55import logging
66import math
7+ import os
78import statistics
89from dataclasses import dataclass , field
910from pathlib import Path
@@ -230,6 +231,90 @@ def _compute_r_cap(rel: dict[FragmentId, float]) -> float:
230231 return max (med + UTILITY .r_cap_sigma * std , 1e-9 )
231232
232233
234+ def _collect_greedy_densities (
235+ candidates : list [Fragment ],
236+ rel : dict [FragmentId , float ],
237+ needs : tuple [InformationNeed , ...],
238+ utility_state : UtilityState ,
239+ ) -> list [tuple [str , int , int , float , float , float ]]:
240+ result : list [tuple [str , int , int , float , float , float ]] = []
241+ for frag in candidates :
242+ if frag .token_count > 0 :
243+ density = compute_density (frag , rel .get (frag .id , 0.0 ), needs , utility_state )
244+ gain = marginal_gain (frag , rel .get (frag .id , 0.0 ), needs , utility_state )
245+ result .append ((str (frag .path ), frag .start_line , frag .token_count , rel .get (frag .id , 0.0 ), gain , density ))
246+ return result
247+
248+
249+ def _write_greedy_dump (
250+ path : str ,
251+ tau : float ,
252+ threshold : float ,
253+ baseline_k : int ,
254+ n_candidates : int ,
255+ n_selected : int ,
256+ remaining_budget : int ,
257+ densities : list [tuple [str , int , int , float , float , float ]],
258+ ) -> None :
259+ import json as _json
260+
261+ with open (path , "w" ) as f :
262+ f .write (
263+ _json .dumps (
264+ {
265+ "tau" : tau ,
266+ "threshold" : threshold ,
267+ "baseline_k" : baseline_k ,
268+ "n_candidates" : n_candidates ,
269+ "n_selected_noncore" : n_selected ,
270+ "remaining_budget" : remaining_budget ,
271+ }
272+ )
273+ + "\n "
274+ )
275+ for fpath , start , tokens , ppr , gain , density in sorted (densities , key = lambda x : - x [5 ]):
276+ f .write (
277+ _json .dumps (
278+ {
279+ "path" : fpath ,
280+ "start" : start ,
281+ "tokens" : tokens ,
282+ "ppr" : round (ppr , 6 ),
283+ "gain" : round (gain , 4 ),
284+ "density" : round (density , 6 ),
285+ }
286+ )
287+ + "\n "
288+ )
289+
290+
291+ def _build_signature_lookup (fragments : list [Fragment ], core_fragments : list [Fragment ]) -> dict [FragmentId , Fragment ]:
292+ sig_by_loc : dict [tuple [Path , int ], Fragment ] = {}
293+ for f in fragments :
294+ if "_signature" in f .kind :
295+ sig_by_loc [(f .path , f .start_line )] = f
296+ sig_lookup : dict [FragmentId , Fragment ] = {}
297+ for cf in core_fragments :
298+ key = (cf .path , cf .start_line )
299+ if key in sig_by_loc :
300+ sig_lookup [cf .id ] = sig_by_loc [key ]
301+ return sig_lookup
302+
303+
304+ def _init_selection_state (
305+ core_ids : set [FragmentId ],
306+ rel : dict [FragmentId , float ],
307+ budget_tokens : int ,
308+ file_importance : dict [Path , float ] | None ,
309+ ) -> _SelectionState :
310+ state = _SelectionState (remaining_budget = budget_tokens )
311+ state .utility_state .r_cap = _compute_r_cap (rel )
312+ state .utility_state .changed_dirs = frozenset (cid .path .parent for cid in core_ids )
313+ if file_importance is not None :
314+ state .utility_state .file_importance = file_importance
315+ return state
316+
317+
233318def lazy_greedy_select (
234319 fragments : list [Fragment ],
235320 core_ids : set [FragmentId ],
@@ -251,21 +336,8 @@ def lazy_greedy_select(
251336 core_fragments .sort (key = lambda f : (f .token_count if f .token_count > 0 else 10 ** 9 , f .line_count , f .start_line ))
252337 non_core_fragments = [f for f in fragments if f .id not in core_ids ]
253338
254- sig_by_loc : dict [tuple [Path , int ], Fragment ] = {}
255- for f in fragments :
256- if "_signature" in f .kind :
257- sig_by_loc [(f .path , f .start_line )] = f
258- sig_lookup : dict [FragmentId , Fragment ] = {}
259- for cf in core_fragments :
260- key = (cf .path , cf .start_line )
261- if key in sig_by_loc :
262- sig_lookup [cf .id ] = sig_by_loc [key ]
263-
264- state = _SelectionState (remaining_budget = budget_tokens )
265- state .utility_state .r_cap = _compute_r_cap (rel )
266- state .utility_state .changed_dirs = frozenset (cid .path .parent for cid in core_ids )
267- if file_importance is not None :
268- state .utility_state .file_importance = file_importance
339+ sig_lookup = _build_signature_lookup (fragments , core_fragments )
340+ state = _init_selection_state (core_ids , rel , budget_tokens , file_importance )
269341 _select_core_fragments (core_fragments , rel , needs , state , budget_tokens , sig_lookup )
270342
271343 if state .remaining_budget <= 0 :
@@ -291,8 +363,23 @@ def lazy_greedy_select(
291363 id_to_frag : dict [FragmentId , Fragment ] = {}
292364 heap = _build_initial_heap (candidates , rel , needs , state .utility_state , id_to_frag )
293365
366+ dump_greedy = os .environ .get ("DIFFCTX_DUMP_GREEDY" )
367+ pre_greedy_densities = _collect_greedy_densities (candidates , rel , needs , state .utility_state ) if dump_greedy else None
368+
294369 selections_for_baseline , threshold = _run_greedy_loop_heap (heap , id_to_frag , state , rel , needs , tau , baseline_k )
295370
371+ if dump_greedy and pre_greedy_densities is not None :
372+ _write_greedy_dump (
373+ dump_greedy ,
374+ tau ,
375+ threshold ,
376+ baseline_k ,
377+ len (candidates ),
378+ selections_for_baseline ,
379+ state .remaining_budget ,
380+ pre_greedy_densities ,
381+ )
382+
296383 greedy_utility = utility_value (state .utility_state )
297384 base_selected_ids = _IntervalIndex ()
298385 for f in base_selected :
0 commit comments