Skip to content

Commit a4eaad5

Browse files
committed
Replace range in vmap for jax > 0.7.0. Fix #37.
1 parent 65dae02 commit a4eaad5

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

crazyflow/sim/sim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
212212
mj_data = mujoco.MjData(mj_model)
213213
mjx_model = mjx.put_model(mj_model, device=self.device)
214214
mjx_data = mjx.put_data(mj_model, mj_data, device=self.device)
215-
mjx_data = jax.vmap(lambda _: mjx_data)(range(self.n_worlds))
215+
mjx_data = jax.vmap(lambda _: mjx_data)(jnp.ones(self.n_worlds))
216216
return mj_model, mj_data, mjx_model, mjx_data
217217

218218
def build_step_fn(self):

0 commit comments

Comments
 (0)