Skip to content

Commit a356957

Browse files
committed
Fix example scripts
1 parent b22b648 commit a356957

6 files changed

Lines changed: 15 additions & 7 deletions

File tree

examples/cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
or when the cache directory is deleted.
2222
"""
2323

24+
import shutil
2425
import time
2526
from pathlib import Path
2627

@@ -32,7 +33,7 @@ def main():
3233
cache_dir = Path("/tmp/jax_cache_test")
3334
if use_cache := cache_dir.exists():
3435
print("Cache directory exists. This run will be fast.")
35-
print(f"\nTo run without cache, delete the directory {cache_dir}.")
36+
print("\nTo run without cache, run this script again.")
3637
else:
3738
print("Cache directory does not exist. This run will be slow.")
3839
print("\nTo run with cache, run this script again.")
@@ -44,6 +45,9 @@ def main():
4445
t2 = time.perf_counter()
4546
prefix = "Using cache: " if use_cache else "Not using cache: "
4647
print(f"{prefix}\n Init: {t1 - t0:.3f}s\n Step: {t2 - t1:.3f}s")
48+
if use_cache:
49+
shutil.rmtree(cache_dir) # Clean up cache so that the next run is slow again
50+
sim.close()
4751

4852

4953
if __name__ == "__main__":

examples/contacts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ def main():
99
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.so_rpy, device="cpu")
1010
fps = 60
1111

12-
cmd = np.array([[[0.3, 0, 0, 0] for _ in range(n_drones)] for _ in range(n_worlds)])
12+
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
13+
cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81 * 1.04
1314
for i in range(int(2 * sim.control_freq)):
1415
sim.attitude_control(cmd)
1516
sim.step(sim.freq // sim.control_freq)

examples/crash.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def main():
3535
if ((i * fps) % sim.control_freq) < fps:
3636
sim.render()
3737
print(f"Crash detected: {sim.contacts().any()}")
38+
sim.close()
3839

3940

4041
if __name__ == "__main__":

examples/gradient.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
def main():
1313
sim = Sim(control=Control.attitude, physics=Physics.first_principles, attitude_freq=50)
14-
sim_step = sim._step
14+
# Remove clipping floor function which kills gradients
15+
sim.step_pipeline = sim.step_pipeline[:-1]
16+
sim_step = sim.build_step_fn()
1517

1618
def step(cmd: NDArray, data: SimData) -> jax.Array:
1719
data = data.replace(
@@ -23,7 +25,7 @@ def step(cmd: NDArray, data: SimData) -> jax.Array:
2325
step_grad = jax.jit(jax.grad(step))
2426

2527
cmd = jnp.zeros((1, 1, 4), dtype=jnp.float32)
26-
cmd = cmd.at[..., 3].set(0.3)
28+
cmd = cmd.at[..., 3].set(sim.data.params.mass[0, 0, 0] * 9.81 * 1.05)
2729

2830
# Trigger jax's jit to compile the gradient function. This is not necessary, but it ensures that
2931
# the timings are not affected by the compilation time.

examples/randomize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def main():
1717
# Randomize the inertia and mass of the drones
1818
mask = np.array([True, False, False]) # Only randomize the first world
1919
mass = sim.data.params.mass
20-
mass_rng = mass + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 1)) * 1e-4
20+
mass_rng = mass + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 1)) * 5e-3
2121
J = sim.data.params.J
22-
J_rng = J + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 3, 3)) * 1e-5
22+
J_rng = J + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 3, 3)) * 1e-6
2323

2424
randomize_mass(sim, mass_rng, mask)
2525
# Note: The mask is optional. We can also randomize all worlds at once

examples/render.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def main():
1212
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.so_rpy, device="cpu")
1313
fps = 60
1414
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
15-
cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81
15+
cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81 * 1.05
1616
rgbas = np.random.default_rng(0).uniform(0, 1, (n_drones, 4))
1717
rgbas[..., 3] = 1.0
1818

0 commit comments

Comments
 (0)