Skip to content

Commit 359dd1a

Browse files
authored
Merge pull request #36 from utiasDSL/dev
Merge dev into main
2 parents 96143d4 + d07577d commit 359dd1a

31 files changed

Lines changed: 3456 additions & 1055 deletions

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SCM syntax highlighting & preventing 3-way merges
2+
pixi.lock merge=binary linguist-language=YAML linguist-generated=true

.github/workflows/testing.yml

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
1-
name: Testing # Skips RL tests because stable-baselines3 comes with a lot of heavy-weight dependencies
2-
3-
on: [push]
1+
name: Testing
2+
on: [push, pull_request]
43

54
jobs:
65
test:
76
runs-on: ubuntu-latest
87
steps:
98
- uses: actions/checkout@v4
10-
- uses: mamba-org/setup-micromamba@v1
9+
10+
- name: Setup Pixi (installs pixi + caches envs) # https://github.com/marketplace/actions/setup-pixi
11+
uses: prefix-dev/setup-pixi@v0.9.0 # pin the action version
1112
with:
12-
micromamba-version: '2.0.2-1' # any version from https://github.com/mamba-org/micromamba-releases
13-
environment-name: test-env
14-
init-shell: bash
15-
create-args: python=3.11
16-
cache-environment: true
17-
- name: Install dependencies and package
18-
run: pip install .[test]
19-
shell: micromamba-shell {0}
20-
- name: Test with pytest
21-
run: pytest tests --cov=crazyflow
22-
shell: micromamba-shell {0}
13+
pixi-version: v0.49.0 # pin the pixi binary version (optional)
14+
cache: true # enable caching of installed envs
15+
# only write new caches on main pushes (TODO: Enable)
16+
# cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
17+
# ensure the 'test' environment(s) are installed
18+
environments: test
19+
# don't activate env (we'll call pixi run -e test explicitly)
20+
activate-environment: false
21+
# prefer using existing lockfile if present (faster, deterministic)
22+
locked: true
23+
24+
- name: Verify pixi and run tests
25+
run: |
26+
pixi --version
27+
pixi run -e test pytest

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,7 @@ build
1616
**/*.pt
1717
tutorials/ppo/wandb
1818
dist
19-
benchmark/data
19+
benchmark/data
20+
# pixi environments
21+
.pixi
22+
*.egg-info

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
--------------------------------------------------------------------------------
44

5-
Fast, parallelizable simulations of Crazyflies with JAX and MuJoCo.
5+
Fast, parallelizable simulations of Crazyflies with JAX.
66

77
[![Python Version]][Python Version URL] [![Ruff Check]][Ruff Check URL] [![Documentation Status]][Documentation Status URL] [![Tests]][Tests URL]
88

@@ -32,7 +32,6 @@ The simulation is built as a pipeline of functions that are composed at initiali
3232
Multiple physics models are supported:
3333
- analytical: A first-principles model based on physical equations
3434
- sys_id: A system-identified model trained on real drone data
35-
- mujoco: MuJoCo physics engine for more complex interactions
3635

3736
#### Control Modes
3837
Different control interfaces are available:
@@ -41,9 +40,10 @@ Different control interfaces are available:
4140
- thrust: Low-level control of individual motor thrusts
4241

4342
#### Integration Methods
44-
For analytical and system-identified physics:
43+
We support multiple integration schemes for additional precision:
4544
- euler: Simple first-order integration
4645
- rk4: Fourth-order Runge-Kutta integration for higher accuracy
46+
- symplectic\_euler: Symplectic integration for conservation of energy
4747

4848
### Parallelization
4949
Crazyflow supports massive parallelization across:
@@ -58,6 +58,12 @@ The framework supports domain randomization through the crazyflow/randomize modu
5858
### Functional Design
5959
The simulation follows a functional programming paradigm: All state is contained in immutable data structures. Updates create new states rather than modifying existing ones. All functions are pure, enabling JAX's transformations (JIT, grad, vmap) and thus automatic differentiation through the entire simulation, making it suitable for gradient-based optimization and reinforcement learning.
6060

61+
### Contacts and Non-Drone Models
62+
We focus on drones dynamics in free-space flight. Consequently, no models other than drones are available in the simulation and contact dynamics with external objects are not considered. However, we use MuJoCo for contact detection and visualization. Users can load their own objects into the simulation by changing the MuJoCo world spec. Drone collisions with these objects will be detected during collision checks, but they won't have an effect on the dynamics (i.e. drones will pass through objects). Similarly, the objects themselves will be static.
63+
64+
### Visualization
65+
We use `gymnasium`'s MuJoCo renderer and synchronize the simulation data with MuJoCo to either render an interactive UI or RGB arrays.
66+
6167
## Examples
6268
The repository includes several example scripts demonstrating different capabilities:
6369
| Example | Description |

benchmark/main.py

Lines changed: 112 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from datetime import datetime
44
from pathlib import Path
55

6+
import fire
67
import gymnasium
78
import jax
89
import jax.numpy as jnp
910
import numpy as np
11+
from jax.errors import JaxRuntimeError
1012
from ml_collections import config_dict
1113

1214
import crazyflow # noqa: F401, ensure gymnasium envs are registered
@@ -41,15 +43,15 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
4143

4244

4345
def profile_gym_env_step(
44-
sim_config: config_dict.ConfigDict, n_steps: int, device: str
46+
sim_config: config_dict.ConfigDict, n_steps: int, device: str, print_summary: bool = True
4547
) -> list[float]:
4648
"""Profile the Crazyflow gym environment step performance."""
4749
times = []
4850
device = jax.devices(device)[0]
4951

5052
envs = gymnasium.make_vec(
5153
"DroneReachPos-v0",
52-
time_horizon_in_seconds=3,
54+
max_episode_time=3,
5355
num_envs=sim_config.n_worlds,
5456
device=sim_config.device,
5557
freq=sim_config.freq,
@@ -60,7 +62,7 @@ def profile_gym_env_step(
6062
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
6163
action[..., 0] = 0.3
6264
# Step through env once to ensure JIT compilation
63-
envs.reset(seed=42)
65+
envs.reset()
6466
envs.step(action)
6567

6668
jax.block_until_ready(envs.unwrapped.sim.data) # Ensure JIT compiled dynamics
@@ -73,12 +75,15 @@ def profile_gym_env_step(
7375
times.append(time.perf_counter() - tstart)
7476

7577
envs.close()
76-
print("Gym env step performance:")
77-
analyze_timings(times, n_steps, envs.unwrapped.sim.n_worlds, envs.unwrapped.sim.freq)
78+
if print_summary:
79+
print("Gym env step performance:")
80+
analyze_timings(times, n_steps, envs.unwrapped.sim.n_worlds, sim_config.freq)
7881
return times
7982

8083

81-
def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str) -> list[float]:
84+
def profile_step(
85+
sim_config: config_dict.ConfigDict, n_steps: int, device: str, print_summary: bool = True
86+
) -> list[float]:
8287
"""Profile the Crazyflow simulator step performance."""
8388
sim = Sim(**sim_config)
8489
times = []
@@ -99,8 +104,9 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
99104
jax.block_until_ready(sim.data)
100105
times.append(time.perf_counter() - tstart)
101106

102-
print("Sim step performance:")
103-
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
107+
if print_summary:
108+
print("Sim step performance:")
109+
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
104110
return times
105111

106112

@@ -140,9 +146,8 @@ def profile_reset(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
140146
analyze_timings(times_masked, n_steps, sim.n_worlds, sim.freq)
141147

142148

143-
def main():
149+
def main(device: str = "cpu", n_worlds_exp: int = 6):
144150
"""Main entry point for profiling."""
145-
device = "cpu"
146151
sim_config = config_dict.ConfigDict()
147152
sim_config.n_worlds = 1
148153
sim_config.n_drones = 1
@@ -181,93 +186,109 @@ def main():
181186
# Reopen the file in append mode for each result
182187

183188
n_steps = 1000
189+
skip_sim, skip_gym = False, False
184190
# Test with increasing number of parallel environments (worlds)
185-
for n_worlds in [1, 10, 100, 1000, 10000, 100000, 1000000]:
186-
print(f"\nTesting with {n_worlds} parallel environments:")
191+
for n_worlds in [10**i for i in range(n_worlds_exp + 1)]:
187192
sim_config.n_worlds = n_worlds
193+
print("-" * 80)
194+
if not skip_sim:
195+
# Test with a single step first to see if we should continue
196+
sim_config.freq = 500 # Test sim at 500 hz
197+
single_step_time = profile_step(sim_config, 2, device, print_summary=False)[1]
198+
199+
# If single step takes too long, skip this and remaining tests
200+
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
201+
print(
202+
f" Skipping benchmark for {n_worlds} and higher - projected time "
203+
f"{single_step_time * n_steps:.2f}s (> 1m)"
204+
)
205+
skip_sim = True
206+
207+
if not skip_sim:
208+
# Configure simulator
209+
print(f"Running simulator benchmark ({n_worlds} worlds)...")
210+
# Run simulator benchmark using existing function
211+
times_sim = profile_step(sim_config, n_steps, device)
212+
213+
# Calculate metrics for CSV
214+
total_time = sum(times_sim)
215+
avg_step_time = np.mean(times_sim)
216+
n_frames = n_steps * n_worlds
217+
fps = n_frames / total_time
218+
real_time_factor = (n_steps / sim_config.freq) * n_worlds / total_time
219+
220+
# Save simulator results
221+
# Reopen CSV writer in append mode
222+
with open(csv_file, "a", newline="") as f:
223+
csv_writer = csv.writer(f)
224+
csv_writer.writerow(
225+
[
226+
"simulator",
227+
1, # n_drones
228+
n_worlds,
229+
n_steps,
230+
total_time,
231+
avg_step_time,
232+
fps,
233+
real_time_factor,
234+
sim_config.device,
235+
]
236+
)
237+
f.flush()
238+
239+
if not skip_gym:
240+
print(f"Running gym environment benchmark ({n_worlds} worlds)...")
241+
# Run gym environment benchmark using existing function
242+
sim_config.freq = 50 # Test gym at 50 hz
243+
try:
244+
step_times = profile_gym_env_step(sim_config, 2, device, print_summary=False)
245+
single_step_time = step_times[1]
246+
# If single step takes too long, skip this test only
247+
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
248+
print(
249+
f" Skipping benchmark for {n_worlds} - projected time "
250+
f"{single_step_time * n_steps:.2f}s (> 1m)"
251+
)
252+
skip_gym = True
253+
except JaxRuntimeError:
254+
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
255+
skip_gym = True
188256

189-
# Test with a single step first to see if we should continue
190-
sim_config.freq = 500 # Test sim at 500 hz
191-
test_times = profile_step(sim_config, 1, device)
192-
193-
single_step_time = test_times[0]
194-
# If single step takes too long, skip this and remaining tests
195-
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
196-
print(
197-
f" Skipping benchmark for {n_worlds} and higher - single step took "
198-
f"{single_step_time * 1000:.2f}s (> 1m)"
199-
)
200-
break
201-
202-
# Configure simulator
203-
print(f" Running simulator benchmark ({n_worlds} worlds)...")
204-
# Run simulator benchmark using existing function
205-
times_sim = profile_step(sim_config, n_steps, device)
206-
207-
# Calculate metrics for CSV
208-
total_time = sum(times_sim)
209-
avg_step_time = np.mean(times_sim)
210-
n_frames = n_steps * n_worlds
211-
fps = n_frames / total_time
212-
real_time_factor = (n_steps / sim_config.freq) * n_worlds / total_time
213-
214-
# Save simulator results
215-
# Reopen CSV writer in append mode
216-
with open(csv_file, "w", newline="") as f:
217-
csv_writer = csv.writer(f)
218-
csv_writer.writerow(
219-
[
220-
"simulator",
221-
1, # n_drones
222-
n_worlds,
223-
n_steps,
224-
total_time,
225-
avg_step_time,
226-
fps,
227-
real_time_factor,
228-
sim_config.device,
229-
]
230-
)
231-
f.flush()
232-
233-
print(f" Running gym environment benchmark ({n_worlds} worlds)...")
234-
# Run gym environment benchmark using existing function
235-
sim_config.freq = 50 # Test gym at 50 hz
236-
try:
237-
times_gym = profile_gym_env_step(sim_config, n_steps, device)
238-
except ValueError as e:
239-
if "RESOURCE_EXHAUSTED" in str(e):
257+
if not skip_gym:
258+
try:
259+
times_gym = profile_gym_env_step(sim_config, n_steps, device)
260+
except JaxRuntimeError:
240261
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
241-
continue # Only continue, we might still be able to benchmark sim
242-
raise e
243-
244-
# Calculate metrics for CSV
245-
total_time = sum(times_gym)
246-
avg_step_time = np.mean(times_gym)
247-
n_frames = n_steps * n_worlds
248-
fps = n_frames / total_time
249-
real_time_factor = (n_steps / sim_config.freq) * sim_config.n_worlds / total_time
250-
251-
# Save gym environment results
252-
with open(csv_file, "a", newline="") as f:
253-
csv_writer = csv.writer(f)
254-
csv_writer.writerow(
255-
[
256-
"gym_env",
257-
sim_config.n_drones,
258-
sim_config.n_worlds,
259-
n_steps,
260-
total_time,
261-
avg_step_time,
262-
fps,
263-
real_time_factor,
264-
sim_config.device,
265-
]
266-
)
267-
f.flush()
262+
skip_gym = True
263+
continue
264+
265+
# Calculate metrics for CSV
266+
total_time = sum(times_gym)
267+
avg_step_time = np.mean(times_gym)
268+
n_frames = n_steps * n_worlds
269+
fps = n_frames / total_time
270+
real_time_factor = (n_steps / sim_config.freq) * sim_config.n_worlds / total_time
271+
272+
# Save gym environment results
273+
with open(csv_file, "a", newline="") as f:
274+
csv_writer = csv.writer(f)
275+
csv_writer.writerow(
276+
[
277+
"gym_env",
278+
sim_config.n_drones,
279+
sim_config.n_worlds,
280+
n_steps,
281+
total_time,
282+
avg_step_time,
283+
fps,
284+
real_time_factor,
285+
sim_config.device,
286+
]
287+
)
288+
f.flush()
268289

269290
print(f"\nBenchmark results saved to {csv_file}")
270291

271292

272293
if __name__ == "__main__":
273-
main()
294+
fire.Fire(main)

benchmark/performance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from crazyflow.sim import Sim
1414

1515
if TYPE_CHECKING:
16-
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal
16+
from crazyflow.envs import ReachPosEnv
1717

1818

1919
def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
@@ -44,7 +44,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
4444
def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
4545
device = jax.devices(device)[0]
4646

47-
envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
47+
envs: ReachPosEnv = gymnasium.make_vec(
4848
"DroneReachPos-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config
4949
)
5050

0 commit comments

Comments
 (0)