@@ -219,6 +219,22 @@ def _print_nontrivial_report(
219219 print (f" HIT: { f } max_ppr={ max_score :.6f} " )
220220
221221
222+ def _classify_nontrivial_stages (
223+ nontrivial : set ,
224+ selected : set ,
225+ sel_dump : set ,
226+ fragmented : set ,
227+ universe : set ,
228+ ) -> dict [str , str ]:
229+ stage_per_file : dict [str , str ] = {}
230+ for f in nontrivial :
231+ if f in selected :
232+ stage_per_file [f ] = "selected"
233+ else :
234+ stage_per_file [f ] = _classify_failure_stage (f , sel_dump , fragmented , universe )
235+ return stage_per_file
236+
237+
222238def evaluate_one (inst : dict , budget : int ) -> dict :
223239 iid = inst ["instance_id" ]
224240 print ("\n " + "=" * 78 )
@@ -293,12 +309,7 @@ def evaluate_one(inst: dict, budget: int) -> dict:
293309 nt_recall = len (nontrivial_hits ) / len (nontrivial ) if nontrivial else 0.0
294310 patch_coverage = len (p_set & selected ) / len (p_set ) if p_set else 0.0
295311
296- stage_per_file : dict [str , str ] = {}
297- for f in nontrivial :
298- if f in selected :
299- stage_per_file [f ] = "selected"
300- else :
301- stage_per_file [f ] = _classify_failure_stage (f , sel_dump , fragmented , universe )
312+ stage_per_file = _classify_nontrivial_stages (nontrivial , selected , sel_dump , fragmented , universe )
302313
303314 return {
304315 "id" : iid ,
@@ -328,6 +339,40 @@ def _print_threshold_sanity_check():
328339 print (f"diffctx _LOW_RELEVANCE_THRESHOLD = { v } " , file = sys .stderr )
329340
330341
342+ def _filter_nontrivial_instances (insts : list ) -> list :
343+ kept = []
344+ for i in insts :
345+ gb = json .loads (i ["gold_context" ]) if isinstance (i ["gold_context" ], str ) else i ["gold_context" ]
346+ gold = {normalize_gold_path (g ["file" ]) for g in gb }
347+ added , deleted , modified = patch_files_detailed (i ["patch" ])
348+ if gold - (added | deleted | modified ):
349+ kept .append (i )
350+ return kept
351+
352+
353+ def _print_ok_summary (ok : list [dict ]) -> None :
354+ print (f"\n Avg patch_coverage: { sum (r ['patch_coverage' ] for r in ok )/ len (ok ):.3f} " )
355+ print (f"Avg file_recall: { sum (r ['file_recall' ] for r in ok )/ len (ok ):.3f} " )
356+ print (f"Avg nontrivial: { sum (r ['nt_recall' ] for r in ok )/ len (ok ):.3f} " )
357+ total_deleted = sum (r ["n_deleted_in_patch" ] for r in ok )
358+ print (f"Total deleted files across all instances: { total_deleted } " )
359+
360+ stages : dict [str , int ] = {}
361+ total_nt = 0
362+ for r in ok :
363+ for _f , stage in r .get ("stage_per_file" , {}).items ():
364+ stages [stage ] = stages .get (stage , 0 ) + 1
365+ total_nt += 1
366+ if total_nt :
367+ print (f"\n Stage-wise breakdown ({ total_nt } nontrivial gold files):" )
368+ for stage in sorted (stages , key = lambda s : - stages [s ]):
369+ pct = 100 * stages [stage ] / total_nt
370+ print (f" { stage :50s} : { stages [stage ]:4d} ({ pct :5.1f} %)" )
371+
372+ print ("\n If patch_coverage < 0.95: BUG — diffctx is losing files from its own diff input." )
373+ print ("If patch_coverage > 0.95: not a patch-loss bug, look elsewhere." )
374+
375+
331376def main ():
332377 _print_threshold_sanity_check ()
333378
@@ -342,14 +387,7 @@ def main():
342387 ds = load_dataset ("Contextbench/ContextBench" , "contextbench_verified" , split = "train" )
343388 insts = list (ds )
344389 if args .nontrivial_only :
345- kept = []
346- for i in insts :
347- gb = json .loads (i ["gold_context" ]) if isinstance (i ["gold_context" ], str ) else i ["gold_context" ]
348- gold = {normalize_gold_path (g ["file" ]) for g in gb }
349- added , deleted , modified = patch_files_detailed (i ["patch" ])
350- if gold - (added | deleted | modified ):
351- kept .append (i )
352- insts = kept
390+ insts = _filter_nontrivial_instances (insts )
353391 insts = insts [: args .limit ]
354392
355393 print (f"Diagnosing { len (insts )} nontrivial instances at budget={ args .budget } \n " )
@@ -370,26 +408,7 @@ def main():
370408 for r in fail :
371409 print (f" FAIL [{ r ['status' ]} ]: { r ['id' ]} " )
372410 if ok :
373- print (f"\n Avg patch_coverage: { sum (r ['patch_coverage' ] for r in ok )/ len (ok ):.3f} " )
374- print (f"Avg file_recall: { sum (r ['file_recall' ] for r in ok )/ len (ok ):.3f} " )
375- print (f"Avg nontrivial: { sum (r ['nt_recall' ] for r in ok )/ len (ok ):.3f} " )
376- total_deleted = sum (r ["n_deleted_in_patch" ] for r in ok )
377- print (f"Total deleted files across all instances: { total_deleted } " )
378-
379- stages : dict [str , int ] = {}
380- total_nt = 0
381- for r in ok :
382- for _f , stage in r .get ("stage_per_file" , {}).items ():
383- stages [stage ] = stages .get (stage , 0 ) + 1
384- total_nt += 1
385- if total_nt :
386- print (f"\n Stage-wise breakdown ({ total_nt } nontrivial gold files):" )
387- for stage in sorted (stages , key = lambda s : - stages [s ]):
388- pct = 100 * stages [stage ] / total_nt
389- print (f" { stage :50s} : { stages [stage ]:4d} ({ pct :5.1f} %)" )
390-
391- print ("\n If patch_coverage < 0.95: BUG — diffctx is losing files from its own diff input." )
392- print ("If patch_coverage > 0.95: not a patch-loss bug, look elsewhere." )
411+ _print_ok_summary (ok )
393412
394413
395414if __name__ == "__main__" :
0 commit comments