diff --git a/examples/experimental/guidance_density.py b/examples/experimental/guidance_density.py
new file mode 100644
index 000000000..7a18680a8
--- /dev/null
+++ b/examples/experimental/guidance_density.py
@@ -0,0 +1,350 @@
+import torch
+import dataclasses
+import os
+import sys
+import mediapy
+import logging
+import numpy as np
+from time import perf_counter
+from tqdm import tqdm
+from pathlib import Path
+from box import Box
+import yaml
+from PIL import Image
+
+from gpudrive.env.config import EnvConfig
+from gpudrive.env.env_torch import GPUDriveTorchEnv
+from gpudrive.env.dataset import SceneDataLoader
+from gpudrive.datatypes.observation import GlobalEgoState
+from gpudrive.datatypes.metadata import Metadata
+from gpudrive.datatypes.info import Info
+from gpudrive.utils.checkpoint import load_agent
+from gpudrive.visualize.utils import img_from_fig
+from gpudrive.datatypes.trajectory import to_local_frame
+import madrona_gpudrive
+
+# Env Settings
+MAX_AGENTS = (
+ madrona_gpudrive.kMaxAgentCount
+) # TODO: Set to 128 for real eval
+NUM_ENVS = 1
+DEVICE = "cuda" # where to run the env rollouts
+INIT_STEPS = 10
+DATASET_SIZE = 20
+RENDER = True
+LOG_DIR = "examples/eval/figures_data/wosac/"
+GUIDANCE_MODE = (
+ "log_replay" # Options: "vbd_amortized", "vbd_online", "log_replay"
+)
+GUIDANCE_DROPOUT_MODE = "avg" # Options: "max", "avg", "remove_all"
+GUIDANCE_DROPOUT_PROB_RANGE = np.arange(0.0, 1.1, 0.33)
+SMOOTHEN_TRAJECTORY = True
+
+DATA_PATH = "data/processed/wosac/validation_interactive/json"
+
+CPT_PATH = "checkpoints/model_guidance_logs__R_10000__06_07_13_55_31_201_013500.pt"
+
+# Load agent
+agent = load_agent(path_to_cpt=CPT_PATH).to(DEVICE)
+
+config = agent.config
+
+# Create data loader
+val_loader = SceneDataLoader(
+ root=DATA_PATH,
+ batch_size=NUM_ENVS,
+ dataset_size=DATASET_SIZE,
+ sample_with_replacement=False,
+ shuffle=True,
+ file_prefix="",
+ seed=10,
+)
+
+# Override default environment settings to match those the agent was trained with
+env_config = EnvConfig(
+ ego_state=config.ego_state,
+ road_map_obs=config.road_map_obs,
+ partner_obs=config.partner_obs,
+ reward_type=config.reward_type,
+ guidance_speed_weight=config.guidance_speed_weight,
+ guidance_heading_weight=config.guidance_heading_weight,
+ smoothness_weight=config.smoothness_weight,
+ norm_obs=config.norm_obs,
+ add_previous_action=config.add_previous_action,
+ guidance=config.guidance,
+ add_reference_pos_xy=config.add_reference_pos_xy,
+ add_reference_speed=config.add_reference_speed,
+ add_reference_heading=config.add_reference_heading,
+ dynamics_model=config.dynamics_model,
+ collision_behavior=config.collision_behavior,
+ goal_behavior=config.goal_behavior,
+ polyline_reduction_threshold=config.polyline_reduction_threshold,
+ remove_non_vehicles=config.remove_non_vehicles,
+ lidar_obs=False,
+ obs_radius=config.obs_radius,
+ action_space_steer_disc=config.action_space_steer_disc,
+ action_space_accel_disc=config.action_space_accel_disc,
+ init_mode="wosac_eval",
+ init_steps=INIT_STEPS,
+ guidance_mode=GUIDANCE_MODE,
+ guidance_dropout_prob=GUIDANCE_DROPOUT_PROB_RANGE[0], # Set to 0 for the first run
+ guidance_dropout_mode=GUIDANCE_DROPOUT_MODE,
+ smoothen_trajectory=SMOOTHEN_TRAJECTORY,
+)
+
+# Make environment
+env = GPUDriveTorchEnv(
+ config=env_config,
+ data_loader=val_loader,
+ max_cont_agents=MAX_AGENTS,
+ device=DEVICE,
+)
+
+def transform_trajectories_to_local_frame(global_trajectories, ego_pos, ego_yaw, device):
+ """
+ Transform trajectories from simulator coordinates to local (ego-centric) coordinates.
+
+ Args:
+ global_trajectories: Tensor of shape [n_rollouts, n_steps, 2] (x, y positions in simulator coords)
+ ego_pos: Tensor of shape [2] (ego x, y position in simulator coords)
+ ego_yaw: Scalar tensor (ego heading angle)
+ device: Device to run computations on
+
+ Returns:
+ local_trajectories: Tensor of shape [n_rollouts, n_steps, 2] in ego frame
+ """
+ n_rollouts, n_steps, _ = global_trajectories.shape
+ local_trajectories = torch.zeros_like(global_trajectories)
+
+ for rollout_idx in range(n_rollouts):
+ for step_idx in range(n_steps):
+ global_pos = global_trajectories[rollout_idx, step_idx, :]
+ local_pos = to_local_frame(
+ global_pos_xy=global_pos.unsqueeze(0), # Add batch dimension
+ ego_pos=ego_pos,
+ ego_yaw=ego_yaw,
+ device=device,
+ )
+ local_trajectories[rollout_idx, step_idx, :] = local_pos.squeeze(0)
+
+ return local_trajectories
+
+# Create output directory
+os.makedirs('guidance_density', exist_ok=True)
+
+scene_count = 0
+while scene_count < DATASET_SIZE:
+
+ # Save Trajectories for the first controlled agent in global coordinates
+ all_global_trajectories = []
+
+ # Get the first controlled agent index for this scene
+ control_mask = env.cont_agent_mask.clone().cpu()
+ first_controlled_agent_idx = torch.where(control_mask[0])[0][0].item()
+
+ # For each guidance density, rollout and collect agent trajectories
+ for GUIDANCE_DROPOUT_PROB in GUIDANCE_DROPOUT_PROB_RANGE:
+
+ # Update the environment configuration for the current guidance dropout probability
+ env.config.guidance_dropout_prob = GUIDANCE_DROPOUT_PROB
+ control_mask = env.cont_agent_mask.clone().cpu()
+ next_obs = env.reset(mask=control_mask)
+
+ global_trajectories = []
+
+ # Zero out actions for parked vehicles
+ info = Info.from_tensor(
+ env.sim.info_tensor(),
+ backend=env.backend,
+ device=env.device,
+ )
+
+ zero_action_mask = (info.off_road == 1) | (
+ info.collided_with_vehicle == 1
+ ) & (info.type == int(madrona_gpudrive.EntityType.Vehicle))
+
+ # Guidance logging
+ num_guidance_points = env.valid_guidance_points
+ guidance_densities = num_guidance_points / env.reference_traj_len
+ print(
+ f"Avg guidance points per agent: {num_guidance_points.cpu().numpy().mean():.2f} which is {guidance_densities.mean().item()*100:.2f} % of the trajectory length (mode = {env.config.guidance_dropout_mode}) \n"
+ )
+
+ # Get position in simulator coordinates (with world mean already subtracted)
+ global_agent_states = GlobalEgoState.from_tensor(
+ env.sim.absolute_self_observation_tensor(),
+ backend=env.backend,
+ device="cpu",
+ )
+
+ # Store trajectory for the first controlled agent only (already in simulator coordinates)
+ first_agent_pos = global_agent_states.pos_xy[0, first_controlled_agent_idx, :]
+ global_trajectories.append(first_agent_pos)
+
+ done_list = [env.get_dones()]
+
+ for time_step in range(env.episode_len - env.init_steps):
+
+ # Predict actions
+ action, _, _, _ = agent(next_obs)
+
+ action_template = torch.zeros(
+ (env.num_worlds, madrona_gpudrive.kMaxAgentCount), dtype=torch.int64, device=env.device
+ )
+ action_template[control_mask] = action.to(env.device)
+
+ # Find the integer key for the "do nothing" action (zero steering, zero acceleration)
+ DO_NOTHING_ACTION_INT = [
+ key
+ for key, value in env.action_key_to_values.items()
+ if abs(value[0]) == 0.0
+ and abs(value[1]) == 0.0
+ and abs(value[2]) == 0.0
+ ][0]
+ action_template[zero_action_mask] = DO_NOTHING_ACTION_INT
+
+ # Step
+ env.step_dynamics(action_template)
+
+ # Get next observation
+ next_obs = env.get_obs(control_mask)
+
+ # Save to trajectories in simulator coordinates (world mean already subtracted)
+ global_agent_states = GlobalEgoState.from_tensor(
+ env.sim.absolute_self_observation_tensor(),
+ backend=env.backend,
+ device="cpu",
+ )
+
+ # Store trajectory for the first controlled agent only (already in simulator coordinates)
+ first_agent_pos = global_agent_states.pos_xy[0, first_controlled_agent_idx, :]
+ global_trajectories.append(first_agent_pos)
+
+ reward = env.get_rewards()
+ done = env.get_dones()
+ done_list.append(done)
+
+ _ = done_list.pop()
+
+ # Stack trajectories for this rollout: shape [n_steps, 2]
+ rollout_trajectory = torch.stack(global_trajectories, dim=0)
+ all_global_trajectories.append(rollout_trajectory)
+
+ # Stack all rollouts: shape [n_rollouts, n_steps, 2]
+ all_global_trajectories = torch.stack(all_global_trajectories, dim=0)
+
+ # Reset environment to get the initial state for egocentric transformation
+ control_mask = env.cont_agent_mask.clone().cpu()
+ env.config.guidance_dropout_prob = 0.0 # Reset to no dropout for final state
+ _ = env.reset(mask=control_mask)
+
+ # Get ego agent's initial position and heading for coordinate transformation
+ # Use simulator coordinates (world mean already subtracted) - this is what to_local_frame expects
+ global_agent_states = GlobalEgoState.from_tensor(
+ env.sim.absolute_self_observation_tensor(),
+ backend=env.backend,
+ device="cpu",
+ )
+
+ ego_pos = global_agent_states.pos_xy[0, first_controlled_agent_idx, :] # [x, y] in simulator coords
+ ego_yaw = global_agent_states.rotation_angle[0, first_controlled_agent_idx] # scalar
+
+ # Transform all trajectories to ego-centric coordinates
+ local_trajectories = transform_trajectories_to_local_frame(
+ all_global_trajectories, ego_pos, ego_yaw, device="cpu"
+ )
+
+ # Create a combined trajectory array with weights for coloring
+ # Shape: [n_rollouts, n_steps, 2]
+ trajectory_weights = 1 - GUIDANCE_DROPOUT_PROB_RANGE # Higher weight = more guidance
+
+ # Plot using the egocentric view
+ fig = env.vis.plot_agent_observation(
+ env_idx=0,
+ agent_idx=first_controlled_agent_idx,
+ figsize=(12, 12),
+ trajectory=None, # We'll modify the function to handle multiple trajectories
+ step_reward=None,
+ route_progress=None,
+ )
+
+ # Manually add the multiple trajectories with different colors based on guidance density
+ if fig is not None:
+ ax = fig.get_axes()[0]
+
+ # Clear any existing legend
+ legend = ax.get_legend()
+ if legend:
+ legend.remove()
+
+ # Color trajectories based on guidance density
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+
+ # Create colormap for guidance density (reverse so high guidance = dark, low guidance = light)
+ colors = sns.color_palette("mako", len(GUIDANCE_DROPOUT_PROB_RANGE))
+ colors = colors[::-1] # Reverse the color order
+
+ # Store trajectory lines for legend
+ trajectory_lines = []
+ trajectory_labels = []
+
+ for rollout_idx, (trajectory, dropout_prob) in enumerate(zip(local_trajectories, GUIDANCE_DROPOUT_PROB_RANGE)):
+ # Filter out invalid points
+ valid_mask = (
+ (trajectory[:, 0] != 0) &
+ (trajectory[:, 1] != 0) &
+ (torch.abs(trajectory[:, 0]) < 1000) &
+ (torch.abs(trajectory[:, 1]) < 1000)
+ )
+
+ if valid_mask.sum() > 1:
+ valid_trajectory = trajectory[valid_mask]
+
+ # Plot trajectory line
+ line, = ax.plot(
+ valid_trajectory[:, 0].cpu().numpy(),
+ valid_trajectory[:, 1].cpu().numpy(),
+ color=colors[rollout_idx],
+ linewidth=2.0,
+ alpha=0.5,
+ zorder=1
+ )
+
+ # Store for legend
+ trajectory_lines.append(line)
+ trajectory_labels.append(f'Guidance: {(1-dropout_prob)*100:.0f}%')
+
+ # Plot trajectory points
+ ax.scatter(
+ valid_trajectory[:, 0].cpu().numpy(),
+ valid_trajectory[:, 1].cpu().numpy(),
+ color=colors[rollout_idx],
+ s=15,
+ alpha=0.4,
+ zorder=1
+ )
+
+ # Add legend with only trajectory lines
+ ax.legend(trajectory_lines, trajectory_labels, loc='upper right', fontsize=14, framealpha=0.9)
+
+ # Add title
+ ax.set_title(f'Egocentric View - Scene {scene_count}\nAgent {first_controlled_agent_idx} Trajectories by Guidance Density',
+ fontsize=14, pad=20)
+
+ # Save the figure
+ fig.savefig(f'guidance_density/scene_{scene_count}_agent_{first_controlled_agent_idx}.pdf', format='pdf', bbox_inches='tight')
+
+ plt.close(fig)
+
+ print(f"Processed scene {scene_count} with agent {first_controlled_agent_idx}")
+ scene_count += 1
+
+ try:
+ env.swap_data_batch()
+ except StopIteration:
+ # If we run out of scenes, break the loop
+ print("No more scenes in the dataset.")
+ break
+
+print(f"Generated {scene_count} egocentric visualization figures in 'guidance_density/' directory")
\ No newline at end of file
diff --git a/examples/experimental/notebooks/06_guidance_density.ipynb b/examples/experimental/notebooks/06_guidance_density.ipynb
new file mode 100644
index 000000000..311073cca
--- /dev/null
+++ b/examples/experimental/notebooks/06_guidance_density.ipynb
@@ -0,0 +1,414 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "701fe64b",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "RuntimeError",
+ "evalue": "module compiled against ABI version 0x1000009 but this version of numpy is 0x2000000",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
+ "\u001b[31mRuntimeError\u001b[39m: module compiled against ABI version 0x1000009 but this version of numpy is 0x2000000"
+ ]
+ },
+ {
+ "ename": "RuntimeError",
+ "evalue": "module compiled against ABI version 0x1000009 but this version of numpy is 0x2000000",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
+ "\u001b[31mRuntimeError\u001b[39m: module compiled against ABI version 0x1000009 but this version of numpy is 0x2000000"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nadarenator/miniconda3/envs/gpudrive/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import dataclasses\n",
+ "import os\n",
+ "import sys\n",
+ "import mediapy\n",
+ "import logging\n",
+ "import numpy as np\n",
+ "from time import perf_counter\n",
+ "from tqdm import tqdm\n",
+ "from pathlib import Path\n",
+ "from box import Box\n",
+ "import yaml\n",
+ "from PIL import Image\n",
+ "\n",
+ "from gpudrive.env.config import EnvConfig\n",
+ "from gpudrive.env.env_torch import GPUDriveTorchEnv\n",
+ "from gpudrive.env.dataset import SceneDataLoader\n",
+ "from gpudrive.datatypes.observation import GlobalEgoState\n",
+ "from gpudrive.datatypes.metadata import Metadata\n",
+ "from gpudrive.datatypes.info import Info\n",
+ "from gpudrive.utils.checkpoint import load_agent\n",
+ "from gpudrive.visualize.utils import img_from_fig\n",
+ "import madrona_gpudrive\n",
+ "\n",
+ "working_dir = Path.cwd()\n",
+ "while working_dir.name != 'gpudrive':\n",
+ " working_dir = working_dir.parent\n",
+ " if working_dir == Path.home():\n",
+ " raise FileNotFoundError(\"Base directory 'gpudrive' not found\")\n",
+ "os.chdir(working_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1b92f4a2",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
+ "\u001b[1;31mClick here for more info. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
+ }
+ ],
+ "source": [
+ "# Env Settings\n",
+ "MAX_AGENTS = (\n",
+ " madrona_gpudrive.kMaxAgentCount\n",
+ ") # TODO: Set to 128 for real eval\n",
+ "NUM_ENVS = 1\n",
+ "DEVICE = \"cuda\" # where to run the env rollouts\n",
+ "INIT_STEPS = 10\n",
+ "DATASET_SIZE = 5\n",
+ "RENDER = True\n",
+ "LOG_DIR = \"examples/eval/figures_data/wosac/\"\n",
+ "GUIDANCE_MODE = (\n",
+ " \"log_replay\" # Options: \"vbd_amortized\", \"vbd_online\", \"log_replay\"\n",
+ ")\n",
+ "GUIDANCE_DROPOUT_MODE = \"avg\" # Options: \"max\", \"avg\", \"remove_all\"\n",
+ "GUIDANCE_DROPOUT_PROB_RANGE = np.arange(0.0, 1.1, 0.1)\n",
+ "SMOOTHEN_TRAJECTORY = True\n",
+ "\n",
+ "DATA_PATH = \"data/processed/wosac/validation_interactive/json\"\n",
+ "\n",
+ "CPT_PATH = \"checkpoints/model_guidance_logs__R_10000__05_14_16_54_46_975_002500.pt\"\n",
+ "\n",
+ "# Load agent\n",
+ "agent = load_agent(path_to_cpt=CPT_PATH).to(DEVICE)\n",
+ "\n",
+ "config = agent.config\n",
+ "\n",
+ "# Save Trajectories:\n",
+ "all_trajectories = []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "784e282f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nadarenator/Desktop/Projects/gpudrive/src/dynamics.hpp(16): warning #177-D: variable \"clipSpeed\" was declared but never referenced\n",
+ " auto clipSpeed = [maxSpeed](float speed)\n",
+ " ^\n",
+ "\n",
+ "Remark: The warnings can be suppressed with \"-diag-suppress \"\n",
+ "\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/src/dynamics.hpp(21): warning #177-D: variable \"polarToVector2D\" was declared but never referenced\n",
+ " auto polarToVector2D = [](float r, float theta)\n",
+ " ^\n",
+ "\n",
+ "\n",
+ "\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/src/dynamics.hpp(16): warning #177-D: variable \"clipSpeed\" was declared but never referenced\n",
+ " auto clipSpeed = [maxSpeed](float speed)\n",
+ " ^\n",
+ "\n",
+ "Remark: The warnings can be suppressed with \"-diag-suppress \"\n",
+ "\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/src/dynamics.hpp(21): warning #177-D: variable \"polarToVector2D\" was declared but never referenced\n",
+ " auto polarToVector2D = [](float r, float theta)\n",
+ " ^\n",
+ "\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/src/level_gen.cpp(323): warning #177-D: function \"madrona_gpudrive::createFloorPlane\" was declared but never referenced\n",
+ " static void createFloorPlane(Engine &ctx)\n",
+ " ^\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Compiling GPU engine code:\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/memory.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/state.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/crash.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/consts.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/taskgraph.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/taskgraph_utils.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/sort_archetype.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/device/host_print.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../common/hashmap.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../common/navmesh.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../core/base.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../physics/physics.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../physics/geo.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../physics/xpbd.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../physics/tgs.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../physics/narrowphase.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../physics/broadphase.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/external/madrona/src/mw/../render/ecs_system.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/src/sim.cpp\n",
+ "/home/nadarenator/Desktop/Projects/gpudrive/src/level_gen.cpp\n",
+ "Initialization finished\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Smoothing guidance data: 100%|\u001b[34m██████████\u001b[0m| 1/1 [00:00<00:00, 139.38it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Avg guidance points per agent: 75.07 which is 82.49 % of the trajectory length (mode = avg) \n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# For each guidance density, create the env, rollout and collect agent trajectories\n",
+ "for GUIDANCE_DROPOUT_PROB in GUIDANCE_DROPOUT_PROB_RANGE:\n",
+ " trajectories = []\n",
+ " # Create data loader\n",
+ " val_loader = SceneDataLoader(\n",
+ " root=DATA_PATH,\n",
+ " batch_size=NUM_ENVS,\n",
+ " dataset_size=DATASET_SIZE,\n",
+ " sample_with_replacement=False,\n",
+ " shuffle=True,\n",
+ " file_prefix=\"\",\n",
+ " seed=10,\n",
+ " )\n",
+ "\n",
+ " # Override default environment settings to match those the agent was trained with\n",
+ " # TODO(dc): Clean this up\n",
+ " env_config = EnvConfig(\n",
+ " ego_state=config.ego_state,\n",
+ " road_map_obs=config.road_map_obs,\n",
+ " partner_obs=config.partner_obs,\n",
+ " reward_type=config.reward_type,\n",
+ " guidance_speed_weight=config.guidance_speed_weight,\n",
+ " guidance_heading_weight=config.guidance_heading_weight,\n",
+ " smoothness_weight=config.smoothness_weight,\n",
+ " norm_obs=config.norm_obs,\n",
+ " add_previous_action=config.add_previous_action,\n",
+ " guidance=config.guidance,\n",
+ " add_reference_pos_xy=config.add_reference_pos_xy,\n",
+ " add_reference_speed=config.add_reference_speed,\n",
+ " add_reference_heading=config.add_reference_heading,\n",
+ " dynamics_model=config.dynamics_model,\n",
+ " collision_behavior=config.collision_behavior,\n",
+ " goal_behavior=config.goal_behavior,\n",
+ " polyline_reduction_threshold=config.polyline_reduction_threshold,\n",
+ " remove_non_vehicles=config.remove_non_vehicles,\n",
+ " lidar_obs=False,\n",
+ " obs_radius=config.obs_radius,\n",
+ " max_steer_angle=config.max_steer_angle,\n",
+ " max_accel_value=config.max_accel_value,\n",
+ " action_space_steer_disc=config.action_space_steer_disc,\n",
+ " action_space_accel_disc=config.action_space_accel_disc,\n",
+ " # Override action space\n",
+ " steer_actions=torch.round(\n",
+ " torch.linspace(\n",
+ " -config.max_steer_angle,\n",
+ " config.max_steer_angle,\n",
+ " config.action_space_steer_disc,\n",
+ " ),\n",
+ " decimals=3,\n",
+ " ),\n",
+ " accel_actions=torch.round(\n",
+ " torch.linspace(\n",
+ " -config.max_accel_value,\n",
+ " config.max_accel_value,\n",
+ " config.action_space_accel_disc,\n",
+ " ),\n",
+ " decimals=3,\n",
+ " ),\n",
+ " init_mode=\"wosac_eval\",\n",
+ " init_steps=INIT_STEPS,\n",
+ " guidance_mode=GUIDANCE_MODE,\n",
+ " guidance_dropout_prob=GUIDANCE_DROPOUT_PROB,\n",
+ " guidance_dropout_mode=GUIDANCE_DROPOUT_MODE,\n",
+ " smoothen_trajectory=SMOOTHEN_TRAJECTORY,\n",
+ " )\n",
+ "\n",
+ " # Make environment\n",
+ " env = GPUDriveTorchEnv(\n",
+ " config=env_config,\n",
+ " data_loader=val_loader,\n",
+ " max_cont_agents=MAX_AGENTS,\n",
+ " device=DEVICE,\n",
+ " )\n",
+ "\n",
+ " # Zero out actions for parked vehicles\n",
+ " info = Info.from_tensor(\n",
+ " env.sim.info_tensor(),\n",
+ " backend=env.backend,\n",
+ " device=env.device,\n",
+ " )\n",
+ "\n",
+ " zero_action_mask = (info.off_road == 1) | (\n",
+ " info.collided_with_vehicle == 1\n",
+ " ) & (info.type == int(madrona_gpudrive.EntityType.Vehicle))\n",
+ "\n",
+ " control_mask = env.cont_agent_mask.clone().cpu()\n",
+ "\n",
+ " next_obs = env.reset(mask=control_mask)\n",
+ "\n",
+ " # Guidance logging\n",
+ " num_guidance_points = env.valid_guidance_points\n",
+ " guidance_densities = num_guidance_points / env.reference_traj_len\n",
+ " print(\n",
+ " f\"Avg guidance points per agent: {num_guidance_points.cpu().numpy().mean():.2f} which is {guidance_densities.mean().item()*100:.2f} % of the trajectory length (mode = {env.config.guidance_dropout_mode}) \\n\"\n",
+ " )\n",
+ "\n",
+ " pos_xy = GlobalEgoState.from_tensor(\n",
+ " env.sim.absolute_self_observation_tensor(),\n",
+ " backend=env.backend,\n",
+ " device=\"cpu\",\n",
+ " ).pos_xy[control_mask]\n",
+ "\n",
+ " trajectories.append(pos_xy)\n",
+ "\n",
+ " done_list = [env.get_dones()]\n",
+ "\n",
+ " for time_step in range(env.episode_len - env.init_steps):\n",
+ "\n",
+ " # Predict actions\n",
+ " action, _, _, _ = agent(next_obs)\n",
+ "\n",
+ " action_template = torch.zeros(\n",
+ " (env.num_worlds, madrona_gpudrive.kMaxAgentCount), dtype=torch.int64, device=env.device\n",
+ " )\n",
+ " action_template[control_mask] = action.to(env.device)\n",
+ "\n",
+ " # Find the integer key for the \"do nothing\" action (zero steering, zero acceleration)\n",
+ " # Check using env.action_key_to_values[DO_NOTHING_ACTION_INT]\n",
+ " DO_NOTHING_ACTION_INT = [\n",
+ " key\n",
+ " for key, value in env.action_key_to_values.items()\n",
+ " if abs(value[0]) == 0.0\n",
+ " and abs(value[1]) == 0.0\n",
+ " and abs(value[2]) == 0.0\n",
+ " ][0]\n",
+ " action_template[zero_action_mask] = DO_NOTHING_ACTION_INT\n",
+ "\n",
+ " # Step\n",
+ " env.step_dynamics(action_template)\n",
+ "\n",
+ " # Get next observation\n",
+ " next_obs = env.get_obs(control_mask)\n",
+ "\n",
+ " # Save to trajectories\n",
+ " pos_xy = GlobalEgoState.from_tensor(\n",
+ " env.sim.absolute_self_observation_tensor(),\n",
+ " backend=env.backend,\n",
+ " device=\"cpu\",\n",
+ " ).pos_xy[control_mask]\n",
+ " trajectories.append(pos_xy)\n",
+ "\n",
+ " # NOTE(dc): Make sure to decouple the obs from the reward function\n",
+ " reward = env.get_rewards()\n",
+ " done = env.get_dones()\n",
+ " done_list.append(done)\n",
+ " \n",
+ " _ = done_list.pop()\n",
+ "\n",
+ " trajectories = torch.stack(trajectories, dim=0).cpu().permute(1, 0, 2)\n",
+ " all_trajectories.append(trajectories)\n",
+ "\n",
+ "all_trajectories = torch.stack(all_trajectories, dim=0).cpu()\n",
+ "all_trajectories = all_trajectories.unsqueeze(0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "faec5d2f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "81\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Plot trajectories and save\n",
+ "_ = env.reset(mask=control_mask)\n",
+ "\n",
+ "fig = env.vis.plot_simulator_state(\n",
+ " env_indices=[0],\n",
+ " agent_positions=all_trajectories,\n",
+ " zoom_radius=70,\n",
+ " multiple_rollouts=True,\n",
+ " line_alpha=0.5,\n",
+ " line_width=1.0,\n",
+ " weights=GUIDANCE_DROPOUT_PROB_RANGE,\n",
+ " colorbar=True,\n",
+ ")[0]\n",
+ "\n",
+ "Image.fromarray(img_from_fig(fig))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "gpudrive",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/experimental/rollout_gifs.py b/examples/experimental/rollout_gifs.py
new file mode 100644
index 000000000..877d00bec
--- /dev/null
+++ b/examples/experimental/rollout_gifs.py
@@ -0,0 +1,283 @@
+import torch
+import argparse
+import os
+import numpy as np
+import mediapy as media
+from PIL import Image
+from tqdm import tqdm
+
+from gpudrive.env.config import EnvConfig
+from gpudrive.env.env_torch import GPUDriveTorchEnv
+from gpudrive.env.dataset import SceneDataLoader
+from gpudrive.datatypes.observation import GlobalEgoState
+from gpudrive.datatypes.info import Info
+from gpudrive.utils.checkpoint import load_agent
+from gpudrive.visualize.utils import img_from_fig
+import madrona_gpudrive
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Generate GIFs of scenes with specified guidance density')
+ parser.add_argument('--guidance_dropout_prob', type=float, default=0.0,
+ help='Guidance dropout probability (0.0 = full guidance, 1.0 = no guidance)')
+ parser.add_argument('--dataset_size', type=int, default=10,
+ help='Number of scenes to process')
+ parser.add_argument('--data_path', type=str,
+ default="data/processed/wosac/validation_interactive/json",
+ help='Path to dataset')
+ parser.add_argument('--checkpoint_path', type=str,
+ default="checkpoints/model_guidance_logs__R_10000__06_07_13_55_31_201_013500.pt",
+ help='Path to agent checkpoint')
+ parser.add_argument('--output_dir', type=str, default='guidance_gifs',
+ help='Output directory for GIFs')
+ parser.add_argument('--device', type=str, default='cuda',
+ help='Device to run on (cuda/cpu)')
+ parser.add_argument('--zoom_radius', type=int, default=45,
+ help='Zoom radius for visualization')
+ parser.add_argument('--fps', type=int, default=10,
+ help='Frames per second for GIF')
+ parser.add_argument('--render_frequency', type=int, default=1,
+ help='Render every N timesteps (1 = every timestep)')
+ parser.add_argument('--guidance_mode', type=str, default='log_replay',
+ choices=['log_replay', 'vbd_amortized', 'vbd_online'],
+ help='Guidance mode to use')
+ parser.add_argument('--guidance_dropout_mode', type=str, default='avg',
+ choices=['max', 'avg', 'remove_all'],
+ help='Guidance dropout mode')
+
+ return parser.parse_args()
+
+
+def setup_environment(args):
+ """Setup the environment and agent"""
+
+ # Environment constants
+ MAX_AGENTS = madrona_gpudrive.kMaxAgentCount
+ NUM_ENVS = 1
+ INIT_STEPS = 10
+
+ # Load agent
+ print(f"Loading agent from {args.checkpoint_path}...")
+ agent = load_agent(path_to_cpt=args.checkpoint_path).to(args.device)
+ config = agent.config
+
+ # Create data loader
+ print(f"Loading dataset from {args.data_path}...")
+ val_loader = SceneDataLoader(
+ root=args.data_path,
+ batch_size=NUM_ENVS,
+ dataset_size=args.dataset_size,
+ sample_with_replacement=False,
+ shuffle=True,
+ file_prefix="",
+ seed=10,
+ )
+
+ # Override default environment settings
+ env_config = EnvConfig(
+ ego_state=config.ego_state,
+ road_map_obs=config.road_map_obs,
+ partner_obs=config.partner_obs,
+ reward_type=config.reward_type,
+ guidance_speed_weight=config.guidance_speed_weight,
+ guidance_heading_weight=config.guidance_heading_weight,
+ smoothness_weight=config.smoothness_weight,
+ norm_obs=config.norm_obs,
+ add_previous_action=config.add_previous_action,
+ guidance=config.guidance,
+ add_reference_pos_xy=config.add_reference_pos_xy,
+ add_reference_speed=config.add_reference_speed,
+ add_reference_heading=config.add_reference_heading,
+ dynamics_model=config.dynamics_model,
+ collision_behavior=config.collision_behavior,
+ goal_behavior=config.goal_behavior,
+ polyline_reduction_threshold=config.polyline_reduction_threshold,
+ remove_non_vehicles=config.remove_non_vehicles,
+ lidar_obs=False,
+ obs_radius=config.obs_radius,
+ action_space_steer_disc=config.action_space_steer_disc,
+ action_space_accel_disc=config.action_space_accel_disc,
+ init_mode="wosac_eval",
+ init_steps=INIT_STEPS,
+ guidance_mode=args.guidance_mode,
+ guidance_dropout_prob=args.guidance_dropout_prob,
+ guidance_dropout_mode=args.guidance_dropout_mode,
+ smoothen_trajectory=True,
+ )
+
+ # Make environment
+ env = GPUDriveTorchEnv(
+ config=env_config,
+ data_loader=val_loader,
+ max_cont_agents=MAX_AGENTS,
+ device=args.device,
+ )
+
+ return env, agent
+
+
+def rollout_scene(env, agent, args):
+ """Rollout a single scene and collect frames"""
+
+ control_mask = env.cont_agent_mask.clone().cpu()
+ next_obs = env.reset(mask=control_mask)
+
+ # Zero out actions for parked vehicles
+ info = Info.from_tensor(
+ env.sim.info_tensor(),
+ backend=env.backend,
+ device=env.device,
+ )
+
+ zero_action_mask = (info.off_road == 1) | (
+ info.collided_with_vehicle == 1
+ ) & (info.type == int(madrona_gpudrive.EntityType.Vehicle))
+
+ # Log guidance info
+ num_guidance_points = env.valid_guidance_points
+ guidance_densities = num_guidance_points / env.reference_traj_len
+ guidance_percent = guidance_densities.mean().item() * 100
+
+ print(f" Guidance density: {guidance_percent:.1f}% "
+ f"(avg {num_guidance_points.cpu().numpy().mean():.1f} points)")
+
+ frames = []
+
+ # Get initial frame
+ if 0 % args.render_frequency == 0:
+ fig = env.vis.plot_simulator_state(
+ env_indices=[0],
+ zoom_radius=args.zoom_radius,
+ plot_guidance_pos_xy=True,
+ center_agent_indices=[0],
+ )[0]
+
+ # Add guidance info to the plot
+ ax = fig.get_axes()[0]
+ ax.text(
+ 0.05, 0.90,
+ f"Guidance: {guidance_percent:.1f}%\nDropout prob: {args.guidance_dropout_prob:.2f}",
+ transform=ax.transAxes,
+ fontsize=12,
+ color="white",
+ ha="left", va="top",
+ bbox=dict(facecolor="black", alpha=0.7, edgecolor="none", pad=5)
+ )
+
+ frames.append(img_from_fig(fig))
+
+ # Rollout the episode
+ for time_step in tqdm(range(env.episode_len - env.config.init_steps),
+ desc=" Rolling out", leave=False):
+
+ # Predict actions
+ action, _, _, _ = agent(next_obs)
+
+ action_template = torch.zeros(
+ (env.num_worlds, madrona_gpudrive.kMaxAgentCount),
+ dtype=torch.int64, device=env.device
+ )
+ action_template[control_mask] = action.to(env.device)
+
+ # Find the "do nothing" action for parked vehicles
+ DO_NOTHING_ACTION_INT = [
+ key for key, value in env.action_key_to_values.items()
+ if abs(value[0]) == 0.0 and abs(value[1]) == 0.0 and abs(value[2]) == 0.0
+ ][0]
+ action_template[zero_action_mask] = DO_NOTHING_ACTION_INT
+
+ # Step environment
+ env.step_dynamics(action_template)
+ next_obs = env.get_obs(control_mask)
+
+ # Render frame if needed
+ if (time_step + 1) % args.render_frequency == 0:
+ fig = env.vis.plot_simulator_state(
+ env_indices=[0],
+ zoom_radius=args.zoom_radius,
+ plot_guidance_pos_xy=True,
+ center_agent_indices=[0],
+ )[0]
+
+ # Add guidance info to the plot
+ ax = fig.get_axes()[0]
+ ax.text(
+ 0.05, 0.90,
+ f"Guidance: {guidance_percent:.1f}%\nDropout prob: {args.guidance_dropout_prob:.2f}",
+ transform=ax.transAxes,
+ fontsize=12,
+ color="white",
+ ha="left", va="top",
+ bbox=dict(facecolor="black", alpha=0.7, edgecolor="none", pad=5)
+ )
+
+ frames.append(img_from_fig(fig))
+
+ return frames, guidance_percent
+
+
+def main():
+ args = parse_args()
+
+ # Create output directory
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ print(f"Generating GIFs with guidance dropout probability: {args.guidance_dropout_prob}")
+ print(f"Output directory: {args.output_dir}")
+
+ # Setup environment and agent
+ env, agent = setup_environment(args)
+
+ scene_count = 0
+ successful_scenes = 0
+
+ while scene_count < args.dataset_size:
+ try:
+ print(f"\nProcessing scene {scene_count + 1}/{args.dataset_size}...")
+
+ # Rollout scene and collect frames
+ frames, guidance_percent = rollout_scene(env, agent, args)
+
+ if frames:
+ # Create filename with guidance info
+ guidance_str = f"guidance_{guidance_percent:.1f}pct"
+ dropout_str = f"dropout_{args.guidance_dropout_prob:.2f}"
+ filename = f"scene_{scene_count:03d}_{guidance_str}_{dropout_str}.gif"
+ filepath = os.path.join(args.output_dir, filename)
+
+ # Save GIF
+ print(f" Saving GIF with {len(frames)} frames to {filename}")
+ media.write_video(
+ filepath,
+ np.array(frames),
+ fps=args.fps,
+ codec="gif"
+ )
+
+ successful_scenes += 1
+ else:
+ print(f" Warning: No frames generated for scene {scene_count}")
+
+ scene_count += 1
+
+ # Try to load next scene
+ try:
+ env.swap_data_batch()
+ except StopIteration:
+ print("No more scenes in the dataset.")
+ break
+
+ except Exception as e:
+ print(f" Error processing scene {scene_count}: {e}")
+ scene_count += 1
+ try:
+ env.swap_data_batch()
+ except StopIteration:
+ break
+
+ print(f"\nCompleted! Generated {successful_scenes} GIFs out of {scene_count} scenes.")
+ print(f"GIFs saved in: {args.output_dir}")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/examples/experimental/rollout_velocity.py b/examples/experimental/rollout_velocity.py
new file mode 100644
index 000000000..a096723d8
--- /dev/null
+++ b/examples/experimental/rollout_velocity.py
@@ -0,0 +1,391 @@
+import torch
+import argparse
+import os
+import numpy as np
+import mediapy as media
+from PIL import Image
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+import matplotlib.patches as patches
+from matplotlib.gridspec import GridSpec
+
+from gpudrive.env.config import EnvConfig
+from gpudrive.env.env_torch import GPUDriveTorchEnv
+from gpudrive.env.dataset import SceneDataLoader
+from gpudrive.datatypes.observation import LocalEgoState
+from gpudrive.datatypes.info import Info
+from gpudrive.utils.checkpoint import load_agent
+from gpudrive.visualize.utils import img_from_fig
+import madrona_gpudrive
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Generate GIFs of scenes with velocity graphs')
+ parser.add_argument('--guidance_dropout_prob', type=float, default=0.0,
+ help='Guidance dropout probability (0.0 = full guidance, 1.0 = no guidance)')
+ parser.add_argument('--dataset_size', type=int, default=10,
+ help='Number of scenes to process')
+ parser.add_argument('--data_path', type=str,
+ default="data/processed/wosac/validation_interactive/json",
+ help='Path to dataset')
+ parser.add_argument('--checkpoint_path', type=str,
+ default="checkpoints/model_guidance_logs__R_10000__06_07_13_55_31_201_013500.pt",
+ help='Path to agent checkpoint')
+ parser.add_argument('--output_dir', type=str, default='velocity_gifs',
+ help='Output directory for GIFs')
+ parser.add_argument('--device', type=str, default='cuda',
+ help='Device to run on (cuda/cpu)')
+ parser.add_argument('--zoom_radius', type=int, default=45,
+ help='Zoom radius for visualization')
+ parser.add_argument('--fps', type=int, default=10,
+ help='Frames per second for GIF')
+ parser.add_argument('--render_frequency', type=int, default=1,
+ help='Render every N timesteps (1 = every timestep)')
+ parser.add_argument('--guidance_mode', type=str, default='log_replay',
+ choices=['log_replay', 'vbd_amortized', 'vbd_online'],
+ help='Guidance mode to use')
+ parser.add_argument('--guidance_dropout_mode', type=str, default='avg',
+ choices=['max', 'avg', 'remove_all'],
+ help='Guidance dropout mode')
+ parser.add_argument('--ego_agent_idx', type=int, default=0,
+ help='Index of the ego agent to track (default: 0 for first controlled agent)')
+ parser.add_argument('--velocity_graph_height', type=float, default=0.3,
+ help='Height ratio of velocity graph relative to total figure (0.0-1.0)')
+
+ return parser.parse_args()
+
+
+def setup_environment(args):
+ """Setup the environment and agent"""
+
+ # Environment constants
+ MAX_AGENTS = madrona_gpudrive.kMaxAgentCount
+ NUM_ENVS = 1
+ INIT_STEPS = 10
+
+ # Load agent
+ print(f"Loading agent from {args.checkpoint_path}...")
+ agent = load_agent(path_to_cpt=args.checkpoint_path).to(args.device)
+ config = agent.config
+
+ # Create data loader
+ print(f"Loading dataset from {args.data_path}...")
+ val_loader = SceneDataLoader(
+ root=args.data_path,
+ batch_size=NUM_ENVS,
+ dataset_size=args.dataset_size,
+ sample_with_replacement=False,
+ shuffle=True,
+ file_prefix="",
+ seed=10,
+ )
+
+ # Override default environment settings
+ env_config = EnvConfig(
+ ego_state=config.ego_state,
+ road_map_obs=config.road_map_obs,
+ partner_obs=config.partner_obs,
+ reward_type=config.reward_type,
+ guidance_speed_weight=config.guidance_speed_weight,
+ guidance_heading_weight=config.guidance_heading_weight,
+ smoothness_weight=config.smoothness_weight,
+ norm_obs=config.norm_obs,
+ add_previous_action=config.add_previous_action,
+ guidance=config.guidance,
+ add_reference_pos_xy=config.add_reference_pos_xy,
+ add_reference_speed=config.add_reference_speed,
+ add_reference_heading=config.add_reference_heading,
+ dynamics_model=config.dynamics_model,
+ collision_behavior=config.collision_behavior,
+ goal_behavior=config.goal_behavior,
+ polyline_reduction_threshold=config.polyline_reduction_threshold,
+ remove_non_vehicles=config.remove_non_vehicles,
+ lidar_obs=False,
+ obs_radius=config.obs_radius,
+ action_space_steer_disc=config.action_space_steer_disc,
+ action_space_accel_disc=config.action_space_accel_disc,
+ init_mode="wosac_eval",
+ init_steps=INIT_STEPS,
+ guidance_mode=args.guidance_mode,
+ guidance_dropout_prob=args.guidance_dropout_prob,
+ guidance_dropout_mode=args.guidance_dropout_mode,
+ smoothen_trajectory=True,
+ )
+
+ # Make environment
+ env = GPUDriveTorchEnv(
+ config=env_config,
+ data_loader=val_loader,
+ max_cont_agents=MAX_AGENTS,
+ device=args.device,
+ )
+
+ return env, agent
+
+
+def create_combined_figure(scene_fig, velocity_data, current_timestep, total_timesteps,
+ args, guidance_percent, ego_agent_idx):
+ """Create a combined figure with scene and velocity graph"""
+
+ # Get the scene image from the existing figure
+ scene_img = img_from_fig(scene_fig)
+ scene_height, scene_width = scene_img.shape[:2]
+
+ # Calculate dimensions
+ velocity_height_ratio = args.velocity_graph_height
+ velocity_height = int(scene_height * velocity_height_ratio)
+ total_height = scene_height + velocity_height
+
+ # Create new figure with custom layout
+ fig = plt.figure(figsize=(12, 12 * total_height / scene_width))
+
+ # Create grid layout: scene on top, velocity graph on bottom
+ gs = GridSpec(2, 1, height_ratios=[1-velocity_height_ratio, velocity_height_ratio],
+ hspace=0.05, left=0.05, right=0.95, top=0.95, bottom=0.05)
+
+ # Add scene as image in top subplot
+ ax_scene = fig.add_subplot(gs[0])
+ ax_scene.imshow(scene_img)
+ ax_scene.set_xlim(0, scene_width)
+ ax_scene.set_ylim(scene_height, 0) # Flip y-axis for image
+ ax_scene.set_aspect('equal')
+ ax_scene.axis('off')
+
+ # Add velocity graph in bottom subplot
+ ax_velocity = fig.add_subplot(gs[1])
+
+ # Extract velocity data up to current timestep
+ timesteps = velocity_data['timesteps'][:current_timestep + 1]
+ velocities = velocity_data['velocities'][:current_timestep + 1]
+
+ # Plot velocity history
+ if len(timesteps) > 1:
+ ax_velocity.plot(timesteps, velocities, 'b-', linewidth=2, alpha=0.7)
+
+ # Add current point
+ if current_timestep < len(velocity_data['velocities']):
+ current_vel = velocity_data['velocities'][current_timestep]
+ ax_velocity.plot(timesteps[-1], current_vel, 'ro', markersize=8,
+ markeredgecolor='black', markeredgewidth=1)
+
+ # Set up the graph
+ ax_velocity.set_xlim(0, total_timesteps)
+ ax_velocity.set_ylim(0, max(max(velocity_data['velocities']) * 1.1, 1.0))
+ ax_velocity.set_xlabel('Time Step', fontsize=12)
+ ax_velocity.set_ylabel('Velocity (m/s)', fontsize=12)
+ ax_velocity.grid(True, alpha=0.3)
+ ax_velocity.set_title(f'Ego Agent {ego_agent_idx} Velocity | Guidance: {guidance_percent:.1f}%',
+ fontsize=14, pad=10)
+
+ # Add vertical line at current timestep
+ ax_velocity.axvline(x=current_timestep, color='red', linestyle='--', alpha=0.7, linewidth=1)
+
+ plt.tight_layout()
+ return fig
+
+
+def rollout_scene(env, agent, args):
+ """Rollout a single scene and collect frames with velocity data"""
+
+ control_mask = env.cont_agent_mask.clone().cpu()
+ next_obs = env.reset(mask=control_mask)
+
+ # Get the ego agent index (validate it exists and is controlled)
+ controlled_indices = torch.where(control_mask[0])[0]
+
+ if args.ego_agent_idx >= len(controlled_indices):
+ print(f" Warning: ego_agent_idx {args.ego_agent_idx} out of range. "
+ f"Only {len(controlled_indices)} controlled agents. Using agent 0.")
+ ego_agent_idx = controlled_indices[0].item()
+ else:
+ ego_agent_idx = controlled_indices[args.ego_agent_idx].item()
+
+ print(f" Tracking ego agent index: {ego_agent_idx}")
+
+ # Zero out actions for parked vehicles
+ info = Info.from_tensor(
+ env.sim.info_tensor(),
+ backend=env.backend,
+ device=env.device,
+ )
+
+ zero_action_mask = (info.off_road == 1) | (
+ info.collided_with_vehicle == 1
+ ) & (info.type == int(madrona_gpudrive.EntityType.Vehicle))
+
+ # Log guidance info
+ num_guidance_points = env.valid_guidance_points
+ guidance_densities = num_guidance_points / env.reference_traj_len
+ guidance_percent = guidance_densities.mean().item() * 100
+
+ print(f" Guidance density: {guidance_percent:.1f}% "
+ f"(avg {num_guidance_points.cpu().numpy().mean():.1f} points)")
+
+ frames = []
+ velocity_data = {
+ 'timesteps': [],
+ 'velocities': [],
+ }
+
+ total_timesteps = env.episode_len - env.config.init_steps
+
+ # Get initial velocity
+ local_ego_states = LocalEgoState.from_tensor(
+ env.sim.self_observation_tensor(),
+ backend=env.backend,
+ device="cpu",
+ )
+ initial_velocity = local_ego_states.speed[0, ego_agent_idx].item()
+ velocity_data['timesteps'].append(0)
+ velocity_data['velocities'].append(initial_velocity)
+
+ # Get initial frame
+ if 0 % args.render_frequency == 0:
+ scene_fig = env.vis.plot_simulator_state(
+ env_indices=[0],
+ zoom_radius=args.zoom_radius,
+ plot_guidance_pos_xy=True,
+ center_agent_indices=[ego_agent_idx],
+ )[0]
+
+ # Create combined figure with velocity graph
+ combined_fig = create_combined_figure(
+ scene_fig, velocity_data, 0, total_timesteps,
+ args, guidance_percent, ego_agent_idx
+ )
+
+ frames.append(img_from_fig(combined_fig))
+ plt.close(scene_fig)
+ plt.close(combined_fig)
+
+ # Rollout the episode
+ for time_step in tqdm(range(total_timesteps), desc=" Rolling out", leave=False):
+
+ # Predict actions
+ action, _, _, _ = agent(next_obs)
+
+ action_template = torch.zeros(
+ (env.num_worlds, madrona_gpudrive.kMaxAgentCount),
+ dtype=torch.int64, device=env.device
+ )
+ action_template[control_mask] = action.to(env.device)
+
+ # Find the "do nothing" action for parked vehicles
+ DO_NOTHING_ACTION_INT = [
+ key for key, value in env.action_key_to_values.items()
+ if abs(value[0]) == 0.0 and abs(value[1]) == 0.0 and abs(value[2]) == 0.0
+ ][0]
+ action_template[zero_action_mask] = DO_NOTHING_ACTION_INT
+
+ # Step environment
+ env.step_dynamics(action_template)
+ next_obs = env.get_obs(control_mask)
+
+ # Get velocity data for ego agent
+ local_ego_states = LocalEgoState.from_tensor(
+ env.sim.self_observation_tensor(),
+ backend=env.backend,
+ device="cpu",
+ )
+
+ # Get speed directly from LocalEgoState
+ velocity = local_ego_states.speed[0, ego_agent_idx].item()
+ velocity_data['timesteps'].append(time_step + 1)
+ velocity_data['velocities'].append(velocity)
+
+ # Render frame if needed
+ if (time_step + 1) % args.render_frequency == 0:
+ scene_fig = env.vis.plot_simulator_state(
+ env_indices=[0],
+ zoom_radius=args.zoom_radius,
+ plot_guidance_pos_xy=True,
+ center_agent_indices=[ego_agent_idx],
+ )[0]
+
+ # Create combined figure with velocity graph
+ combined_fig = create_combined_figure(
+ scene_fig, velocity_data, time_step + 1, total_timesteps,
+ args, guidance_percent, ego_agent_idx
+ )
+
+ frames.append(img_from_fig(combined_fig))
+ plt.close(scene_fig)
+ plt.close(combined_fig)
+
+ return frames, guidance_percent, ego_agent_idx, velocity_data
+
+
+def main():
+ args = parse_args()
+
+ # Validate arguments
+ if not (0.0 <= args.velocity_graph_height <= 1.0):
+ print("Error: velocity_graph_height must be between 0.0 and 1.0")
+ return
+
+ # Create output directory
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ print(f"Generating GIFs with velocity graphs")
+ print(f"Guidance dropout probability: {args.guidance_dropout_prob}")
+ print(f"Ego agent index: {args.ego_agent_idx}")
+ print(f"Velocity graph height ratio: {args.velocity_graph_height}")
+ print(f"Output directory: {args.output_dir}")
+
+ # Setup environment and agent
+ env, agent = setup_environment(args)
+
+ scene_count = 0
+ successful_scenes = 0
+
+ while scene_count < args.dataset_size:
+ try:
+ print(f"\nProcessing scene {scene_count + 1}/{args.dataset_size}...")
+
+ # Rollout scene and collect frames with velocity data
+ frames, guidance_percent, actual_ego_idx, velocity_data = rollout_scene(env, agent, args)
+
+ if frames:
+ # Create filename with guidance and ego agent info
+ guidance_str = f"guidance_{guidance_percent:.1f}pct"
+ dropout_str = f"dropout_{args.guidance_dropout_prob:.2f}"
+ ego_str = f"ego_{actual_ego_idx}"
+ filename = f"scene_{scene_count:03d}_{guidance_str}_{dropout_str}_{ego_str}_velocity.gif"
+ filepath = os.path.join(args.output_dir, filename)
+
+ # Save GIF
+ print(f" Saving GIF with {len(frames)} frames to {filename}")
+ media.write_video(
+ filepath,
+ np.array(frames),
+ fps=args.fps,
+ codec="gif"
+ )
+
+ successful_scenes += 1
+ else:
+ print(f" Warning: No frames generated for scene {scene_count}")
+
+ scene_count += 1
+
+ # Try to load next scene
+ try:
+ env.swap_data_batch()
+ except StopIteration:
+ print("No more scenes in the dataset.")
+ break
+
+ except Exception as e:
+ print(f" Error processing scene {scene_count}: {e}")
+ scene_count += 1
+ try:
+ env.swap_data_batch()
+ except StopIteration:
+ break
+
+ print(f"\nCompleted! Generated {successful_scenes} GIFs with velocity graphs out of {scene_count} scenes.")
+ print(f"GIFs saved in: {args.output_dir}")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/gpudrive/visualize/core.py b/gpudrive/visualize/core.py
index 4dab953e2..dff942c8c 100644
--- a/gpudrive/visualize/core.py
+++ b/gpudrive/visualize/core.py
@@ -547,7 +547,7 @@ def plot_agent_trajectories(
) # Use absolute values for coloring
# Set up a colormap for weights
- weight_cmap = plt.cm.coolwarm
+ weight_cmap = sns.color_palette("flare", as_cmap=True)
weight_norm = plt.Normalize(
vmin=weight_values.min().item(),
vmax=weight_values.max().item(),
@@ -651,7 +651,7 @@ def plot_agent_trajectories(
weight_sm, cax=cbar_ax, orientation="horizontal"
)
cbar.set_label(
- f"Conditioning param value", fontsize=15 * marker_scale
+ f"Guidance Density", fontsize=15 * marker_scale
)
cbar.ax.tick_params(labelsize=12 * marker_scale)
except Exception as e: