Skip to content

Commit 4e34269

Browse files
committed
Refactor gym envs. Add visualizations. Fix double compile of step
1 parent 9ad96b6 commit 4e34269

22 files changed

Lines changed: 837 additions & 801 deletions

benchmark/main.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import jax
99
import jax.numpy as jnp
1010
import numpy as np
11+
from jax.errors import JaxRuntimeError
1112
from ml_collections import config_dict
1213

1314
import crazyflow # noqa: F401, ensure gymnasium envs are registered
@@ -42,15 +43,15 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
4243

4344

4445
def profile_gym_env_step(
45-
sim_config: config_dict.ConfigDict, n_steps: int, device: str
46+
sim_config: config_dict.ConfigDict, n_steps: int, device: str, print_summary: bool = True
4647
) -> list[float]:
4748
"""Profile the Crazyflow gym environment step performance."""
4849
times = []
4950
device = jax.devices(device)[0]
5051

5152
envs = gymnasium.make_vec(
5253
"DroneReachPos-v0",
53-
time_horizon_in_seconds=3,
54+
max_episode_time=3,
5455
num_envs=sim_config.n_worlds,
5556
device=sim_config.device,
5657
freq=sim_config.freq,
@@ -61,7 +62,7 @@ def profile_gym_env_step(
6162
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
6263
action[..., 0] = 0.3
6364
# Step through env once to ensure JIT compilation
64-
envs.reset(seed=42)
65+
envs.reset()
6566
envs.step(action)
6667

6768
jax.block_until_ready(envs.unwrapped.sim.data) # Ensure JIT compiled dynamics
@@ -74,12 +75,15 @@ def profile_gym_env_step(
7475
times.append(time.perf_counter() - tstart)
7576

7677
envs.close()
77-
print("Gym env step performance:")
78-
analyze_timings(times, n_steps, envs.unwrapped.sim.n_worlds, envs.unwrapped.sim.freq)
78+
if print_summary:
79+
print("Gym env step performance:")
80+
analyze_timings(times, n_steps, envs.unwrapped.sim.n_worlds, sim_config.freq)
7981
return times
8082

8183

82-
def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str) -> list[float]:
84+
def profile_step(
85+
sim_config: config_dict.ConfigDict, n_steps: int, device: str, print_summary: bool = True
86+
) -> list[float]:
8387
"""Profile the Crazyflow simulator step performance."""
8488
sim = Sim(**sim_config)
8589
times = []
@@ -100,8 +104,9 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
100104
jax.block_until_ready(sim.data)
101105
times.append(time.perf_counter() - tstart)
102106

103-
print("Sim step performance:")
104-
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
107+
if print_summary:
108+
print("Sim step performance:")
109+
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
105110
return times
106111

107112

@@ -184,25 +189,24 @@ def main(device: str = "cpu", n_worlds_exp: int = 6):
184189
skip_sim, skip_gym = False, False
185190
# Test with increasing number of parallel environments (worlds)
186191
for n_worlds in [10**i for i in range(n_worlds_exp + 1)]:
192+
sim_config.n_worlds = n_worlds
193+
print("-" * 80)
187194
if not skip_sim:
188-
print(f"\nTesting with {n_worlds} parallel environments:")
189-
sim_config.n_worlds = n_worlds
190-
191195
# Test with a single step first to see if we should continue
192196
sim_config.freq = 500 # Test sim at 500 hz
193-
single_step_time = profile_step(sim_config, 2, device)[1]
197+
single_step_time = profile_step(sim_config, 2, device, print_summary=False)[1]
194198

195199
# If single step takes too long, skip this and remaining tests
196200
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
197201
print(
198-
f" Skipping benchmark for {n_worlds} and higher - single step took "
202+
f" Skipping benchmark for {n_worlds} and higher - projected time "
199203
f"{single_step_time * n_steps:.2f}s (> 1m)"
200204
)
201205
skip_sim = True
202206

203207
if not skip_sim:
204208
# Configure simulator
205-
print(f" Running simulator benchmark ({n_worlds} worlds)...")
209+
print(f"Running simulator benchmark ({n_worlds} worlds)...")
206210
# Run simulator benchmark using existing function
207211
times_sim = profile_step(sim_config, n_steps, device)
208212

@@ -233,33 +237,30 @@ def main(device: str = "cpu", n_worlds_exp: int = 6):
233237
f.flush()
234238

235239
if not skip_gym:
236-
print(f" Running gym environment benchmark ({n_worlds} worlds)...")
240+
print(f"Running gym environment benchmark ({n_worlds} worlds)...")
237241
# Run gym environment benchmark using existing function
238242
sim_config.freq = 50 # Test gym at 50 hz
239243
try:
240-
single_step_time = profile_gym_env_step(sim_config, 2, device)[1]
244+
step_times = profile_gym_env_step(sim_config, 2, device, print_summary=False)
245+
single_step_time = step_times[1]
241246
# If single step takes too long, skip this test only
242247
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
243248
print(
244-
f" Skipping benchmark for {n_worlds} - single step took "
249+
f" Skipping benchmark for {n_worlds} - projected time "
245250
f"{single_step_time * n_steps:.2f}s (> 1m)"
246251
)
247252
skip_gym = True
248-
except ValueError as e:
249-
if "RESOURCE_EXHAUSTED" in str(e):
250-
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
251-
skip_gym = True
252-
else:
253-
raise e
253+
except JaxRuntimeError:
254+
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
255+
skip_gym = True
254256

255257
if not skip_gym:
256258
try:
257259
times_gym = profile_gym_env_step(sim_config, n_steps, device)
258-
except ValueError as e:
259-
if "RESOURCE_EXHAUSTED" in str(e):
260-
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
261-
continue # Only continue, we might still be able to benchmark sim
262-
raise e
260+
except JaxRuntimeError:
261+
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
262+
skip_gym = True
263+
continue
263264

264265
# Calculate metrics for CSV
265266
total_time = sum(times_gym)

benchmark/performance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from crazyflow.sim import Sim
1414

1515
if TYPE_CHECKING:
16-
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal
16+
from crazyflow.envs import ReachPosEnv
1717

1818

1919
def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
@@ -44,7 +44,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
4444
def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
4545
device = jax.devices(device)[0]
4646

47-
envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
47+
envs: ReachPosEnv = gymnasium.make_vec(
4848
"DroneReachPos-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config
4949
)
5050

benchmark/plot.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,25 @@ def format_log_axes(ax: plt.Axes, dfs: dict[str, pd.DataFrame], prefix: str):
111111
ax.set_xticklabels([xticklabels[i] for i in valid_indices])
112112

113113
# Get min and max y values for plots
114-
min_y = min([df["fps"].min() for key, df in dfs.items() if key.startswith(prefix)])
115-
max_y = max([df["fps"].max() for key, df in dfs.items() if key.startswith(prefix)])
114+
min_y = min([df["fps"].min() for key, df in dfs.items()])
115+
max_y = max([df["fps"].max() for key, df in dfs.items()])
116116

117117
# Create logarithmic y-ticks
118-
yticks = np.array([1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000])
119-
mask = (yticks >= min_y * 0.1) & (yticks <= max_y * 10)
120-
valid_indices = np.nonzero(mask)[0]
121-
ax.set_yticks(yticks[valid_indices])
122-
yticklabels = ["1", "10", "100", "1K", "10K", "100K", "1M", "10M", "100M"]
123-
ax.set_yticklabels([yticklabels[i] for i in valid_indices])
118+
# Generate yticks based on data range
119+
min_power = int(np.floor(np.log10(min_y)))
120+
max_power = int(np.ceil(np.log10(max_y)))
121+
yticks = np.array([10**i for i in range(min_power, max_power + 1)])
122+
ax.set_yticks(yticks)
123+
yticklabels = []
124+
abbrev = {1e9: "B", 1e6: "M", 1e3: "K"}
125+
for i in yticks:
126+
for divisor, suffix in sorted(abbrev.items(), reverse=True):
127+
if i >= divisor:
128+
yticklabels.append(f"{int(i // divisor)}{suffix}")
129+
break
130+
else:
131+
yticklabels.append(f"{int(i)}")
132+
ax.set_yticklabels(yticklabels)
124133

125134
# Remove minor ticks for cleaner appearance
126135
ax.minorticks_off()

crazyflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import crazyflow.gymnasium_envs # noqa: F401, ensure gymnasium envs are registered
1+
import crazyflow.envs # noqa: F401, ensure gymnasium envs are registered
22
from crazyflow.control import Control
33
from crazyflow.sim import Physics, Sim
44

crazyflow/envs/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from gymnasium.envs.registration import register
2+
3+
from crazyflow.envs.figure_8_env import FigureEightEnv
4+
from crazyflow.envs.landing_env import LandingEnv
5+
from crazyflow.envs.norm_actions_wrapper import NormalizeActions
6+
from crazyflow.envs.reach_pos_env import ReachPosEnv
7+
from crazyflow.envs.reach_vel_env import ReachVelEnv
8+
9+
__all__ = ["ReachPosEnv", "ReachVelEnv", "LandingEnv", "NormalizeActions", "FigureEightEnv"]
10+
11+
register(id="DroneReachPos-v0", vector_entry_point=ReachPosEnv)
12+
13+
register(id="DroneReachVel-v0", vector_entry_point=ReachVelEnv)
14+
15+
register(id="DroneLanding-v0", vector_entry_point=LandingEnv)
16+
17+
register(id="DroneFigureEightTrajectory-v0", vector_entry_point=FigureEightEnv)

0 commit comments

Comments
 (0)