11#!/usr/bin/env python3
22import argparse
3+ import json
34import math
45import os
6+ import re
57from pathlib import Path
68from concurrent .futures import ProcessPoolExecutor , as_completed
79
@@ -111,6 +113,74 @@ def fixed_target_bounds(agent_arrays, frame_idx, target_slot, half_width=50.0):
111113 return [tx - half_width , tx + half_width , ty - half_width , ty + half_width ]
112114
113115
116+ def collision_partner_slot (agent_arrays , frame_idx , target_slot ):
117+ if target_slot is None or not agent_arrays ["valid" ][frame_idx , target_slot ]:
118+ return None
119+
120+ tx = float (agent_arrays ["x" ][frame_idx , target_slot ])
121+ ty = float (agent_arrays ["y" ][frame_idx , target_slot ])
122+ candidates = []
123+ for slot_idx in np .flatnonzero (agent_arrays ["valid" ][frame_idx ]):
124+ if slot_idx == target_slot :
125+ continue
126+ dx = float (agent_arrays ["x" ][frame_idx , slot_idx ]) - tx
127+ dy = float (agent_arrays ["y" ][frame_idx , slot_idx ]) - ty
128+ candidates .append ((dx * dx + dy * dy , int (slot_idx )))
129+ return min (candidates )[1 ] if candidates else None
130+
131+
132+ def _trajectory_axis_center (values , collision_value , half_width , margin ):
133+ low = float (np .min (values )) - margin
134+ high = float (np .max (values )) + margin
135+ if high - low <= 2 * half_width :
136+ min_center = high - half_width
137+ max_center = low + half_width
138+ return float (np .clip (collision_value , min_center , max_center ))
139+
140+ trajectory_center = 0.5 * (low + high )
141+ return float (np .clip (trajectory_center , collision_value - half_width , collision_value + half_width ))
142+
143+
144+ def collision_trajectory_bounds (
145+ agent_arrays ,
146+ frame_start ,
147+ frame_idx ,
148+ target_slot ,
149+ partner_slot ,
150+ half_width = 50.0 ,
151+ margin = 3.0 ,
152+ ):
153+ if partner_slot is None :
154+ return fixed_target_bounds (agent_arrays , frame_idx , target_slot , half_width = half_width )
155+
156+ focus_slots = [target_slot , partner_slot ]
157+ xs = []
158+ ys = []
159+ for slot_idx in focus_slots :
160+ valid = agent_arrays ["valid" ][frame_start : frame_idx + 1 , slot_idx ]
161+ frames = np .flatnonzero (valid ) + frame_start
162+ xs .extend (agent_arrays ["x" ][frames , slot_idx ].astype (float ).tolist ())
163+ ys .extend (agent_arrays ["y" ][frames , slot_idx ].astype (float ).tolist ())
164+
165+ if not xs :
166+ return fixed_target_bounds (agent_arrays , frame_idx , target_slot , half_width = half_width )
167+
168+ collision_x = 0.5 * (
169+ float (agent_arrays ["x" ][frame_idx , target_slot ]) + float (agent_arrays ["x" ][frame_idx , partner_slot ])
170+ )
171+ collision_y = 0.5 * (
172+ float (agent_arrays ["y" ][frame_idx , target_slot ]) + float (agent_arrays ["y" ][frame_idx , partner_slot ])
173+ )
174+ center_x = _trajectory_axis_center (xs , collision_x , half_width , margin )
175+ center_y = _trajectory_axis_center (ys , collision_y , half_width , margin )
176+ return [
177+ center_x - half_width ,
178+ center_x + half_width ,
179+ center_y - half_width ,
180+ center_y + half_width ,
181+ ]
182+
183+
114184def draw_roads (ax , map_static , bounds ):
115185 min_x , max_x , min_y , max_y = bounds
116186 for elem in map_static .get ("road_elements" , []):
@@ -241,8 +311,19 @@ def render_failure_png(
241311 None ,
242312 )
243313 final_idx = agent_arrays ["valid" ].shape [0 ] - 1
314+ frame_start = 0
315+ if last_n_frames is not None :
316+ frame_start = max (0 , final_idx - int (last_n_frames ) + 1 )
317+ partner_slot = collision_partner_slot (agent_arrays , final_idx , target_slot )
244318 if target_crop_half_width is not None :
245- bounds = fixed_target_bounds (agent_arrays , final_idx , target_slot , half_width = float (target_crop_half_width ))
319+ bounds = collision_trajectory_bounds (
320+ agent_arrays ,
321+ frame_start ,
322+ final_idx ,
323+ target_slot ,
324+ partner_slot ,
325+ half_width = float (target_crop_half_width ),
326+ )
246327 else :
247328 focus_slots = list (slots .values ())
248329 bounds = crop_bounds (agent_arrays , focus_slots )
@@ -258,9 +339,6 @@ def render_failure_png(
258339 ax .set_facecolor ("#ffffff" )
259340 draw_roads (ax , payload ["map" ], bounds )
260341
261- frame_start = 0
262- if last_n_frames is not None :
263- frame_start = max (0 , final_idx - int (last_n_frames ) + 1 )
264342 visible_frame_count = final_idx - frame_start + 1
265343 sample_frames = sorted (set (np .linspace (frame_start , final_idx , min (6 , visible_frame_count ), dtype = int ).tolist ()))
266344 target_color = TARGET_COLOR
@@ -279,6 +357,19 @@ def render_failure_png(
279357
280358 for frame_idx in sample_frames [:- 1 ]:
281359 draw_vehicle (ax , agent_arrays , frame_idx , target_slot , target_color , alpha = 0.18 , edge = target_color , zorder = 4 )
360+ for agent_id , slot_idx in slots .items ():
361+ if slot_idx == target_slot :
362+ continue
363+ draw_vehicle (
364+ ax ,
365+ agent_arrays ,
366+ frame_idx ,
367+ slot_idx ,
368+ other_vehicle_color ,
369+ alpha = 0.18 ,
370+ edge = other_color ,
371+ zorder = 3 ,
372+ )
282373
283374 for agent_id , slot_idx in slots .items ():
284375 if slot_idx == target_slot :
@@ -410,6 +501,15 @@ def _filename_for_row(row):
410501 return f"episode_{ episode_id :06d} _{ fault_label } _{ responsibility_label } .png"
411502
412503
504+ def _load_html_summary (html_path ):
505+ html = Path (html_path ).read_text ()
506+ match = re .search (r"const DATA = (.*?);\n" , html )
507+ if match is None :
508+ raise ValueError (f"Could not find embedded replay data in { html_path } " )
509+ payload = json .loads (match .group (1 ))
510+ return payload .get ("summary" , {})
511+
512+
413513def _render_job (job ):
414514 replay_path , output_path , title , subtitle , last_n_frames = job
415515 render_failure_png (
@@ -502,6 +602,154 @@ def batch_render(
502602 }
503603
504604
605+ def batch_render_csv (
606+ csv_path ,
607+ output_name = "paper_figures" ,
608+ responsibility_threshold = None ,
609+ last_n_frames = 50 ,
610+ workers = 0 ,
611+ limit = None ,
612+ ):
613+ csv_path = Path (csv_path )
614+ episodes_df = pd .read_csv (csv_path )
615+ replay_dir = csv_path .parent / "replays"
616+ if not replay_dir .is_dir ():
617+ raise FileNotFoundError (f"Replay directory not found: { replay_dir } " )
618+ replay_index = {path .name : path for path in replay_dir .glob ("episode_*.replay.zlib" )}
619+
620+ has_replay = _numeric_series (episodes_df , "has_replay" ) > 0
621+ responsibility = pd .concat (
622+ [
623+ _numeric_series (episodes_df , "target_collision_responsibility" ),
624+ _numeric_series (episodes_df , "target_hit_responsibility" ),
625+ ],
626+ axis = 1 ,
627+ ).max (axis = 1 )
628+ selected_mask = has_replay
629+ if responsibility_threshold is not None :
630+ selected_mask &= responsibility > float (responsibility_threshold )
631+ selected = episodes_df [selected_mask ].copy ()
632+
633+ jobs = []
634+ skipped_missing_replay = 0
635+ output_dir = csv_path .parent / output_name
636+ for row in selected .to_dict (orient = "records" ):
637+ episode_id = int (row ["episode_id" ])
638+ filename = f"episode_{ episode_id :06d} .replay.zlib"
639+ replay_path = replay_index .get (filename )
640+ if replay_path is None :
641+ skipped_missing_replay += 1
642+ continue
643+ output_path = output_dir / _filename_for_row (row )
644+ jobs .append ((replay_path , output_path , None , None , last_n_frames ))
645+ if limit is not None and len (jobs ) >= int (limit ):
646+ break
647+
648+ rendered = 0
649+ failed = []
650+ if workers and int (workers ) > 1 and jobs :
651+ with ProcessPoolExecutor (max_workers = int (workers )) as executor :
652+ futures = {executor .submit (_render_job , job ): job for job in jobs }
653+ with tqdm (total = len (futures ), desc = f"Rendering { csv_path .parent .name } " ) as pbar :
654+ for future in as_completed (futures ):
655+ job = futures [future ]
656+ try :
657+ future .result ()
658+ rendered += 1
659+ except Exception as exc :
660+ failed .append ((str (job [0 ]), str (exc )))
661+ pbar .update (1 )
662+ else :
663+ for job in tqdm (jobs , desc = f"Rendering { csv_path .parent .name } " ):
664+ try :
665+ _render_job (job )
666+ rendered += 1
667+ except Exception as exc :
668+ failed .append ((str (job [0 ]), str (exc )))
669+
670+ return {
671+ "csv_path" : str (csv_path ),
672+ "selected" : int (selected_mask .sum ()),
673+ "job_count" : len (jobs ),
674+ "rendered" : rendered ,
675+ "skipped_missing_replay" : skipped_missing_replay ,
676+ "failed" : failed ,
677+ }
678+
679+
680+ def batch_render_directory (
681+ render_dir ,
682+ output_name = "paper_figures_academic_mp8" ,
683+ last_n_frames = 50 ,
684+ workers = 0 ,
685+ limit = None ,
686+ responsibility_threshold = None ,
687+ ):
688+ render_dir = Path (render_dir )
689+ replay_dir = render_dir / "replays"
690+ if not replay_dir .is_dir ():
691+ raise FileNotFoundError (f"Replay directory not found: { replay_dir } " )
692+
693+ jobs = []
694+ missing_html = 0
695+ invalid_html = []
696+ for replay_path in sorted (replay_dir .glob ("episode_*.replay.zlib" )):
697+ episode_stem = replay_path .name .removesuffix (".replay.zlib" )
698+ html_path = render_dir / f"{ episode_stem } .html"
699+ if not html_path .exists ():
700+ missing_html += 1
701+ continue
702+ try :
703+ summary = _load_html_summary (html_path )
704+ except Exception as exc :
705+ invalid_html .append ((str (html_path ), str (exc )))
706+ continue
707+
708+ episode_id = int (summary .get ("episode_id" , episode_stem .removeprefix ("episode_" )))
709+ summary ["episode_id" ] = episode_id
710+ responsibility = max (
711+ _safe_float (summary .get ("target_collision_responsibility" )),
712+ _safe_float (summary .get ("target_hit_responsibility" )),
713+ )
714+ if responsibility_threshold is not None and responsibility <= float (responsibility_threshold ):
715+ continue
716+ output_path = render_dir / output_name / _filename_for_row (summary )
717+ jobs .append ((replay_path , output_path , None , None , last_n_frames ))
718+ if limit is not None and len (jobs ) >= int (limit ):
719+ break
720+
721+ rendered = 0
722+ failed = []
723+ if workers and int (workers ) > 1 and jobs :
724+ with ProcessPoolExecutor (max_workers = int (workers )) as executor :
725+ futures = {executor .submit (_render_job , job ): job for job in jobs }
726+ with tqdm (total = len (futures ), desc = f"Rendering { render_dir .name } " ) as pbar :
727+ for future in as_completed (futures ):
728+ job = futures [future ]
729+ try :
730+ future .result ()
731+ rendered += 1
732+ except Exception as exc :
733+ failed .append ((str (job [0 ]), str (exc )))
734+ pbar .update (1 )
735+ else :
736+ for job in tqdm (jobs , desc = f"Rendering { render_dir .name } " ):
737+ try :
738+ _render_job (job )
739+ rendered += 1
740+ except Exception as exc :
741+ failed .append ((str (job [0 ]), str (exc )))
742+
743+ return {
744+ "render_dir" : str (render_dir ),
745+ "job_count" : len (jobs ),
746+ "rendered" : rendered ,
747+ "missing_html" : missing_html ,
748+ "invalid_html" : invalid_html ,
749+ "failed" : failed ,
750+ }
751+
752+
505753def main ():
506754 parser = argparse .ArgumentParser ()
507755 parser .add_argument ("replay_path" , nargs = "?" )
@@ -515,6 +763,18 @@ def main():
515763 parser .add_argument ("--responsibility-threshold" , type = float , default = 0.2 )
516764 parser .add_argument ("--workers" , type = int , default = 0 , help = "Parallel workers for batch rendering" )
517765 parser .add_argument ("--limit" , type = int , default = None , help = "Optional max number of batch jobs for smoke tests" )
766+ parser .add_argument (
767+ "--batch-render-dir" ,
768+ action = "append" ,
769+ default = [],
770+ help = "Render a directory containing episode HTML files and a replays/ subdirectory" ,
771+ )
772+ parser .add_argument (
773+ "--batch-csv" ,
774+ action = "append" ,
775+ default = [],
776+ help = "Render episodes selected from an explicit per-episode CSV and sibling replays/ directory" ,
777+ )
518778 args = parser .parse_args ()
519779 if args .batch_failure_runs :
520780 summary = batch_render (
@@ -528,6 +788,36 @@ def main():
528788 print (summary )
529789 return
530790
791+ if args .batch_render_dir :
792+ summaries = [
793+ batch_render_directory (
794+ render_dir ,
795+ output_name = args .batch_output_name ,
796+ last_n_frames = args .last_n_frames or 50 ,
797+ workers = args .workers ,
798+ limit = args .limit ,
799+ responsibility_threshold = args .responsibility_threshold ,
800+ )
801+ for render_dir in args .batch_render_dir
802+ ]
803+ print (summaries )
804+ return
805+
806+ if args .batch_csv :
807+ summaries = [
808+ batch_render_csv (
809+ csv_path ,
810+ output_name = args .batch_output_name ,
811+ responsibility_threshold = args .responsibility_threshold ,
812+ last_n_frames = args .last_n_frames or 50 ,
813+ workers = args .workers ,
814+ limit = args .limit ,
815+ )
816+ for csv_path in args .batch_csv
817+ ]
818+ print (summaries )
819+ return
820+
531821 if not args .replay_path or not args .output_path :
532822 parser .error ("replay_path and output_path are required unless --batch-failure-runs is used" )
533823 render_failure_png (
0 commit comments