88import jax
99import jax .numpy as jnp
1010import numpy as np
11+ from jax .errors import JaxRuntimeError
1112from ml_collections import config_dict
1213
1314import crazyflow # noqa: F401, ensure gymnasium envs are registered
@@ -42,15 +43,15 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
4243
4344
4445def profile_gym_env_step (
45- sim_config : config_dict .ConfigDict , n_steps : int , device : str
46+ sim_config : config_dict .ConfigDict , n_steps : int , device : str , print_summary : bool = True
4647) -> list [float ]:
4748 """Profile the Crazyflow gym environment step performance."""
4849 times = []
4950 device = jax .devices (device )[0 ]
5051
5152 envs = gymnasium .make_vec (
5253 "DroneReachPos-v0" ,
53- time_horizon_in_seconds = 3 ,
54+ max_episode_time = 3 ,
5455 num_envs = sim_config .n_worlds ,
5556 device = sim_config .device ,
5657 freq = sim_config .freq ,
@@ -61,7 +62,7 @@ def profile_gym_env_step(
6162 action = np .zeros ((sim_config .n_worlds , 4 ), dtype = np .float32 )
6263 action [..., 0 ] = 0.3
6364 # Step through env once to ensure JIT compilation
64- envs .reset (seed = 42 )
65+ envs .reset ()
6566 envs .step (action )
6667
6768 jax .block_until_ready (envs .unwrapped .sim .data ) # Ensure JIT compiled dynamics
@@ -74,12 +75,15 @@ def profile_gym_env_step(
7475 times .append (time .perf_counter () - tstart )
7576
7677 envs .close ()
77- print ("Gym env step performance:" )
78- analyze_timings (times , n_steps , envs .unwrapped .sim .n_worlds , envs .unwrapped .sim .freq )
78+ if print_summary :
79+ print ("Gym env step performance:" )
80+ analyze_timings (times , n_steps , envs .unwrapped .sim .n_worlds , sim_config .freq )
7981 return times
8082
8183
82- def profile_step (sim_config : config_dict .ConfigDict , n_steps : int , device : str ) -> list [float ]:
84+ def profile_step (
85+ sim_config : config_dict .ConfigDict , n_steps : int , device : str , print_summary : bool = True
86+ ) -> list [float ]:
8387 """Profile the Crazyflow simulator step performance."""
8488 sim = Sim (** sim_config )
8589 times = []
@@ -100,8 +104,9 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
100104 jax .block_until_ready (sim .data )
101105 times .append (time .perf_counter () - tstart )
102106
103- print ("Sim step performance:" )
104- analyze_timings (times , n_steps , sim .n_worlds , sim .freq )
107+ if print_summary :
108+ print ("Sim step performance:" )
109+ analyze_timings (times , n_steps , sim .n_worlds , sim .freq )
105110 return times
106111
107112
@@ -184,25 +189,24 @@ def main(device: str = "cpu", n_worlds_exp: int = 6):
184189 skip_sim , skip_gym = False , False
185190 # Test with increasing number of parallel environments (worlds)
186191 for n_worlds in [10 ** i for i in range (n_worlds_exp + 1 )]:
192+ sim_config .n_worlds = n_worlds
193+ print ("-" * 80 )
187194 if not skip_sim :
188- print (f"\n Testing with { n_worlds } parallel environments:" )
189- sim_config .n_worlds = n_worlds
190-
191195 # Test with a single step first to see if we should continue
192196 sim_config .freq = 500 # Test sim at 500 hz
193- single_step_time = profile_step (sim_config , 2 , device )[1 ]
197+ single_step_time = profile_step (sim_config , 2 , device , print_summary = False )[1 ]
194198
195199 # If single step takes too long, skip this and remaining tests
196200 if single_step_time > max_seconds_per_run / n_steps : # threshold for the tests
197201 print (
198- f" Skipping benchmark for { n_worlds } and higher - single step took "
202+ f" Skipping benchmark for { n_worlds } and higher - projected time "
199203 f"{ single_step_time * n_steps :.2f} s (> 1m)"
200204 )
201205 skip_sim = True
202206
203207 if not skip_sim :
204208 # Configure simulator
205- print (f" Running simulator benchmark ({ n_worlds } worlds)..." )
209+ print (f"Running simulator benchmark ({ n_worlds } worlds)..." )
206210 # Run simulator benchmark using existing function
207211 times_sim = profile_step (sim_config , n_steps , device )
208212
@@ -233,33 +237,30 @@ def main(device: str = "cpu", n_worlds_exp: int = 6):
233237 f .flush ()
234238
235239 if not skip_gym :
236- print (f" Running gym environment benchmark ({ n_worlds } worlds)..." )
240+ print (f"Running gym environment benchmark ({ n_worlds } worlds)..." )
237241 # Run gym environment benchmark using existing function
238242 sim_config .freq = 50 # Test gym at 50 hz
239243 try :
240- single_step_time = profile_gym_env_step (sim_config , 2 , device )[1 ]
244+ step_times = profile_gym_env_step (sim_config , 2 , device , print_summary = False )
245+ single_step_time = step_times [1 ]
241246 # If single step takes too long, skip this test only
242247 if single_step_time > max_seconds_per_run / n_steps : # threshold for the tests
243248 print (
244- f" Skipping benchmark for { n_worlds } - single step took "
249+ f" Skipping benchmark for { n_worlds } - projected time "
245250 f"{ single_step_time * n_steps :.2f} s (> 1m)"
246251 )
247252 skip_gym = True
248- except ValueError as e :
249- if "RESOURCE_EXHAUSTED" in str (e ):
250- print (f" Skipping benchmark for { n_worlds } - resource exhausted" )
251- skip_gym = True
252- else :
253- raise e
253+ except JaxRuntimeError :
254+ print (f" Skipping benchmark for { n_worlds } - resource exhausted" )
255+ skip_gym = True
254256
255257 if not skip_gym :
256258 try :
257259 times_gym = profile_gym_env_step (sim_config , n_steps , device )
258- except ValueError as e :
259- if "RESOURCE_EXHAUSTED" in str (e ):
260- print (f" Skipping benchmark for { n_worlds } - resource exhausted" )
261- continue # Only continue, we might still be able to benchmark sim
262- raise e
260+ except JaxRuntimeError :
261+ print (f" Skipping benchmark for { n_worlds } - resource exhausted" )
262+ skip_gym = True
263+ continue
263264
264265 # Calculate metrics for CSV
265266 total_time = sum (times_gym )
0 commit comments