Skip to content

Commit 30cd9c0

Browse files
committed
Add sleep state fields
- Add MJ_MINAWAKE constant, SLEEP enable bit, SleepPolicy and SleepState enums - Add sleep_tolerance to Option, tree_sleep_policy/dof_length to Model - Add tree_asleep, tree_awake, body_awake, dof_awake_ind, body_awake_ind, nv_awake, nbody_awake, ntree_awake to Data - Initialize sleep state in make_data and put_data (all trees start awake) - Add tests for sleep state initialization, sleep policy and dof_length import
1 parent 0ab0578 commit 30cd9c0

3 files changed

Lines changed: 135 additions & 1 deletion

File tree

mujoco_warp/_src/io.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,10 @@ def geom_trid_index(i, j):
594594
m.qM_madr_ij.append(madr_ij)
595595

596596
# place m on device
597-
sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int})
597+
# TODO(team): remove ntree once field is added to types.Model
598+
sizes = dict(
599+
{"*": 1, "ntree": mjm.ntree}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int}
600+
)
598601
for f in dataclasses.fields(types.Model):
599602
if isinstance(f.type, wp.array):
600603
setattr(m, f.name, _create_array(getattr(m, f.name), f.type, sizes))
@@ -691,6 +694,7 @@ def make_data(
691694
sizes["nworld"] = nworld
692695
sizes["naconmax"] = naconmax
693696
sizes["njmax"] = njmax
697+
sizes["ntree"] = mjm.ntree
694698

695699
contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)})
696700
efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)})
@@ -728,6 +732,10 @@ def make_data(
728732
),
729733
# equality constraints
730734
"eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool),
735+
# sleep state: all trees start fully awake
736+
"tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int),
737+
"tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int),
738+
"body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int),
731739
}
732740
for f in dataclasses.fields(types.Data):
733741
if f.name in d_kwargs:
@@ -805,6 +813,7 @@ def put_data(
805813
sizes["nworld"] = nworld
806814
sizes["naconmax"] = naconmax
807815
sizes["njmax"] = njmax
816+
sizes["ntree"] = mjm.ntree
808817

809818
# ensure static geom positions are computed
810819
# TODO: remove once MjData creation semantics are fixed
@@ -882,6 +891,10 @@ def put_data(
882891
"ne_jnt": None,
883892
"ne_ten": None,
884893
"ne_flex": None,
894+
# sleep state: all trees start fully awake
895+
"tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int),
896+
"tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int),
897+
"body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int),
885898
}
886899
for f in dataclasses.fields(types.Data):
887900
if f.name in d_kwargs:

mujoco_warp/_src/io_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,66 @@ def test_eq_active(self, active, make_data):
645645

646646
_assert_eq(d.eq_active.numpy()[0], mjd.eq_active, "eq_active")
647647

648+
@parameterized.parameters(True, False)
649+
def test_sleep_state_initial(self, use_make_data):
650+
"""Tests that make_data and put_data initialize all trees awake."""
651+
mjm = mujoco.MjModel.from_xml_string("""
652+
<mujoco>
653+
<worldbody>
654+
<body>
655+
<joint/>
656+
<geom size=".1"/>
657+
</body>
658+
</worldbody>
659+
</mujoco>
660+
""")
661+
mjd = mujoco.MjData(mjm)
662+
663+
if use_make_data:
664+
d = mjwarp.make_data(mjm)
665+
else:
666+
d = mjwarp.put_data(mjm, mjd)
667+
668+
# All trees should be awake (tree_asleep < 0)
669+
tree_asleep = d.tree_asleep.numpy()
670+
self.assertTrue((tree_asleep < 0).all(), "tree_asleep should be < 0 (awake)")
671+
# tree_awake should all be 1
672+
tree_awake = d.tree_awake.numpy()
673+
np.testing.assert_array_equal(tree_awake, 1, "tree_awake should be 1")
674+
# body_awake should all be 1
675+
body_awake = d.body_awake.numpy()
676+
np.testing.assert_array_equal(body_awake, 1, "body_awake should be 1")
677+
678+
def test_sleep_policy_import(self):
679+
"""Tests that tree_sleep_policy matches MuJoCo."""
680+
mjm = mujoco.MjModel.from_xml_string("""
681+
<mujoco>
682+
<worldbody>
683+
<body>
684+
<joint/>
685+
<geom size=".1"/>
686+
</body>
687+
</worldbody>
688+
</mujoco>
689+
""")
690+
m = mjwarp.put_model(mjm)
691+
np.testing.assert_array_equal(m.tree_sleep_policy.numpy(), mjm.tree_sleep_policy)
692+
693+
def test_dof_length_import(self):
694+
"""Tests that dof_length matches MuJoCo."""
695+
mjm = mujoco.MjModel.from_xml_string("""
696+
<mujoco>
697+
<worldbody>
698+
<body>
699+
<joint/>
700+
<geom size=".1"/>
701+
</body>
702+
</worldbody>
703+
</mujoco>
704+
""")
705+
m = mjwarp.put_model(mjm)
706+
np.testing.assert_allclose(m.dof_length.numpy(), mjm.dof_length)
707+
648708

649709
if __name__ == "__main__":
650710
wp.init()

mujoco_warp/_src/types.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance
2525
MJ_MAXCONPAIR = mujoco.mjMAXCONPAIR
2626
MJ_MINMU = mujoco.mjMINMU # minimum friction
27+
# TODO(team): set with mujoco.mjMINAWAKE after mjwarp depends
28+
# on mujoco > 3.4.0 in pyproject.toml
29+
MJ_MINAWAKE = 10 # minimum number of timesteps before sleeping
2730
# maximum size (by number of edges) of an horizon in EPA algorithm
2831
MJ_MAX_EPAHORIZON = 24
2932
# maximum average number of trianglarfaces EPA can insert at each iteration
@@ -174,14 +177,50 @@ class EnableBit(enum.IntFlag):
174177
ENERGY: energy computation
175178
INVDISCRETE: discrete-time inverse dynamics
176179
MULTICCD: multiple contacts with CCD
180+
SLEEP: sleeping
177181
"""
178182

179183
ENERGY = mujoco.mjtEnableBit.mjENBL_ENERGY
180184
INVDISCRETE = mujoco.mjtEnableBit.mjENBL_INVDISCRETE
181185
MULTICCD = mujoco.mjtEnableBit.mjENBL_MULTICCD
186+
SLEEP = mujoco.mjtEnableBit.mjENBL_SLEEP
182187
# unsupported: OVERRIDE, FWDINV, ISLAND
183188

184189

190+
class SleepPolicy(enum.IntEnum):
191+
"""Per-tree sleep policy.
192+
193+
Attributes:
194+
AUTO: compiler chooses sleep policy
195+
AUTO_NEVER: compiler sleep policy: never
196+
AUTO_ALLOWED: compiler sleep policy: allowed
197+
NEVER: user sleep policy: never
198+
ALLOWED: user sleep policy: allowed
199+
INIT: user sleep policy: initialized asleep
200+
"""
201+
202+
AUTO = mujoco.mjtSleepPolicy.mjSLEEP_AUTO
203+
AUTO_NEVER = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_NEVER
204+
AUTO_ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_ALLOWED
205+
NEVER = mujoco.mjtSleepPolicy.mjSLEEP_NEVER
206+
ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_ALLOWED
207+
INIT = mujoco.mjtSleepPolicy.mjSLEEP_INIT
208+
209+
210+
class SleepState(enum.IntEnum):
211+
"""Sleep state for bodies.
212+
213+
Attributes:
214+
ASLEEP: body is asleep
215+
AWAKE: body is awake
216+
STATIC: body is static (world body or mocap)
217+
"""
218+
219+
ASLEEP = 0
220+
AWAKE = 1
221+
STATIC = 2
222+
223+
185224
class TrnType(enum.IntEnum):
186225
"""Type of actuator transmission.
187226
@@ -647,6 +686,7 @@ class Option:
647686
tolerance: main solver tolerance
648687
ls_tolerance: CG/Newton linesearch tolerance
649688
ccd_tolerance: convex collision detection tolerance
689+
sleep_tolerance: sleep velocity tolerance
650690
density: density of medium
651691
viscosity: viscosity of medium
652692
gravity: gravitational acceleration
@@ -683,6 +723,7 @@ class Option:
683723
tolerance: array("*", float)
684724
ls_tolerance: array("*", float)
685725
ccd_tolerance: array("*", float)
726+
sleep_tolerance: float
686727
density: array("*", float)
687728
viscosity: array("*", float)
688729
gravity: array("*", wp.vec3)
@@ -835,6 +876,8 @@ class Model:
835876
dof_armature: dof armature inertia/mass (*, nv)
836877
dof_damping: damping coefficient (*, nv)
837878
dof_invweight0: diag. inverse inertia in qpos0 (*, nv)
879+
dof_length: dof length for weighting velocity norm (nv,)
880+
tree_sleep_policy: tree sleep policy (SleepPolicy) (ntree,)
838881
geom_type: geometric type (GeomType) (ngeom,)
839882
geom_contype: geom contact type (ngeom,)
840883
geom_conaffinity: geom contact affinity (ngeom,)
@@ -1186,6 +1229,8 @@ class Model:
11861229
dof_armature: array("*", "nv", float)
11871230
dof_damping: array("*", "nv", float)
11881231
dof_invweight0: array("*", "nv", float)
1232+
dof_length: array("nv", float)
1233+
tree_sleep_policy: array("ntree", int)
11891234
geom_type: array("ngeom", int)
11901235
geom_contype: array("ngeom", int)
11911236
geom_conaffinity: array("ngeom", int)
@@ -1562,6 +1607,9 @@ class Data:
15621607
nf: number of friction constraints (nworld,)
15631608
nl: number of limit constraints (nworld,)
15641609
nefc: number of constraints (nworld,)
1610+
ntree_awake: number of awake trees (nworld,)
1611+
nbody_awake: number of awake bodies (nworld,)
1612+
nv_awake: number of awake dofs (nworld,)
15651613
time: simulation time (nworld,)
15661614
energy: potential, kinetic energy (nworld, 2)
15671615
qpos: position (nworld, nq)
@@ -1577,6 +1625,7 @@ class Data:
15771625
qacc: acceleration (nworld, nv)
15781626
act_dot: time-derivative of actuator activation (nworld, na)
15791627
sensordata: sensor data array (nworld, nsensordata,)
1628+
tree_asleep: tree asleep counter; >=0: asleep cycle (nworld, ntree)
15801629
xpos: Cartesian position of body frame (nworld, nbody, 3)
15811630
xquat: Cartesian orientation of body frame (nworld, nbody, 4)
15821631
xmat: Cartesian orientation of body frame (nworld, nbody, 3, 3)
@@ -1612,6 +1661,10 @@ class Data:
16121661
qLD: L'*D*L factorization of M (nworld, nv, nv) if dense
16131662
(nworld, 1, nC) if sparse
16141663
qLDiagInv: 1/diag(D) (nworld, nv)
1664+
tree_awake: is tree awake; 0: asleep; 1: awake (nworld, ntree)
1665+
body_awake: body sleep state (SleepState) (nworld, nbody)
1666+
body_awake_ind: indices of awake/static bodies (nworld, nbody)
1667+
dof_awake_ind: indices of awake dofs (nworld, nv)
16151668
flexedge_velocity: flex edge velocities (nworld, nflexedge)
16161669
ten_velocity: tendon velocities (nworld, ntendon)
16171670
actuator_velocity: actuator velocities (nworld, nu)
@@ -1660,6 +1713,9 @@ class Data:
16601713
nf: array("nworld", int)
16611714
nl: array("nworld", int)
16621715
nefc: array("nworld", int)
1716+
ntree_awake: array("nworld", int)
1717+
nbody_awake: array("nworld", int)
1718+
nv_awake: array("nworld", int)
16631719
time: array("nworld", float)
16641720
energy: array("nworld", wp.vec2)
16651721
qpos: array("nworld", "nq", float)
@@ -1675,6 +1731,7 @@ class Data:
16751731
qacc: array("nworld", "nv", float)
16761732
act_dot: array("nworld", "na", float)
16771733
sensordata: array("nworld", "nsensordata", float)
1734+
tree_asleep: array("nworld", "ntree", int)
16781735
xpos: array("nworld", "nbody", wp.vec3)
16791736
xquat: array("nworld", "nbody", wp.quat)
16801737
xmat: array("nworld", "nbody", wp.mat33)
@@ -1708,6 +1765,10 @@ class Data:
17081765
qM: wp.array3d(dtype=float)
17091766
qLD: wp.array3d(dtype=float)
17101767
qLDiagInv: array("nworld", "nv", float)
1768+
tree_awake: array("nworld", "ntree", int)
1769+
body_awake: array("nworld", "nbody", int)
1770+
body_awake_ind: array("nworld", "nbody", int)
1771+
dof_awake_ind: array("nworld", "nv", int)
17111772
flexedge_velocity: array("nworld", "nflexedge", float)
17121773
ten_velocity: array("nworld", "ntendon", float)
17131774
actuator_velocity: array("nworld", "nu", float)

0 commit comments

Comments
 (0)