@@ -576,3 +576,40 @@ def seed_sim(data: SimData, seed: int, device: Device) -> SimData:
576576 """JIT-compiled seeding function."""
577577 rng_key = jax .device_put (jax .random .key (seed ), device )
578578 return data .replace (core = data .core .replace (rng_key = rng_key ))
579+
580+
581+ def use_box_collision (sim : Sim , enable : bool = True ):
582+ """Changes the collision geometry to use boxes or spheres (default).
583+
584+ Args:
585+ sim: The simulation instance.
586+ enable: If True, use box collision geometry. If False, use sphere collision geometry.
587+
588+ Warning:
589+ Using box collision geometry is more computationally expensive than sphere collision
590+ geometry, especially for larger swarms. It is recommended to only enable box collision
591+ geometry for small swarms or when high accuracy is required.
592+ """
593+
594+ def find_geom_ids_by_prefix (model : mujoco .MjModel , prefix : str ) -> Array [int ]:
595+ geom_ids = []
596+ for i in range (model .ngeom ):
597+ name : str | None = mujoco .mj_id2name (model , mujoco .mjtObj .mjOBJ_GEOM , i )
598+ if name is not None and name .startswith (prefix ):
599+ geom_ids .append (i )
600+ return jnp .asarray (geom_ids )
601+
602+ # Get geom ids
603+ sphere_ids = find_geom_ids_by_prefix (sim .mj_model , "col_sphere" )
604+ box_ids = find_geom_ids_by_prefix (sim .mj_model , "col_box" )
605+ assert len (sphere_ids ) == len (box_ids ) and len (sphere_ids ) > 0 , (
606+ "Number of sphere and box collision geometries must be the same, check xml files"
607+ )
608+
609+ # Enable/disable geoms
610+ sim .mj_model .geom_contype [sphere_ids ] = 1 * (not enable )
611+ sim .mj_model .geom_conaffinity [sphere_ids ] = 1 * (not enable )
612+ sim .mj_model .geom_rgba [sphere_ids , 3 ] = 1 * (not enable )
613+ sim .mj_model .geom_contype [box_ids ] = 1 * enable
614+ sim .mj_model .geom_conaffinity [box_ids ] = 1 * enable
615+ sim .mj_model .geom_rgba [box_ids , 3 ] = 1 * enable
0 commit comments