Skip to content

Commit b1541a1

Browse files
committed
Fix benchmark script
1 parent 02e444b commit b1541a1

3 files changed

Lines changed: 252 additions & 9 deletions

File tree

benchmark/main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime
44
from pathlib import Path
55

6+
import fire
67
import gymnasium
78
import jax
89
import jax.numpy as jnp
@@ -140,9 +141,8 @@ def profile_reset(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
140141
analyze_timings(times_masked, n_steps, sim.n_worlds, sim.freq)
141142

142143

143-
def main():
144+
def main(device: str = "cpu", n_worlds_exp: int = 6):
144145
"""Main entry point for profiling."""
145-
device = "cpu"
146146
sim_config = config_dict.ConfigDict()
147147
sim_config.n_worlds = 1
148148
sim_config.n_drones = 1
@@ -182,7 +182,7 @@ def main():
182182

183183
n_steps = 1000
184184
# Test with increasing number of parallel environments (worlds)
185-
for n_worlds in [1, 10, 100, 1000, 10000, 100000, 1000000]:
185+
for n_worlds in [10**i for i in range(n_worlds_exp + 1)]:
186186
print(f"\nTesting with {n_worlds} parallel environments:")
187187
sim_config.n_worlds = n_worlds
188188

@@ -213,7 +213,7 @@ def main():
213213

214214
# Save simulator results
215215
# Reopen CSV writer in append mode
216-
with open(csv_file, "w", newline="") as f:
216+
with open(csv_file, "a", newline="") as f:
217217
csv_writer = csv.writer(f)
218218
csv_writer.writerow(
219219
[
@@ -270,4 +270,4 @@ def main():
270270

271271

272272
if __name__ == "__main__":
273-
main()
273+
fire.Fire(main)

0 commit comments

Comments
 (0)