Skip to content

Commit 7b4ff25

Browse files
committed
Vibe coded improvements of paper figures style
1 parent deff70b commit 7b4ff25

1 file changed

Lines changed: 294 additions & 4 deletions

File tree

scripts/render_mined_failure_png.py

Lines changed: 294 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#!/usr/bin/env python3
22
import argparse
3+
import json
34
import math
45
import os
6+
import re
57
from pathlib import Path
68
from concurrent.futures import ProcessPoolExecutor, as_completed
79

@@ -111,6 +113,74 @@ def fixed_target_bounds(agent_arrays, frame_idx, target_slot, half_width=50.0):
111113
return [tx - half_width, tx + half_width, ty - half_width, ty + half_width]
112114

113115

116+
def collision_partner_slot(agent_arrays, frame_idx, target_slot):
117+
if target_slot is None or not agent_arrays["valid"][frame_idx, target_slot]:
118+
return None
119+
120+
tx = float(agent_arrays["x"][frame_idx, target_slot])
121+
ty = float(agent_arrays["y"][frame_idx, target_slot])
122+
candidates = []
123+
for slot_idx in np.flatnonzero(agent_arrays["valid"][frame_idx]):
124+
if slot_idx == target_slot:
125+
continue
126+
dx = float(agent_arrays["x"][frame_idx, slot_idx]) - tx
127+
dy = float(agent_arrays["y"][frame_idx, slot_idx]) - ty
128+
candidates.append((dx * dx + dy * dy, int(slot_idx)))
129+
return min(candidates)[1] if candidates else None
130+
131+
132+
def _trajectory_axis_center(values, collision_value, half_width, margin):
133+
low = float(np.min(values)) - margin
134+
high = float(np.max(values)) + margin
135+
if high - low <= 2 * half_width:
136+
min_center = high - half_width
137+
max_center = low + half_width
138+
return float(np.clip(collision_value, min_center, max_center))
139+
140+
trajectory_center = 0.5 * (low + high)
141+
return float(np.clip(trajectory_center, collision_value - half_width, collision_value + half_width))
142+
143+
144+
def collision_trajectory_bounds(
145+
agent_arrays,
146+
frame_start,
147+
frame_idx,
148+
target_slot,
149+
partner_slot,
150+
half_width=50.0,
151+
margin=3.0,
152+
):
153+
if partner_slot is None:
154+
return fixed_target_bounds(agent_arrays, frame_idx, target_slot, half_width=half_width)
155+
156+
focus_slots = [target_slot, partner_slot]
157+
xs = []
158+
ys = []
159+
for slot_idx in focus_slots:
160+
valid = agent_arrays["valid"][frame_start : frame_idx + 1, slot_idx]
161+
frames = np.flatnonzero(valid) + frame_start
162+
xs.extend(agent_arrays["x"][frames, slot_idx].astype(float).tolist())
163+
ys.extend(agent_arrays["y"][frames, slot_idx].astype(float).tolist())
164+
165+
if not xs:
166+
return fixed_target_bounds(agent_arrays, frame_idx, target_slot, half_width=half_width)
167+
168+
collision_x = 0.5 * (
169+
float(agent_arrays["x"][frame_idx, target_slot]) + float(agent_arrays["x"][frame_idx, partner_slot])
170+
)
171+
collision_y = 0.5 * (
172+
float(agent_arrays["y"][frame_idx, target_slot]) + float(agent_arrays["y"][frame_idx, partner_slot])
173+
)
174+
center_x = _trajectory_axis_center(xs, collision_x, half_width, margin)
175+
center_y = _trajectory_axis_center(ys, collision_y, half_width, margin)
176+
return [
177+
center_x - half_width,
178+
center_x + half_width,
179+
center_y - half_width,
180+
center_y + half_width,
181+
]
182+
183+
114184
def draw_roads(ax, map_static, bounds):
115185
min_x, max_x, min_y, max_y = bounds
116186
for elem in map_static.get("road_elements", []):
@@ -241,8 +311,19 @@ def render_failure_png(
241311
None,
242312
)
243313
final_idx = agent_arrays["valid"].shape[0] - 1
314+
frame_start = 0
315+
if last_n_frames is not None:
316+
frame_start = max(0, final_idx - int(last_n_frames) + 1)
317+
partner_slot = collision_partner_slot(agent_arrays, final_idx, target_slot)
244318
if target_crop_half_width is not None:
245-
bounds = fixed_target_bounds(agent_arrays, final_idx, target_slot, half_width=float(target_crop_half_width))
319+
bounds = collision_trajectory_bounds(
320+
agent_arrays,
321+
frame_start,
322+
final_idx,
323+
target_slot,
324+
partner_slot,
325+
half_width=float(target_crop_half_width),
326+
)
246327
else:
247328
focus_slots = list(slots.values())
248329
bounds = crop_bounds(agent_arrays, focus_slots)
@@ -258,9 +339,6 @@ def render_failure_png(
258339
ax.set_facecolor("#ffffff")
259340
draw_roads(ax, payload["map"], bounds)
260341

261-
frame_start = 0
262-
if last_n_frames is not None:
263-
frame_start = max(0, final_idx - int(last_n_frames) + 1)
264342
visible_frame_count = final_idx - frame_start + 1
265343
sample_frames = sorted(set(np.linspace(frame_start, final_idx, min(6, visible_frame_count), dtype=int).tolist()))
266344
target_color = TARGET_COLOR
@@ -279,6 +357,19 @@ def render_failure_png(
279357

280358
for frame_idx in sample_frames[:-1]:
281359
draw_vehicle(ax, agent_arrays, frame_idx, target_slot, target_color, alpha=0.18, edge=target_color, zorder=4)
360+
for agent_id, slot_idx in slots.items():
361+
if slot_idx == target_slot:
362+
continue
363+
draw_vehicle(
364+
ax,
365+
agent_arrays,
366+
frame_idx,
367+
slot_idx,
368+
other_vehicle_color,
369+
alpha=0.18,
370+
edge=other_color,
371+
zorder=3,
372+
)
282373

283374
for agent_id, slot_idx in slots.items():
284375
if slot_idx == target_slot:
@@ -410,6 +501,15 @@ def _filename_for_row(row):
410501
return f"episode_{episode_id:06d}_{fault_label}_{responsibility_label}.png"
411502

412503

504+
def _load_html_summary(html_path):
505+
html = Path(html_path).read_text()
506+
match = re.search(r"const DATA = (.*?);\n", html)
507+
if match is None:
508+
raise ValueError(f"Could not find embedded replay data in {html_path}")
509+
payload = json.loads(match.group(1))
510+
return payload.get("summary", {})
511+
512+
413513
def _render_job(job):
414514
replay_path, output_path, title, subtitle, last_n_frames = job
415515
render_failure_png(
@@ -502,6 +602,154 @@ def batch_render(
502602
}
503603

504604

605+
def batch_render_csv(
606+
csv_path,
607+
output_name="paper_figures",
608+
responsibility_threshold=None,
609+
last_n_frames=50,
610+
workers=0,
611+
limit=None,
612+
):
613+
csv_path = Path(csv_path)
614+
episodes_df = pd.read_csv(csv_path)
615+
replay_dir = csv_path.parent / "replays"
616+
if not replay_dir.is_dir():
617+
raise FileNotFoundError(f"Replay directory not found: {replay_dir}")
618+
replay_index = {path.name: path for path in replay_dir.glob("episode_*.replay.zlib")}
619+
620+
has_replay = _numeric_series(episodes_df, "has_replay") > 0
621+
responsibility = pd.concat(
622+
[
623+
_numeric_series(episodes_df, "target_collision_responsibility"),
624+
_numeric_series(episodes_df, "target_hit_responsibility"),
625+
],
626+
axis=1,
627+
).max(axis=1)
628+
selected_mask = has_replay
629+
if responsibility_threshold is not None:
630+
selected_mask &= responsibility > float(responsibility_threshold)
631+
selected = episodes_df[selected_mask].copy()
632+
633+
jobs = []
634+
skipped_missing_replay = 0
635+
output_dir = csv_path.parent / output_name
636+
for row in selected.to_dict(orient="records"):
637+
episode_id = int(row["episode_id"])
638+
filename = f"episode_{episode_id:06d}.replay.zlib"
639+
replay_path = replay_index.get(filename)
640+
if replay_path is None:
641+
skipped_missing_replay += 1
642+
continue
643+
output_path = output_dir / _filename_for_row(row)
644+
jobs.append((replay_path, output_path, None, None, last_n_frames))
645+
if limit is not None and len(jobs) >= int(limit):
646+
break
647+
648+
rendered = 0
649+
failed = []
650+
if workers and int(workers) > 1 and jobs:
651+
with ProcessPoolExecutor(max_workers=int(workers)) as executor:
652+
futures = {executor.submit(_render_job, job): job for job in jobs}
653+
with tqdm(total=len(futures), desc=f"Rendering {csv_path.parent.name}") as pbar:
654+
for future in as_completed(futures):
655+
job = futures[future]
656+
try:
657+
future.result()
658+
rendered += 1
659+
except Exception as exc:
660+
failed.append((str(job[0]), str(exc)))
661+
pbar.update(1)
662+
else:
663+
for job in tqdm(jobs, desc=f"Rendering {csv_path.parent.name}"):
664+
try:
665+
_render_job(job)
666+
rendered += 1
667+
except Exception as exc:
668+
failed.append((str(job[0]), str(exc)))
669+
670+
return {
671+
"csv_path": str(csv_path),
672+
"selected": int(selected_mask.sum()),
673+
"job_count": len(jobs),
674+
"rendered": rendered,
675+
"skipped_missing_replay": skipped_missing_replay,
676+
"failed": failed,
677+
}
678+
679+
680+
def batch_render_directory(
681+
render_dir,
682+
output_name="paper_figures_academic_mp8",
683+
last_n_frames=50,
684+
workers=0,
685+
limit=None,
686+
responsibility_threshold=None,
687+
):
688+
render_dir = Path(render_dir)
689+
replay_dir = render_dir / "replays"
690+
if not replay_dir.is_dir():
691+
raise FileNotFoundError(f"Replay directory not found: {replay_dir}")
692+
693+
jobs = []
694+
missing_html = 0
695+
invalid_html = []
696+
for replay_path in sorted(replay_dir.glob("episode_*.replay.zlib")):
697+
episode_stem = replay_path.name.removesuffix(".replay.zlib")
698+
html_path = render_dir / f"{episode_stem}.html"
699+
if not html_path.exists():
700+
missing_html += 1
701+
continue
702+
try:
703+
summary = _load_html_summary(html_path)
704+
except Exception as exc:
705+
invalid_html.append((str(html_path), str(exc)))
706+
continue
707+
708+
episode_id = int(summary.get("episode_id", episode_stem.removeprefix("episode_")))
709+
summary["episode_id"] = episode_id
710+
responsibility = max(
711+
_safe_float(summary.get("target_collision_responsibility")),
712+
_safe_float(summary.get("target_hit_responsibility")),
713+
)
714+
if responsibility_threshold is not None and responsibility <= float(responsibility_threshold):
715+
continue
716+
output_path = render_dir / output_name / _filename_for_row(summary)
717+
jobs.append((replay_path, output_path, None, None, last_n_frames))
718+
if limit is not None and len(jobs) >= int(limit):
719+
break
720+
721+
rendered = 0
722+
failed = []
723+
if workers and int(workers) > 1 and jobs:
724+
with ProcessPoolExecutor(max_workers=int(workers)) as executor:
725+
futures = {executor.submit(_render_job, job): job for job in jobs}
726+
with tqdm(total=len(futures), desc=f"Rendering {render_dir.name}") as pbar:
727+
for future in as_completed(futures):
728+
job = futures[future]
729+
try:
730+
future.result()
731+
rendered += 1
732+
except Exception as exc:
733+
failed.append((str(job[0]), str(exc)))
734+
pbar.update(1)
735+
else:
736+
for job in tqdm(jobs, desc=f"Rendering {render_dir.name}"):
737+
try:
738+
_render_job(job)
739+
rendered += 1
740+
except Exception as exc:
741+
failed.append((str(job[0]), str(exc)))
742+
743+
return {
744+
"render_dir": str(render_dir),
745+
"job_count": len(jobs),
746+
"rendered": rendered,
747+
"missing_html": missing_html,
748+
"invalid_html": invalid_html,
749+
"failed": failed,
750+
}
751+
752+
505753
def main():
506754
parser = argparse.ArgumentParser()
507755
parser.add_argument("replay_path", nargs="?")
@@ -515,6 +763,18 @@ def main():
515763
parser.add_argument("--responsibility-threshold", type=float, default=0.2)
516764
parser.add_argument("--workers", type=int, default=0, help="Parallel workers for batch rendering")
517765
parser.add_argument("--limit", type=int, default=None, help="Optional max number of batch jobs for smoke tests")
766+
parser.add_argument(
767+
"--batch-render-dir",
768+
action="append",
769+
default=[],
770+
help="Render a directory containing episode HTML files and a replays/ subdirectory",
771+
)
772+
parser.add_argument(
773+
"--batch-csv",
774+
action="append",
775+
default=[],
776+
help="Render episodes selected from an explicit per-episode CSV and sibling replays/ directory",
777+
)
518778
args = parser.parse_args()
519779
if args.batch_failure_runs:
520780
summary = batch_render(
@@ -528,6 +788,36 @@ def main():
528788
print(summary)
529789
return
530790

791+
if args.batch_render_dir:
792+
summaries = [
793+
batch_render_directory(
794+
render_dir,
795+
output_name=args.batch_output_name,
796+
last_n_frames=args.last_n_frames or 50,
797+
workers=args.workers,
798+
limit=args.limit,
799+
responsibility_threshold=args.responsibility_threshold,
800+
)
801+
for render_dir in args.batch_render_dir
802+
]
803+
print(summaries)
804+
return
805+
806+
if args.batch_csv:
807+
summaries = [
808+
batch_render_csv(
809+
csv_path,
810+
output_name=args.batch_output_name,
811+
responsibility_threshold=args.responsibility_threshold,
812+
last_n_frames=args.last_n_frames or 50,
813+
workers=args.workers,
814+
limit=args.limit,
815+
)
816+
for csv_path in args.batch_csv
817+
]
818+
print(summaries)
819+
return
820+
531821
if not args.replay_path or not args.output_path:
532822
parser.error("replay_path and output_path are required unless --batch-failure-runs is used")
533823
render_failure_png(

0 commit comments

Comments
 (0)