33from datetime import datetime
44from pathlib import Path
55
6+ import fire
67import gymnasium
78import jax
89import jax .numpy as jnp
910import numpy as np
11+ from jax .errors import JaxRuntimeError
1012from ml_collections import config_dict
1113
1214import crazyflow # noqa: F401, ensure gymnasium envs are registered
@@ -41,15 +43,15 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
4143
4244
4345def profile_gym_env_step (
44- sim_config : config_dict .ConfigDict , n_steps : int , device : str
46+ sim_config : config_dict .ConfigDict , n_steps : int , device : str , print_summary : bool = True
4547) -> list [float ]:
4648 """Profile the Crazyflow gym environment step performance."""
4749 times = []
4850 device = jax .devices (device )[0 ]
4951
5052 envs = gymnasium .make_vec (
5153 "DroneReachPos-v0" ,
52- time_horizon_in_seconds = 3 ,
54+ max_episode_time = 3 ,
5355 num_envs = sim_config .n_worlds ,
5456 device = sim_config .device ,
5557 freq = sim_config .freq ,
@@ -60,7 +62,7 @@ def profile_gym_env_step(
6062 action = np .zeros ((sim_config .n_worlds , 4 ), dtype = np .float32 )
6163 action [..., 0 ] = 0.3
6264 # Step through env once to ensure JIT compilation
63- envs .reset (seed = 42 )
65+ envs .reset ()
6466 envs .step (action )
6567
6668 jax .block_until_ready (envs .unwrapped .sim .data ) # Ensure JIT compiled dynamics
@@ -73,12 +75,15 @@ def profile_gym_env_step(
7375 times .append (time .perf_counter () - tstart )
7476
7577 envs .close ()
76- print ("Gym env step performance:" )
77- analyze_timings (times , n_steps , envs .unwrapped .sim .n_worlds , envs .unwrapped .sim .freq )
78+ if print_summary :
79+ print ("Gym env step performance:" )
80+ analyze_timings (times , n_steps , envs .unwrapped .sim .n_worlds , sim_config .freq )
7881 return times
7982
8083
81- def profile_step (sim_config : config_dict .ConfigDict , n_steps : int , device : str ) -> list [float ]:
84+ def profile_step (
85+ sim_config : config_dict .ConfigDict , n_steps : int , device : str , print_summary : bool = True
86+ ) -> list [float ]:
8287 """Profile the Crazyflow simulator step performance."""
8388 sim = Sim (** sim_config )
8489 times = []
@@ -99,8 +104,9 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
99104 jax .block_until_ready (sim .data )
100105 times .append (time .perf_counter () - tstart )
101106
102- print ("Sim step performance:" )
103- analyze_timings (times , n_steps , sim .n_worlds , sim .freq )
107+ if print_summary :
108+ print ("Sim step performance:" )
109+ analyze_timings (times , n_steps , sim .n_worlds , sim .freq )
104110 return times
105111
106112
@@ -140,9 +146,8 @@ def profile_reset(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
140146 analyze_timings (times_masked , n_steps , sim .n_worlds , sim .freq )
141147
142148
143- def main ():
149+ def main (device : str = "cpu" , n_worlds_exp : int = 6 ):
144150 """Main entry point for profiling."""
145- device = "cpu"
146151 sim_config = config_dict .ConfigDict ()
147152 sim_config .n_worlds = 1
148153 sim_config .n_drones = 1
@@ -181,93 +186,109 @@ def main():
181186 # Reopen the file in append mode for each result
182187
183188 n_steps = 1000
189+ skip_sim , skip_gym = False , False
184190 # Test with increasing number of parallel environments (worlds)
185- for n_worlds in [1 , 10 , 100 , 1000 , 10000 , 100000 , 1000000 ]:
186- print (f"\n Testing with { n_worlds } parallel environments:" )
191+ for n_worlds in [10 ** i for i in range (n_worlds_exp + 1 )]:
187192 sim_config .n_worlds = n_worlds
193+ print ("-" * 80 )
194+ if not skip_sim :
195+ # Test with a single step first to see if we should continue
196+ sim_config .freq = 500 # Test sim at 500 hz
197+ single_step_time = profile_step (sim_config , 2 , device , print_summary = False )[1 ]
198+
199+ # If single step takes too long, skip this and remaining tests
200+ if single_step_time > max_seconds_per_run / n_steps : # threshold for the tests
201+ print (
202+ f" Skipping benchmark for { n_worlds } and higher - projected time "
203+ f"{ single_step_time * n_steps :.2f} s (> 1m)"
204+ )
205+ skip_sim = True
206+
207+ if not skip_sim :
208+ # Configure simulator
209+ print (f"Running simulator benchmark ({ n_worlds } worlds)..." )
210+ # Run simulator benchmark using existing function
211+ times_sim = profile_step (sim_config , n_steps , device )
212+
213+ # Calculate metrics for CSV
214+ total_time = sum (times_sim )
215+ avg_step_time = np .mean (times_sim )
216+ n_frames = n_steps * n_worlds
217+ fps = n_frames / total_time
218+ real_time_factor = (n_steps / sim_config .freq ) * n_worlds / total_time
219+
220+ # Save simulator results
221+ # Reopen CSV writer in append mode
222+ with open (csv_file , "a" , newline = "" ) as f :
223+ csv_writer = csv .writer (f )
224+ csv_writer .writerow (
225+ [
226+ "simulator" ,
227+ 1 , # n_drones
228+ n_worlds ,
229+ n_steps ,
230+ total_time ,
231+ avg_step_time ,
232+ fps ,
233+ real_time_factor ,
234+ sim_config .device ,
235+ ]
236+ )
237+ f .flush ()
238+
239+ if not skip_gym :
240+ print (f"Running gym environment benchmark ({ n_worlds } worlds)..." )
241+ # Run gym environment benchmark using existing function
242+ sim_config .freq = 50 # Test gym at 50 hz
243+ try :
244+ step_times = profile_gym_env_step (sim_config , 2 , device , print_summary = False )
245+ single_step_time = step_times [1 ]
246+ # If single step takes too long, skip this test only
247+ if single_step_time > max_seconds_per_run / n_steps : # threshold for the tests
248+ print (
249+ f" Skipping benchmark for { n_worlds } - projected time "
250+ f"{ single_step_time * n_steps :.2f} s (> 1m)"
251+ )
252+ skip_gym = True
253+ except JaxRuntimeError :
254+ print (f" Skipping benchmark for { n_worlds } - resource exhausted" )
255+ skip_gym = True
188256
189- # Test with a single step first to see if we should continue
190- sim_config .freq = 500 # Test sim at 500 hz
191- test_times = profile_step (sim_config , 1 , device )
192-
193- single_step_time = test_times [0 ]
194- # If single step takes too long, skip this and remaining tests
195- if single_step_time > max_seconds_per_run / n_steps : # threshold for the tests
196- print (
197- f" Skipping benchmark for { n_worlds } and higher - single step took "
198- f"{ single_step_time * 1000 :.2f} s (> 1m)"
199- )
200- break
201-
202- # Configure simulator
203- print (f" Running simulator benchmark ({ n_worlds } worlds)..." )
204- # Run simulator benchmark using existing function
205- times_sim = profile_step (sim_config , n_steps , device )
206-
207- # Calculate metrics for CSV
208- total_time = sum (times_sim )
209- avg_step_time = np .mean (times_sim )
210- n_frames = n_steps * n_worlds
211- fps = n_frames / total_time
212- real_time_factor = (n_steps / sim_config .freq ) * n_worlds / total_time
213-
214- # Save simulator results
215- # Reopen CSV writer in append mode
216- with open (csv_file , "w" , newline = "" ) as f :
217- csv_writer = csv .writer (f )
218- csv_writer .writerow (
219- [
220- "simulator" ,
221- 1 , # n_drones
222- n_worlds ,
223- n_steps ,
224- total_time ,
225- avg_step_time ,
226- fps ,
227- real_time_factor ,
228- sim_config .device ,
229- ]
230- )
231- f .flush ()
232-
233- print (f" Running gym environment benchmark ({ n_worlds } worlds)..." )
234- # Run gym environment benchmark using existing function
235- sim_config .freq = 50 # Test gym at 50 hz
236- try :
237- times_gym = profile_gym_env_step (sim_config , n_steps , device )
238- except ValueError as e :
239- if "RESOURCE_EXHAUSTED" in str (e ):
257+ if not skip_gym :
258+ try :
259+ times_gym = profile_gym_env_step (sim_config , n_steps , device )
260+ except JaxRuntimeError :
240261 print (f" Skipping benchmark for { n_worlds } - resource exhausted" )
241- continue # Only continue, we might still be able to benchmark sim
242- raise e
243-
244- # Calculate metrics for CSV
245- total_time = sum (times_gym )
246- avg_step_time = np .mean (times_gym )
247- n_frames = n_steps * n_worlds
248- fps = n_frames / total_time
249- real_time_factor = (n_steps / sim_config .freq ) * sim_config .n_worlds / total_time
250-
251- # Save gym environment results
252- with open (csv_file , "a" , newline = "" ) as f :
253- csv_writer = csv .writer (f )
254- csv_writer .writerow (
255- [
256- "gym_env" ,
257- sim_config .n_drones ,
258- sim_config .n_worlds ,
259- n_steps ,
260- total_time ,
261- avg_step_time ,
262- fps ,
263- real_time_factor ,
264- sim_config .device ,
265- ]
266- )
267- f .flush ()
262+ skip_gym = True
263+ continue
264+
265+ # Calculate metrics for CSV
266+ total_time = sum (times_gym )
267+ avg_step_time = np .mean (times_gym )
268+ n_frames = n_steps * n_worlds
269+ fps = n_frames / total_time
270+ real_time_factor = (n_steps / sim_config .freq ) * sim_config .n_worlds / total_time
271+
272+ # Save gym environment results
273+ with open (csv_file , "a" , newline = "" ) as f :
274+ csv_writer = csv .writer (f )
275+ csv_writer .writerow (
276+ [
277+ "gym_env" ,
278+ sim_config .n_drones ,
279+ sim_config .n_worlds ,
280+ n_steps ,
281+ total_time ,
282+ avg_step_time ,
283+ fps ,
284+ real_time_factor ,
285+ sim_config .device ,
286+ ]
287+ )
288+ f .flush ()
268289
269290 print (f"\n Benchmark results saved to { csv_file } " )
270291
271292
272293if __name__ == "__main__" :
273- main ( )
294+ fire . Fire ( main )
0 commit comments