diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index f1da96088..2f7fe6228 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -37,6 +37,7 @@ from mujoco_warp._src.collision_sdf import sdf_narrowphase as sdf_narrowphase from mujoco_warp._src.constraint import make_constraint as make_constraint from mujoco_warp._src.derivative import deriv_smooth_vel as deriv_smooth_vel +from mujoco_warp._src.derivative import transition_fd as transition_fd from mujoco_warp._src.forward import euler as euler from mujoco_warp._src.forward import forward as forward from mujoco_warp._src.forward import fwd_acceleration as fwd_acceleration diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py index cd515e3e6..f78b5106a 100644 --- a/mujoco_warp/_src/derivative.py +++ b/mujoco_warp/_src/derivative.py @@ -15,6 +15,8 @@ import warp as wp +from mujoco_warp._src import forward +from mujoco_warp._src import math from mujoco_warp._src import util_misc from mujoco_warp._src.support import next_act from mujoco_warp._src.types import MJ_MINVAL @@ -23,6 +25,7 @@ from mujoco_warp._src.types import DisableBit from mujoco_warp._src.types import DynType from mujoco_warp._src.types import GainType +from mujoco_warp._src.types import JointType from mujoco_warp._src.types import Model from mujoco_warp._src.types import vec10f from mujoco_warp._src.warp_util import event_scope @@ -452,3 +455,540 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d[float]): ) # TODO(team): rne derivative + + +@wp.kernel +def _get_state( + # Model: + nq: int, + nv: int, + na: int, + # Data in: + qpos_in: wp.array2d[float], + qvel_in: wp.array2d[float], + act_in: wp.array2d[float], + # Out: + state_out: wp.array2d[float], +): + # get state = [qpos, qvel, act] + worldid = wp.tid() + for i in range(nq): + state_out[worldid, i] = qpos_in[worldid, i] + if i < nv: + state_out[worldid, nq + i] = qvel_in[worldid, i] + for i in range(na): + state_out[worldid, nq + nv + i] = act_in[worldid, i] + + +@wp.kernel +def _set_state( + # Model: + nq: int, + nv: int, + na: int, + # In: + state_in: wp.array2d[float], + # Data out: + qpos_out: wp.array2d[float], + qvel_out: wp.array2d[float], + act_out: wp.array2d[float], +): + # set state = [qpos, qvel, act] + worldid = wp.tid() + for i in range(nq): + qpos_out[worldid, i] = state_in[worldid, i] + if i < nv: + qvel_out[worldid, i] = state_in[worldid, nq + i] + for i in range(na): + act_out[worldid, i] = state_in[worldid, nq + nv + i] + + +@wp.kernel +def _state_diff_to_col( + # Model: + nq: int, + nv: int, + na: int, + njnt: int, + jnt_type: wp.array[int], + jnt_qposadr: wp.array[int], + jnt_dofadr: wp.array[int], + # In: + state1_in: wp.array2d[float], + state2_in: wp.array2d[float], + inv_h: float, + col_idx: int, + # Out: + jac_out: wp.array3d[float], +): + # finite difference two state vectors and write to Jacobian column + worldid = wp.tid() + + # position difference via joint type + for jntid in range(njnt): + jnttype = jnt_type[jntid] + qpos_adr = jnt_qposadr[jntid] + dof_adr = jnt_dofadr[jntid] + + if jnttype == JointType.FREE: + # linear position difference + jac_out[worldid, dof_adr + 0, col_idx] = (state2_in[worldid, qpos_adr + 0] - state1_in[worldid, qpos_adr + 0]) * inv_h + jac_out[worldid, dof_adr + 1, col_idx] = (state2_in[worldid, qpos_adr + 1] - state1_in[worldid, qpos_adr + 1]) * inv_h + jac_out[worldid, dof_adr + 2, col_idx] = (state2_in[worldid, qpos_adr + 2] - state1_in[worldid, qpos_adr + 2]) * inv_h + # quaternion difference + q1 = wp.quat( + state1_in[worldid, qpos_adr + 3], + state1_in[worldid, qpos_adr + 4], + state1_in[worldid, qpos_adr + 5], + state1_in[worldid, qpos_adr + 6], + ) + q2 = wp.quat( + state2_in[worldid, qpos_adr + 3], + state2_in[worldid, qpos_adr + 4], + state2_in[worldid, qpos_adr + 5], + state2_in[worldid, qpos_adr + 6], + ) + dq = math.quat_sub(q2, q1) + jac_out[worldid, dof_adr + 3, col_idx] = dq[0] * inv_h + jac_out[worldid, dof_adr + 4, col_idx] = dq[1] * inv_h + jac_out[worldid, dof_adr + 5, col_idx] = dq[2] * inv_h + elif jnttype == JointType.BALL: + q1 = wp.quat( + state1_in[worldid, qpos_adr + 0], + state1_in[worldid, qpos_adr + 1], + state1_in[worldid, qpos_adr + 2], + state1_in[worldid, qpos_adr + 3], + ) + q2 = wp.quat( + state2_in[worldid, qpos_adr + 0], + state2_in[worldid, qpos_adr + 1], + state2_in[worldid, qpos_adr + 2], + state2_in[worldid, qpos_adr + 3], + ) + dq = math.quat_sub(q2, q1) + jac_out[worldid, dof_adr + 0, col_idx] = dq[0] * inv_h + jac_out[worldid, dof_adr + 1, col_idx] = dq[1] * inv_h + jac_out[worldid, dof_adr + 2, col_idx] = dq[2] * inv_h + else: # SLIDE, HINGE + jac_out[worldid, dof_adr, col_idx] = (state2_in[worldid, qpos_adr] - state1_in[worldid, qpos_adr]) * inv_h + + # velocity and activation difference + for i in range(nv): + jac_out[worldid, nv + i, col_idx] = (state2_in[worldid, nq + i] - state1_in[worldid, nq + i]) * inv_h + for i in range(na): + jac_out[worldid, 2 * nv + i, col_idx] = (state2_in[worldid, nq + nv + i] - state1_in[worldid, nq + nv + i]) * inv_h + + +@wp.kernel +def _perturb_position( + # Model: + nq: int, + njnt: int, + jnt_type: wp.array[int], + jnt_qposadr: wp.array[int], + jnt_dofadr: wp.array[int], + # Data in: + qpos_in: wp.array2d[float], + # In: + dof_idx: int, + eps: float, + # Data out: + qpos_out: wp.array2d[float], +): + worldid = wp.tid() + + # copy qpos_in to qpos_out + for i in range(nq): + qpos_out[worldid, i] = qpos_in[worldid, i] + + # find joint for this dof and perturb + for jntid in range(njnt): + jnttype = jnt_type[jntid] + qpos_adr = jnt_qposadr[jntid] + dof_adr = jnt_dofadr[jntid] + + if jnttype == JointType.FREE: + if dof_idx >= dof_adr and dof_idx < dof_adr + 3: + qpos_out[worldid, qpos_adr + (dof_idx - dof_adr)] += eps + elif dof_idx >= dof_adr + 3 and dof_idx < dof_adr + 6: + q = wp.quat( + qpos_in[worldid, qpos_adr + 3], + qpos_in[worldid, qpos_adr + 4], + qpos_in[worldid, qpos_adr + 5], + qpos_in[worldid, qpos_adr + 6], + ) + local_idx = dof_idx - dof_adr - 3 + if local_idx == 0: + v = wp.vec3(1.0, 0.0, 0.0) + elif local_idx == 1: + v = wp.vec3(0.0, 1.0, 0.0) + else: + v = wp.vec3(0.0, 0.0, 1.0) + q_new = math.quat_integrate(q, v, eps) + qpos_out[worldid, qpos_adr + 3] = q_new[0] + qpos_out[worldid, qpos_adr + 4] = q_new[1] + qpos_out[worldid, qpos_adr + 5] = q_new[2] + qpos_out[worldid, qpos_adr + 6] = q_new[3] + elif jnttype == JointType.BALL: + if dof_idx >= dof_adr and dof_idx < dof_adr + 3: + q = wp.quat( + qpos_in[worldid, qpos_adr + 0], + qpos_in[worldid, qpos_adr + 1], + qpos_in[worldid, qpos_adr + 2], + qpos_in[worldid, qpos_adr + 3], + ) + local_idx = dof_idx - dof_adr + if local_idx == 0: + v = wp.vec3(1.0, 0.0, 0.0) + elif local_idx == 1: + v = wp.vec3(0.0, 1.0, 0.0) + else: + v = wp.vec3(0.0, 0.0, 1.0) + q_new = math.quat_integrate(q, v, eps) + qpos_out[worldid, qpos_adr + 0] = q_new[0] + qpos_out[worldid, qpos_adr + 1] = q_new[1] + qpos_out[worldid, qpos_adr + 2] = q_new[2] + qpos_out[worldid, qpos_adr + 3] = q_new[3] + else: # SLIDE, HINGE + if dof_idx == dof_adr: + qpos_out[worldid, qpos_adr] += eps + + +@wp.kernel +def _perturb_array( + # In: + idx: int, + eps: float, + arr_in: wp.array2d[float], + # Out: + arr_out: wp.array2d[float], +): + worldid = wp.tid() + for i in range(arr_in.shape[1]): + if i == idx: + arr_out[worldid, i] = arr_in[worldid, i] + eps + else: + arr_out[worldid, i] = arr_in[worldid, i] + + +@wp.kernel +def _diff_vectors_to_col( + # In: + x1_in: wp.array2d[float], + x2_in: wp.array2d[float], + inv_h: float, + n: int, + col_idx: int, + # Out: + jac_out: wp.array3d[float], +): + # dx = (x2 - x1) / h, written to Jacobian column + worldid = wp.tid() + for i in range(n): + jac_out[worldid, i, col_idx] = (x2_in[worldid, i] - x1_in[worldid, i]) * inv_h + + +@event_scope +def transition_fd( + m: Model, + d: Data, + eps: float, + centered: bool = False, + A: wp.array3d[float] = None, + B: wp.array3d[float] = None, + C: wp.array3d[float] = None, + D: wp.array3d[float] = None, +): + """Finite differenced transition matrices (control theory notation). + + Computes: d(x_next) = A*dx + B*du, d(sensor) = C*dx + D*du + where x = [qvel_diff, qvel, act] is the state in tangent space. + + Args: + m: model + d: data + eps: finite difference epsilon + centered: if True, use centered differences + A: output state transition matrix (nworld, ndx, ndx) where ndx = 2*nv+na + B: output control transition matrix (nworld, ndx, nu) + C: output state observation matrix (nworld, nsensordata, ndx) + D: output control observation matrix (nworld, nsensordata, nu) + """ + # TODO(team): add option for scratch memory + + nq, nv, na, nu = m.nq, m.nv, m.na, m.nu + ns = m.nsensordata + ndx = 2 * nv + na + nworld = d.nworld + + # skip sensor computations if not requested + skip_sensor = C is None and D is None + + # save current state + state_size = nq + nv + na + state0 = wp.empty((nworld, state_size), dtype=float) + ctrl0 = wp.empty((nworld, nu), dtype=float) if nu > 0 else None + warmstart0 = wp.empty((nworld, nv), dtype=float) + time0 = wp.empty(nworld, dtype=float) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[state0]) + if nu > 0: + wp.copy(ctrl0, d.ctrl) + wp.copy(warmstart0, d.qacc_warmstart) + wp.copy(time0, d.time) + + # baseline step + forward.step(m, d) + + # save baseline next state and sensors + next_state = wp.empty((nworld, state_size), dtype=float) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_state]) + sensor0 = None + if not skip_sensor: + sensor0 = wp.empty((nworld, ns), dtype=float) + wp.copy(sensor0, d.sensordata) + + # restore state + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + # allocate work arrays + next_plus = wp.empty((nworld, state_size), dtype=float) + next_minus = wp.empty((nworld, state_size), dtype=float) if centered else None + sensor_plus = wp.empty((nworld, ns), dtype=float) if not skip_sensor else None + sensor_minus = wp.empty((nworld, ns), dtype=float) if not skip_sensor and centered else None + + inv_eps = 1.0 / eps + inv_2eps = 1.0 / (2.0 * eps) if centered else inv_eps + + # finite difference controls + if (B is not None or D is not None) and nu > 0: + ctrl_temp = wp.empty((nworld, nu), dtype=float) + for i in range(nu): + # nudge forward + wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, ctrl0], outputs=[ctrl_temp]) + wp.copy(d.ctrl, ctrl_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.ctrl, ctrl0) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + if centered: + wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, ctrl0], outputs=[ctrl_temp]) + wp.copy(d.ctrl, ctrl_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.ctrl, ctrl0) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + # compute derivatives + if B is not None: + if centered: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps, i], + outputs=[B], + ) + else: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps, i], + outputs=[B], + ) + + if D is not None: + if centered: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor_plus, sensor_minus, inv_2eps, ns, i], outputs=[D]) + else: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns, i], outputs=[D]) + + # finite difference activations + if (A is not None or C is not None) and na > 0: + act0 = wp.empty((nworld, na), dtype=float) + wp.copy(act0, d.act) + act_temp = wp.empty((nworld, na), dtype=float) + for i in range(na): + # nudge forward + wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, act0], outputs=[act_temp]) + wp.copy(d.act, act_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + if centered: + wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, act0], outputs=[act_temp]) + wp.copy(d.act, act_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + # compute derivatives + col_idx = 2 * nv + i + if A is not None: + if centered: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps, col_idx], + outputs=[A], + ) + else: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps, col_idx], + outputs=[A], + ) + + if C is not None: + if centered: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns, col_idx], outputs=[C]) + else: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns, col_idx], outputs=[C]) + + # finite difference velocities + if A is not None or C is not None: + qvel0 = wp.empty((nworld, nv), dtype=float) + wp.copy(qvel0, d.qvel) + qvel_temp = wp.empty((nworld, nv), dtype=float) + for i in range(nv): + # nudge forward + wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, qvel0], outputs=[qvel_temp]) + wp.copy(d.qvel, qvel_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + if centered: + wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, qvel0], outputs=[qvel_temp]) + wp.copy(d.qvel, qvel_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + # compute derivatives + col_idx = nv + i + if A is not None: + if centered: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps, col_idx], + outputs=[A], + ) + else: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps, col_idx], + outputs=[A], + ) + + if C is not None: + if centered: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns, col_idx], outputs=[C]) + else: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns, col_idx], outputs=[C]) + + # finite difference positions + if A is not None or C is not None: + qpos_perturbed = wp.empty((nworld, nq), dtype=float) + for i in range(nv): + # nudge position forward + wp.launch( + _perturb_position, + dim=nworld, + inputs=[nq, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, i, eps], + outputs=[qpos_perturbed], + ) + wp.copy(d.qpos, qpos_perturbed) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + if centered: + wp.launch( + _perturb_position, + dim=nworld, + inputs=[nq, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, i, -eps], + outputs=[qpos_perturbed], + ) + wp.copy(d.qpos, qpos_perturbed) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) + + # compute derivatives + col_idx = i + if A is not None: + if centered: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps, col_idx], + outputs=[A], + ) + else: + wp.launch( + _state_diff_to_col, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps, col_idx], + outputs=[A], + ) + + if C is not None: + if centered: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns, col_idx], outputs=[C]) + else: + wp.launch(_diff_vectors_to_col, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns, col_idx], outputs=[C]) + + # restore final state + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + wp.copy(d.qacc_warmstart, warmstart0) + wp.copy(d.time, time0) diff --git a/mujoco_warp/_src/derivative_test.py b/mujoco_warp/_src/derivative_test.py index cc4746d29..ff6afbf10 100644 --- a/mujoco_warp/_src/derivative_test.py +++ b/mujoco_warp/_src/derivative_test.py @@ -534,6 +534,386 @@ def test_forcerange_clamped_derivative(self): "implicitfast should be more accurate than Euler at large timestep when forcerange derivatives are correctly handled", ) + @parameterized.parameters(False, True) + def test_transition_fd_linear_system(self, centered): + """Tests A and B matrices match MuJoCo mjd_transitionFD.""" + # simple linear system with 3 slide joints + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + + + + """, + keyframe=0, + ) + + # larger eps needed for float32 precision + eps = 1e-3 + ndx = 2 * mjm.nv + mjm.na + + # mujoco reference + A_mj = np.zeros((ndx, ndx)) + B_mj = np.zeros((ndx, mjm.nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, A_mj, B_mj, None, None) + + # mujoco warp + A_mjw = wp.zeros((1, ndx, ndx), dtype=float) + B_mjw = wp.zeros((1, ndx, mjm.nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, A_mjw, B_mjw, None, None) + + _assert_eq(A_mjw.numpy()[0], A_mj, "A") + _assert_eq(B_mjw.numpy()[0], B_mj, "B") + + @parameterized.parameters(False, True) + def test_transition_fd_sensor_derivatives(self, centered): + """Tests C and D matrices against MuJoCo mjd_transitionFD.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + """, + ) + + # larger eps needed for float32 precision + eps = 1e-3 + nv = mjm.nv + nu = mjm.nu + ns = mjm.nsensordata + ndx = 2 * nv + mjm.na + + # mujoco reference + C_mj = np.zeros((ns, ndx)) + D_mj = np.zeros((ns, nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, None, None, C_mj, D_mj) + + # mujoco warp + C_mjw = wp.zeros((1, ns, ndx), dtype=float) + D_mjw = wp.zeros((1, ns, nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, None, None, C_mjw, D_mjw) + + _assert_eq(C_mjw.numpy()[0], C_mj, "C") + _assert_eq(D_mjw.numpy()[0], D_mj, "D") + + @parameterized.parameters(False, True) + def test_transition_fd_clamped_ctrl(self, centered): + """Tests that B matrix is zero when ctrl is at or beyond limits.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + """, + ) + + eps = 1e-3 + nv = mjm.nv + nu = mjm.nu + ndx = 2 * nv + mjm.na + + # set ctrl beyond limits + mjd.ctrl[0] = 2.0 + d.ctrl.fill_(2.0) + + # mujoco reference - B should be zero + B_mj = np.zeros((ndx, nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, None, B_mj, None, None) + + # mujoco warp + B_mjw = wp.zeros((1, ndx, nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, None, B_mjw, None, None) + + # expect B to be zero since ctrl is beyond limits + _assert_eq(B_mjw.numpy()[0], B_mj, "B clamped") + np.testing.assert_allclose(B_mj, 0.0, atol=1e-10) + + def test_transition_fd_no_state_mutation(self): + """Tests that transition_fd does not mutate state.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + """, + keyframe=0, + ) + + # save state before + qpos_before = d.qpos.numpy().copy() + qvel_before = d.qvel.numpy().copy() + ctrl_before = d.ctrl.numpy().copy() + + # call transition_fd + eps = 1e-3 + ndx = 2 * m.nv + m.na + A = wp.zeros((1, ndx, ndx), dtype=float) + B = wp.zeros((1, ndx, m.nu), dtype=float) + mjw.transition_fd(m, d, eps, False, A, B, None, None) + + # check state unchanged + _assert_eq(d.qpos.numpy(), qpos_before, "qpos") + _assert_eq(d.qvel.numpy(), qvel_before, "qvel") + _assert_eq(d.ctrl.numpy(), ctrl_before, "ctrl") + + @parameterized.parameters(False, True) + def test_transition_fd_free_joint(self, centered): + """Tests A and B matrices with a free joint (quaternion perturbation).""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """, + keyframe=0, + ) + + eps = 1e-3 + ndx = 2 * mjm.nv + mjm.na + + # mujoco reference + A_mj = np.zeros((ndx, ndx)) + B_mj = np.zeros((ndx, mjm.nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, A_mj, B_mj, None, None) + + # mujoco warp + A_mjw = wp.zeros((1, ndx, ndx), dtype=float) + B_mjw = wp.zeros((1, ndx, mjm.nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, A_mjw, B_mjw, None, None) + + _assert_eq(A_mjw.numpy()[0], A_mj, "A free joint") + _assert_eq(B_mjw.numpy()[0], B_mj, "B free joint") + + @parameterized.parameters(False, True) + def test_transition_fd_activations(self, centered): + """Tests A and B matrices with actuator activations (na > 0).""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + """, + keyframe=0, + ) + + self.assertGreater(mjm.na, 0, "Model should have activations") + eps = 1e-3 + ndx = 2 * mjm.nv + mjm.na + + # mujoco reference + A_mj = np.zeros((ndx, ndx)) + B_mj = np.zeros((ndx, mjm.nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, A_mj, B_mj, None, None) + + # mujoco warp + A_mjw = wp.zeros((1, ndx, ndx), dtype=float) + B_mjw = wp.zeros((1, ndx, mjm.nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, A_mjw, B_mjw, None, None) + + _assert_eq(A_mjw.numpy()[0], A_mj, "A activations") + _assert_eq(B_mjw.numpy()[0], B_mj, "B activations") + + @parameterized.parameters(False, True) + def test_transition_fd_ctrl_preserved(self, centered): + """Tests that ctrl values are preserved despite internal clamping.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + """, + ) + + eps = 1e-3 + nv = mjm.nv + nu = mjm.nu + ndx = 2 * nv + mjm.na + + # set ctrl beyond limits + mjd.ctrl[0] = 2.0 + d.ctrl.fill_(2.0) + + # mujoco reference - B should be zero + B_mj = np.zeros((ndx, nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, None, B_mj, None, None) + + # mujoco warp + B_mjw = wp.zeros((1, ndx, nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, None, B_mjw, None, None) + + # expect B to be zero since ctrl is beyond limits + _assert_eq(B_mjw.numpy()[0], B_mj, "B beyond limit") + np.testing.assert_allclose(B_mj, 0.0, atol=1e-10) + + # verify ctrl preserved despite internal clamping during FD + np.testing.assert_allclose( + mjd.ctrl[0], + 2.0, + atol=1e-10, + err_msg="MuJoCo ctrl should not be modified", + ) + np.testing.assert_allclose( + d.ctrl.numpy()[0, 0], + 2.0, + atol=1e-10, + err_msg="Warp ctrl should not be modified", + ) + + def test_transition_fd_full_no_mutation(self): + """Tests state preservation with free joints, activations, time, sensors.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + + + + """, + keyframe=0, + ) + + self.assertGreater(mjm.na, 0, "Model should have activations") + self.assertGreater(mjm.nsensordata, 0, "Model should have sensors") + + # save state before + qpos_before = d.qpos.numpy().copy() + qvel_before = d.qvel.numpy().copy() + act_before = d.act.numpy().copy() + ctrl_before = d.ctrl.numpy().copy() + time_before = d.time.numpy().copy() + + # call transition_fd requesting all matrices + eps = 1e-3 + nv = m.nv + ns = m.nsensordata + ndx = 2 * nv + m.na + A = wp.zeros((1, ndx, ndx), dtype=float) + B = wp.zeros((1, ndx, m.nu), dtype=float) + C = wp.zeros((1, ns, ndx), dtype=float) + D = wp.zeros((1, ns, m.nu), dtype=float) + mjw.transition_fd(m, d, eps, False, A, B, C, D) + + # check all state fields unchanged + _assert_eq(d.qpos.numpy(), qpos_before, "qpos") + _assert_eq(d.qvel.numpy(), qvel_before, "qvel") + _assert_eq(d.act.numpy(), act_before, "act") + _assert_eq(d.ctrl.numpy(), ctrl_before, "ctrl") + _assert_eq(d.time.numpy(), time_before, "time") + if __name__ == "__main__": wp.init()