|
3 | 3 | from datetime import datetime |
4 | 4 | from pathlib import Path |
5 | 5 |
|
| 6 | +import fire |
6 | 7 | import gymnasium |
7 | 8 | import jax |
8 | 9 | import jax.numpy as jnp |
@@ -140,9 +141,8 @@ def profile_reset(sim_config: config_dict.ConfigDict, n_steps: int, device: str) |
140 | 141 | analyze_timings(times_masked, n_steps, sim.n_worlds, sim.freq) |
141 | 142 |
|
142 | 143 |
|
143 | | -def main(): |
| 144 | +def main(device: str = "cpu", n_worlds_exp: int = 6): |
144 | 145 | """Main entry point for profiling.""" |
145 | | - device = "cpu" |
146 | 146 | sim_config = config_dict.ConfigDict() |
147 | 147 | sim_config.n_worlds = 1 |
148 | 148 | sim_config.n_drones = 1 |
@@ -182,7 +182,7 @@ def main(): |
182 | 182 |
|
183 | 183 | n_steps = 1000 |
184 | 184 | # Test with increasing number of parallel environments (worlds) |
185 | | - for n_worlds in [1, 10, 100, 1000, 10000, 100000, 1000000]: |
| 185 | + for n_worlds in [10**i for i in range(n_worlds_exp + 1)]: |
186 | 186 | print(f"\nTesting with {n_worlds} parallel environments:") |
187 | 187 | sim_config.n_worlds = n_worlds |
188 | 188 |
|
@@ -213,7 +213,7 @@ def main(): |
213 | 213 |
|
214 | 214 | # Save simulator results |
215 | 215 | # Reopen CSV writer in append mode |
216 | | - with open(csv_file, "w", newline="") as f: |
| 216 | + with open(csv_file, "a", newline="") as f: |
217 | 217 | csv_writer = csv.writer(f) |
218 | 218 | csv_writer.writerow( |
219 | 219 | [ |
@@ -270,4 +270,4 @@ def main(): |
270 | 270 |
|
271 | 271 |
|
272 | 272 | if __name__ == "__main__": |
273 | | - main() |
| 273 | + fire.Fire(main) |
0 commit comments