Skip to content

Commit 1686d14

Browse files
committed
Add sleep state fields
- Add MJ_MINAWAKE constant, SLEEP enable bit, SleepPolicy and SleepState enums - Add sleep_tolerance to Option, tree_sleep_policy/dof_length to Model - Add tree_asleep, tree_awake, body_awake, dof_awake_ind, body_awake_ind, nv_awake, nbody_awake, ntree_awake to Data - Initialize sleep state in make_data and put_data (all trees start awake) - Add tests for sleep state initialization, sleep policy and dof_length import
1 parent 70533bf commit 1686d14

3 files changed

Lines changed: 137 additions & 1 deletion

File tree

mujoco_warp/_src/io.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,10 @@ def _check_margin(name, t1, t2, margin):
643643
m.flexedge_J_colind = mjm.flexedge_J_colind.reshape(-1)
644644

645645
# place m on device
646-
sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int})
646+
# TODO(team): remove ntree once field is added to types.Model
647+
sizes = dict(
648+
{"*": 1, "ntree": mjm.ntree}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int}
649+
)
647650
for f in dataclasses.fields(types.Model):
648651
if _is_array_spec(f.type):
649652
setattr(m, f.name, _create_array(getattr(m, f.name), f.type, sizes))
@@ -929,6 +932,7 @@ def make_data(
929932
sizes["nworld"] = nworld
930933
sizes["naconmax"] = naconmax
931934
sizes["njmax"] = njmax
935+
sizes["ntree"] = mjm.ntree
932936

933937
if njmax_nnz is None:
934938
if is_sparse(mjm):
@@ -995,6 +999,10 @@ def make_data(
995999
# island arrays
9961000
"nisland": None,
9971001
"tree_island": None,
1002+
# sleep state: all trees start fully awake
1003+
"tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int),
1004+
"tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int),
1005+
"body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int),
9981006
}
9991007
for f in dataclasses.fields(types.Data):
10001008
if f.name in d_kwargs:
@@ -1101,6 +1109,7 @@ def put_data(
11011109
sizes["nworld"] = nworld
11021110
sizes["naconmax"] = naconmax
11031111
sizes["njmax"] = njmax
1112+
sizes["ntree"] = mjm.ntree
11041113

11051114
if njmax_nnz is None:
11061115
if is_sparse(mjm):
@@ -1210,6 +1219,10 @@ def put_data(
12101219
# island arrays
12111220
"nisland": None,
12121221
"tree_island": None,
1222+
# sleep state: all trees start fully awake
1223+
"tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int),
1224+
"tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int),
1225+
"body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int),
12131226
}
12141227
for f in dataclasses.fields(types.Data):
12151228
if f.name in d_kwargs:

mujoco_warp/_src/io_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,6 +2423,68 @@ def test_margin_pair_box_box(self):
24232423
""")
24242424
)
24252425

2426+
# --- sleep state tests (from incoming branch) ---
2427+
2428+
@parameterized.parameters(True, False)
2429+
def test_sleep_state_initial(self, use_make_data):
2430+
"""Tests that make_data and put_data initialize all trees awake."""
2431+
mjm = mujoco.MjModel.from_xml_string("""
2432+
<mujoco>
2433+
<worldbody>
2434+
<body>
2435+
<joint/>
2436+
<geom size=".1"/>
2437+
</body>
2438+
</worldbody>
2439+
</mujoco>
2440+
""")
2441+
mjd = mujoco.MjData(mjm)
2442+
2443+
if use_make_data:
2444+
d = mjwarp.make_data(mjm)
2445+
else:
2446+
d = mjwarp.put_data(mjm, mjd)
2447+
2448+
# All trees should be awake (tree_asleep < 0)
2449+
tree_asleep = d.tree_asleep.numpy()
2450+
self.assertTrue((tree_asleep < 0).all(), "tree_asleep should be < 0 (awake)")
2451+
# tree_awake should all be 1
2452+
tree_awake = d.tree_awake.numpy()
2453+
np.testing.assert_array_equal(tree_awake, 1, "tree_awake should be 1")
2454+
# body_awake should all be 1
2455+
body_awake = d.body_awake.numpy()
2456+
np.testing.assert_array_equal(body_awake, 1, "body_awake should be 1")
2457+
2458+
def test_sleep_policy_import(self):
2459+
"""Tests that tree_sleep_policy matches MuJoCo."""
2460+
mjm = mujoco.MjModel.from_xml_string("""
2461+
<mujoco>
2462+
<worldbody>
2463+
<body>
2464+
<joint/>
2465+
<geom size=".1"/>
2466+
</body>
2467+
</worldbody>
2468+
</mujoco>
2469+
""")
2470+
m = mjwarp.put_model(mjm)
2471+
np.testing.assert_array_equal(m.tree_sleep_policy.numpy(), mjm.tree_sleep_policy)
2472+
2473+
def test_dof_length_import(self):
2474+
"""Tests that dof_length matches MuJoCo."""
2475+
mjm = mujoco.MjModel.from_xml_string("""
2476+
<mujoco>
2477+
<worldbody>
2478+
<body>
2479+
<joint/>
2480+
<geom size=".1"/>
2481+
</body>
2482+
</worldbody>
2483+
</mujoco>
2484+
""")
2485+
m = mjwarp.put_model(mjm)
2486+
np.testing.assert_allclose(m.dof_length.numpy(), mjm.dof_length)
2487+
24262488

24272489
# TODO(team): test set_const_0 sparse
24282490

mujoco_warp/_src/types.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance
2727
MJ_MAXCONPAIR = mujoco.mjMAXCONPAIR
2828
MJ_MINMU = mujoco.mjMINMU # minimum friction
29+
# TODO(team): set with mujoco.mjMINAWAKE after mjwarp depends
30+
# on mujoco > 3.4.0 in pyproject.toml
31+
MJ_MINAWAKE = 10 # minimum number of timesteps before sleeping
2932
# maximum size (by number of edges) of an horizon in EPA algorithm
3033
MJ_MAX_EPAHORIZON = 24
3134
# maximum average number of trianglarfaces EPA can insert at each iteration
@@ -213,14 +216,50 @@ class EnableBit(enum.IntFlag):
213216
ENERGY: energy computation
214217
INVDISCRETE: discrete-time inverse dynamics
215218
MULTICCD: multiple contacts with CCD
219+
SLEEP: sleeping
216220
"""
217221

218222
ENERGY = mujoco.mjtEnableBit.mjENBL_ENERGY
219223
INVDISCRETE = mujoco.mjtEnableBit.mjENBL_INVDISCRETE
220224
MULTICCD = mujoco.mjtEnableBit.mjENBL_MULTICCD
225+
SLEEP = mujoco.mjtEnableBit.mjENBL_SLEEP
221226
# unsupported: OVERRIDE, FWDINV, ISLAND
222227

223228

229+
class SleepPolicy(enum.IntEnum):
230+
"""Per-tree sleep policy.
231+
232+
Attributes:
233+
AUTO: compiler chooses sleep policy
234+
AUTO_NEVER: compiler sleep policy: never
235+
AUTO_ALLOWED: compiler sleep policy: allowed
236+
NEVER: user sleep policy: never
237+
ALLOWED: user sleep policy: allowed
238+
INIT: user sleep policy: initialized asleep
239+
"""
240+
241+
AUTO = mujoco.mjtSleepPolicy.mjSLEEP_AUTO
242+
AUTO_NEVER = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_NEVER
243+
AUTO_ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_ALLOWED
244+
NEVER = mujoco.mjtSleepPolicy.mjSLEEP_NEVER
245+
ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_ALLOWED
246+
INIT = mujoco.mjtSleepPolicy.mjSLEEP_INIT
247+
248+
249+
class SleepState(enum.IntEnum):
250+
"""Sleep state for bodies.
251+
252+
Attributes:
253+
ASLEEP: body is asleep
254+
AWAKE: body is awake
255+
STATIC: body is static (world body or mocap)
256+
"""
257+
258+
ASLEEP = 0
259+
AWAKE = 1
260+
STATIC = 2
261+
262+
224263
class TrnType(enum.IntEnum):
225264
"""Type of actuator transmission.
226265
@@ -722,6 +761,7 @@ class Option:
722761
tolerance: main solver tolerance
723762
ls_tolerance: CG/Newton linesearch tolerance
724763
ccd_tolerance: convex collision detection tolerance
764+
sleep_tolerance: sleep velocity tolerance
725765
gravity: gravitational acceleration
726766
wind: wind (for lift, drag, and viscosity)
727767
magnetic: global magnetic flux
@@ -756,6 +796,7 @@ class Option:
756796
tolerance: array("*", float)
757797
ls_tolerance: array("*", float)
758798
ccd_tolerance: array("*", float)
799+
sleep_tolerance: float
759800
gravity: array("*", wp.vec3)
760801
wind: array("*", wp.vec3)
761802
magnetic: array("*", wp.vec3)
@@ -947,9 +988,11 @@ class Model:
947988
dof_damping: damping coefficient (*, nv)
948989
dof_dampingpoly: high-order damping coefficients (*, nv, 2)
949990
dof_invweight0: diag. inverse inertia in qpos0 (*, nv)
991+
dof_length: dof length for weighting velocity norm (nv,)
950992
tree_bodynum: number of bodies in tree (incl. root) (ntree,)
951993
tree_dofadr: start address of tree's dofs (ntree,)
952994
tree_dofnum: number of dofs in tree (ntree,)
995+
tree_sleep_policy: tree sleep policy (SleepPolicy) (ntree,)
953996
geom_type: geometric type (GeomType) (ngeom,)
954997
geom_contype: geom contact type (ngeom,)
955998
geom_conaffinity: geom contact affinity (ngeom,)
@@ -1340,9 +1383,11 @@ class Model:
13401383
dof_damping: array("*", "nv", float)
13411384
dof_dampingpoly: array("*", "nv", wp.vec2)
13421385
dof_invweight0: array("*", "nv", float)
1386+
dof_length: array("nv", float)
13431387
tree_bodynum: array("ntree", int)
13441388
tree_dofadr: array("ntree", int)
13451389
tree_dofnum: array("ntree", int)
1390+
tree_sleep_policy: array("ntree", int)
13461391
geom_type: array("ngeom", int)
13471392
geom_contype: array("ngeom", int)
13481393
geom_conaffinity: array("ngeom", int)
@@ -1728,6 +1773,9 @@ class Data:
17281773
nl: number of limit constraints (nworld,)
17291774
nefc: number of constraints (nworld,)
17301775
nisland: number of constraint islands (nworld,)
1776+
ntree_awake: number of awake trees (nworld,)
1777+
nbody_awake: number of awake bodies (nworld,)
1778+
nv_awake: number of awake dofs (nworld,)
17311779
time: simulation time (nworld,)
17321780
energy: potential, kinetic energy (nworld, 2)
17331781
qpos: position (nworld, nq)
@@ -1743,6 +1791,7 @@ class Data:
17431791
qacc: acceleration (nworld, nv)
17441792
act_dot: time-derivative of actuator activation (nworld, na)
17451793
sensordata: sensor data array (nworld, nsensordata,)
1794+
tree_asleep: tree asleep counter; >=0: asleep cycle (nworld, ntree)
17461795
xpos: Cartesian position of body frame (nworld, nbody, 3)
17471796
xquat: Cartesian orientation of body frame (nworld, nbody, 4)
17481797
xmat: Cartesian orientation of body frame (nworld, nbody, 3, 3)
@@ -1781,6 +1830,10 @@ class Data:
17811830
qLD: L'*D*L factorization of M (nworld, nv, nv) if dense
17821831
(nworld, 1, nC) if sparse
17831832
qLDiagInv: 1/diag(D) (nworld, nv)
1833+
tree_awake: is tree awake; 0: asleep; 1: awake (nworld, ntree)
1834+
body_awake: body sleep state (SleepState) (nworld, nbody)
1835+
body_awake_ind: indices of awake/static bodies (nworld, nbody)
1836+
dof_awake_ind: indices of awake dofs (nworld, nv)
17841837
flexedge_velocity: flex edge velocities (nworld, nflexedge)
17851838
ten_velocity: tendon velocities (nworld, ntendon)
17861839
actuator_velocity: actuator velocities (nworld, nu)
@@ -1826,6 +1879,9 @@ class Data:
18261879
nl: array("nworld", int)
18271880
nefc: array("nworld", int)
18281881
nisland: array("nworld", int)
1882+
ntree_awake: array("nworld", int)
1883+
nbody_awake: array("nworld", int)
1884+
nv_awake: array("nworld", int)
18291885
time: array("nworld", float)
18301886
energy: array("nworld", wp.vec2)
18311887
qpos: array("nworld", "nq", float)
@@ -1841,6 +1897,7 @@ class Data:
18411897
qacc: array("nworld", "nv", float)
18421898
act_dot: array("nworld", "na", float)
18431899
sensordata: array("nworld", "nsensordata", float)
1900+
tree_asleep: array("nworld", "ntree", int)
18441901
xpos: array("nworld", "nbody", wp.vec3)
18451902
xquat: array("nworld", "nbody", wp.quat)
18461903
xmat: array("nworld", "nbody", wp.mat33)
@@ -1877,6 +1934,10 @@ class Data:
18771934
qM: wp.array3d[float]
18781935
qLD: wp.array3d[float]
18791936
qLDiagInv: array("nworld", "nv", float)
1937+
tree_awake: array("nworld", "ntree", int)
1938+
body_awake: array("nworld", "nbody", int)
1939+
body_awake_ind: array("nworld", "nbody", int)
1940+
dof_awake_ind: array("nworld", "nv", int)
18801941
flexedge_velocity: array("nworld", "nflexedge", float)
18811942
ten_velocity: array("nworld", "ntendon", float)
18821943
actuator_velocity: array("nworld", "nu", float)

0 commit comments

Comments
 (0)