diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index f1da96088..2600eedd5 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -47,6 +47,10 @@ from mujoco_warp._src.forward import rungekutta4 as rungekutta4 from mujoco_warp._src.forward import step1 as step1 from mujoco_warp._src.forward import step2 as step2 +from mujoco_warp._src.history import init_ctrl_history as init_ctrl_history +from mujoco_warp._src.history import init_sensor_history as init_sensor_history +from mujoco_warp._src.history import read_ctrl as read_ctrl +from mujoco_warp._src.history import read_sensor as read_sensor from mujoco_warp._src.inverse import inverse as inverse from mujoco_warp._src.io import create_render_context as create_render_context from mujoco_warp._src.io import get_data_into as get_data_into diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index df4e98097..c08fd8f09 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -20,6 +20,7 @@ from mujoco_warp._src import collision_driver from mujoco_warp._src import constraint from mujoco_warp._src import derivative +from mujoco_warp._src import history from mujoco_warp._src import island from mujoco_warp._src import math from mujoco_warp._src import passive @@ -302,6 +303,9 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) outputs=[d.qpos], ) + # advance history buffers before time advance + history.insert_ctrl_history(m, d) + wp.launch( _next_time, dim=d.nworld, @@ -1098,6 +1102,13 @@ def fwd_actuation(m: Model, d: Data): d.actuator_force.zero_() return + # read delayed ctrl (or direct copy if no delay) + if m.nhistory > 0: + ctrl = wp.empty((d.nworld, m.nu), dtype=float) + history.read_ctrl_delayed(m, d, ctrl) + else: + ctrl = d.ctrl + wp.launch( _actuator_force, dim=(d.nworld, m.nu), @@ -1122,7 +1133,7 @@ def fwd_actuation(m: Model, d: Data): m.actuator_acc0, m.actuator_lengthrange, d.act, - d.ctrl, + ctrl, d.actuator_length, d.actuator_velocity, m.opt.disableflags & DisableBit.CLAMPCTRL, diff --git a/mujoco_warp/_src/history.py b/mujoco_warp/_src/history.py new file mode 100644 index 000000000..f45efa775 --- /dev/null +++ b/mujoco_warp/_src/history.py @@ -0,0 +1,925 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp + +from mujoco_warp._src.types import MJ_MAXVAL +from mujoco_warp._src.types import MJ_MINVAL +from mujoco_warp._src.types import Data +from mujoco_warp._src.types import Model + +wp.set_module_options({"enable_backward": False}) + + +@wp.func +def _history_physical_index(cursor: int, n: int, logical: int) -> int: + """Convert logical index (0=oldest, n-1=newest) to physical index.""" + return (cursor + 1 + logical) % n + + +@wp.func +def _history_find_index( + # In: + buf: wp.array2d[float], + worldid: int, + buf_offset: int, + n: int, + cursor: int, + t: float, +) -> int: + """Find logical index i such that times[i-1] < t <= times[i]. + + Returns 0 if t <= times[oldest], n if t > times[newest]. + Uses circular binary search matching MuJoCo C historyFindIndex. + """ + times_offset = buf_offset + 2 + + oldest_phys = _history_physical_index(cursor, n, 0) + newest_phys = _history_physical_index(cursor, n, n - 1) + t_oldest = buf[worldid, times_offset + oldest_phys] + t_newest = buf[worldid, times_offset + newest_phys] + + # before or at first element + if t <= t_oldest: + return 0 + + # after last element + if t > t_newest: + return n + + # circular binary search: find smallest logical i such that times[phys(i)] >= t + lo = int(0) + hi = int(n - 1) + while hi - lo > 1: + mid = int((lo + hi) >> 1) + mid_phys = _history_physical_index(cursor, n, mid) + if buf[worldid, times_offset + mid_phys] < t: + lo = mid + else: + hi = mid + + return hi + + +@wp.func +def _history_read_scalar( + # In: + buf: wp.array2d[float], + worldid: int, + buf_offset: int, + n: int, + t: float, + interp: int, +) -> float: + """Read a scalar value from history buffer at time t. + + interp: 0=zero-order-hold, 1=linear, 2=cubic (Catmull-Rom spline) + """ + cursor = int(buf[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + oldest_phys = _history_physical_index(cursor, n, 0) + newest_phys = _history_physical_index(cursor, n, n - 1) + t_oldest = buf[worldid, times_offset + oldest_phys] + t_newest = buf[worldid, times_offset + newest_phys] + + # extrapolate before oldest + if t <= t_oldest + 1e-6: + return buf[worldid, values_offset + oldest_phys] + + # extrapolate after newest + if t >= t_newest - 1e-6: + return buf[worldid, values_offset + newest_phys] + + # find bracketing index + i = _history_find_index(buf, worldid, buf_offset, n, cursor, t) + phys_i = _history_physical_index(cursor, n, i) + + # exact match + if wp.abs(t - buf[worldid, times_offset + phys_i]) < 1e-6: + return buf[worldid, values_offset + phys_i] + + phys_lo = _history_physical_index(cursor, n, i - 1) + phys_hi = phys_i + + # zero-order hold + if interp == 0: + return buf[worldid, values_offset + phys_lo] + + dt = buf[worldid, times_offset + phys_hi] - buf[worldid, times_offset + phys_lo] + alpha = (t - buf[worldid, times_offset + phys_lo]) / dt + v_lo = buf[worldid, values_offset + phys_lo] + v_hi = buf[worldid, values_offset + phys_hi] + + # linear interpolation + if interp == 1: + return v_lo + alpha * (v_hi - v_lo) + + # cubic spline interpolation (Catmull-Rom) + alpha2 = alpha * alpha + alpha3 = alpha2 * alpha + h00 = 2.0 * alpha3 - 3.0 * alpha2 + 1.0 + h10 = alpha3 - 2.0 * alpha2 + alpha + h01 = -2.0 * alpha3 + 3.0 * alpha2 + h11 = alpha3 - alpha2 + + # finite-differenced Catmull-Rom slopes, 0 at endpoints + m_lo = 0.0 + if i > 1: + phys_lo_prev = _history_physical_index(cursor, n, i - 2) + dt_lo = buf[worldid, times_offset + phys_hi] - buf[worldid, times_offset + phys_lo_prev] + m_lo = (v_hi - buf[worldid, values_offset + phys_lo_prev]) / dt_lo + + m_hi = 0.0 + if i < n - 1: + phys_hi_next = _history_physical_index(cursor, n, i + 1) + dt_hi = buf[worldid, times_offset + phys_hi_next] - buf[worldid, times_offset + phys_lo] + m_hi = (buf[worldid, values_offset + phys_hi_next] - v_lo) / dt_hi + + return h00 * v_lo + h10 * dt * m_lo + h01 * v_hi + h11 * dt * m_hi + + +@wp.func +def _history_read_vector( + # In: + adr: int, + buf: wp.array2d[float], + worldid: int, + buf_offset: int, + n: int, + dim: int, + t: float, + interp: int, + # Data out: + sensordata_out: wp.array2d[float], +) -> int: + """Read a vector value from history buffer at time t into sensordata. + + Returns 1 on success (value written to sensordata). + interp: 0=zero-order-hold, 1=linear, 2=cubic (Catmull-Rom spline) + """ + cursor = int(buf[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + oldest_phys = _history_physical_index(cursor, n, 0) + newest_phys = _history_physical_index(cursor, n, n - 1) + t_oldest = buf[worldid, times_offset + oldest_phys] + t_newest = buf[worldid, times_offset + newest_phys] + + # extrapolate before oldest: copy oldest + if t <= t_oldest + 1e-6: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + oldest_phys * dim + d] + return 1 + + # extrapolate after newest: copy newest + if t >= t_newest - 1e-6: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + newest_phys * dim + d] + return 1 + + # find bracketing index + i = _history_find_index(buf, worldid, buf_offset, n, cursor, t) + phys_i = _history_physical_index(cursor, n, i) + + # exact match + if wp.abs(t - buf[worldid, times_offset + phys_i]) < 1e-6: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + phys_i * dim + d] + return 1 + + phys_lo = _history_physical_index(cursor, n, i - 1) + phys_hi = phys_i + + # zero-order hold + if interp == 0: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + phys_lo * dim + d] + return 1 + + dt = buf[worldid, times_offset + phys_hi] - buf[worldid, times_offset + phys_lo] + alpha = (t - buf[worldid, times_offset + phys_lo]) / dt + + # linear interpolation + if interp == 1: + for d in range(dim): + v_lo = buf[worldid, values_offset + phys_lo * dim + d] + v_hi = buf[worldid, values_offset + phys_hi * dim + d] + sensordata_out[worldid, adr + d] = v_lo + alpha * (v_hi - v_lo) + return 1 + + # cubic spline interpolation (Catmull-Rom) + alpha2 = alpha * alpha + alpha3 = alpha2 * alpha + h00 = 2.0 * alpha3 - 3.0 * alpha2 + 1.0 + h10 = alpha3 - 2.0 * alpha2 + alpha + h01 = -2.0 * alpha3 + 3.0 * alpha2 + h11 = alpha3 - alpha2 + + for d in range(dim): + v_lo = buf[worldid, values_offset + phys_lo * dim + d] + v_hi = buf[worldid, values_offset + phys_hi * dim + d] + + # finite-differenced Catmull-Rom slopes, 0 at endpoints + m_lo = 0.0 + if i > 1: + phys_lo_prev = _history_physical_index(cursor, n, i - 2) + dt_lo = buf[worldid, times_offset + phys_hi] - buf[worldid, times_offset + phys_lo_prev] + m_lo = (v_hi - buf[worldid, values_offset + phys_lo_prev * dim + d]) / dt_lo + + m_hi = 0.0 + if i < n - 1: + phys_hi_next = _history_physical_index(cursor, n, i + 1) + dt_hi = buf[worldid, times_offset + phys_hi_next] - buf[worldid, times_offset + phys_lo] + m_hi = (buf[worldid, values_offset + phys_hi_next * dim + d] - v_lo) / dt_hi + + sensordata_out[worldid, adr + d] = h00 * v_lo + h10 * dt * m_lo + h01 * v_hi + h11 * dt * m_hi + return 1 + + +@wp.func +def _history_insert_scalar( + # In: + worldid: int, + buf_offset: int, + n: int, + t: float, + value: float, + # Out: + buf_out: wp.array2d[float], +): + """Insert a scalar value into history buffer at time t.""" + cursor = int(buf_out[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + i = _history_find_index(buf_out, worldid, buf_offset, n, cursor, t) + + # exact match + if i < n: + phys_i = _history_physical_index(cursor, n, i) + if wp.abs(t - buf_out[worldid, times_offset + phys_i]) < 1e-6: + buf_out[worldid, values_offset + phys_i] = value + return + + # older than oldest: replace oldest + if i == 0: + oldest_phys = _history_physical_index(cursor, n, 0) + buf_out[worldid, times_offset + oldest_phys] = t + buf_out[worldid, values_offset + oldest_phys] = value + return + + # newer than newest: advance cursor + if i == n: + cursor = (cursor + 1) % n + buf_out[worldid, buf_offset + 1] = float(cursor) + buf_out[worldid, times_offset + cursor] = t + buf_out[worldid, values_offset + cursor] = value + return + + # out-of-order: shift [1, i-1] left, insert at i-1 + for j in range(i - 1): + src_phys = _history_physical_index(cursor, n, j + 1) + dst_phys = _history_physical_index(cursor, n, j) + buf_out[worldid, times_offset + dst_phys] = buf_out[worldid, times_offset + src_phys] + buf_out[worldid, values_offset + dst_phys] = buf_out[worldid, values_offset + src_phys] + insert_phys = _history_physical_index(cursor, n, i - 1) + buf_out[worldid, times_offset + insert_phys] = t + buf_out[worldid, values_offset + insert_phys] = value + + +@wp.func +def _history_insert_vector( + # In: + worldid: int, + buf_offset: int, + n: int, + dim: int, + t: float, + src: wp.array2d[float], + src_adr: int, + # Out: + buf_out: wp.array2d[float], +): + """Insert a vector value from src[worldid, src_adr:src_adr+dim] into history buffer at time t.""" + cursor = int(buf_out[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + i = _history_find_index(buf_out, worldid, buf_offset, n, cursor, t) + + slot_phys = -1 + + # exact match + if i < n: + phys_i = _history_physical_index(cursor, n, i) + if wp.abs(t - buf_out[worldid, times_offset + phys_i]) < 1e-6: + slot_phys = phys_i + + if slot_phys < 0: + if i == 0: + # older than oldest: replace oldest + slot_phys = _history_physical_index(cursor, n, 0) + buf_out[worldid, times_offset + slot_phys] = t + elif i == n: + # newer than newest: advance cursor + cursor = (cursor + 1) % n + buf_out[worldid, buf_offset + 1] = float(cursor) + slot_phys = cursor + buf_out[worldid, times_offset + slot_phys] = t + else: + # out-of-order: shift [1, i-1] left, insert at i-1 + for j in range(i - 1): + src_phys = _history_physical_index(cursor, n, j + 1) + dst_phys = _history_physical_index(cursor, n, j) + buf_out[worldid, times_offset + dst_phys] = buf_out[worldid, times_offset + src_phys] + for d in range(dim): + buf_out[worldid, values_offset + dst_phys * dim + d] = buf_out[worldid, values_offset + src_phys * dim + d] + slot_phys = _history_physical_index(cursor, n, i - 1) + buf_out[worldid, times_offset + slot_phys] = t + + # copy values + for d in range(dim): + buf_out[worldid, values_offset + slot_phys * dim + d] = src[worldid, src_adr + d] + + +@wp.kernel +def _read_ctrl_delayed_kernel( + # Model: + actuator_history: wp.array[wp.vec2i], + actuator_historyadr: wp.array[int], + actuator_delay: wp.array[float], + # Data in: + time_in: wp.array[float], + history_in: wp.array2d[float], + ctrl_in: wp.array2d[float], + # Data out: + ctrl_out: wp.array2d[float], +): + """Read delayed ctrl for each actuator.""" + worldid, uid = wp.tid() + + hist = actuator_history[uid] + nsample = hist[0] + delay = actuator_delay[uid] + + if nsample == 0 or delay == 0.0: + # no delay: direct copy + ctrl_out[worldid, uid] = ctrl_in[worldid, uid] + else: + interp = hist[1] + buf_offset = actuator_historyadr[uid] + t = time_in[worldid] - delay + ctrl_out[worldid, uid] = _history_read_scalar(history_in, worldid, buf_offset, nsample, t, interp) + + +@wp.kernel +def _insert_ctrl_history_kernel( + # Model: + actuator_history: wp.array[wp.vec2i], + actuator_historyadr: wp.array[int], + # Data in: + time_in: wp.array[float], + ctrl_in: wp.array2d[float], + # Data out: + history_out: wp.array2d[float], +): + """Insert current ctrl into history buffers.""" + worldid, uid = wp.tid() + + hist = actuator_history[uid] + nsample = hist[0] + if nsample == 0: + return + + buf_offset = actuator_historyadr[uid] + t = time_in[worldid] + value = ctrl_in[worldid, uid] + _history_insert_scalar(worldid, buf_offset, nsample, t, value, history_out) + + +@wp.kernel +def _insert_sensor_history_stage( + # Model: + sensor_dim: wp.array[int], + sensor_adr: wp.array[int], + sensor_history: wp.array[wp.vec2i], + sensor_historyadr: wp.array[int], + sensor_delay: wp.array[float], + sensor_interval: wp.array[wp.vec2], + # Data in: + time_in: wp.array[float], + sensordata_in: wp.array2d[float], + # In: + sensor_ids: wp.array[int], + # Data out: + history_out: wp.array2d[float], +): + """Insert current sensor values into history buffers for specific sensor IDs.""" + worldid, idx = wp.tid() + sid = sensor_ids[idx] + + hist = sensor_history[sid] + nsample = hist[0] + if nsample == 0: + return + + buf_offset = sensor_historyadr[sid] + dim = sensor_dim[sid] + interval_val = sensor_interval[sid] + period = interval_val[0] + t = time_in[worldid] + + if period > 0.0: + # interval mode: check if condition is satisfied + time_prev = history_out[worldid, buf_offset] # user slot stores time_prev + if time_prev + period <= t: + # advance time_prev by exact period + history_out[worldid, buf_offset] = time_prev + period + # insert sensor value + _history_insert_vector(worldid, buf_offset, nsample, dim, t, sensordata_in, sensor_adr[sid], history_out) + else: + _history_insert_vector(worldid, buf_offset, nsample, dim, t, sensordata_in, sensor_adr[sid], history_out) + + +@wp.kernel +def _apply_sensor_delay_kernel( + # Model: + sensor_dim: wp.array[int], + sensor_adr: wp.array[int], + sensor_history: wp.array[wp.vec2i], + sensor_historyadr: wp.array[int], + sensor_delay: wp.array[float], + sensor_interval: wp.array[wp.vec2], + # Data in: + time_in: wp.array[float], + history_in: wp.array2d[float], + # In: + sensor_ids: wp.array[int], + # Data out: + sensordata_out: wp.array2d[float], +): + """Apply delay/interval logic for sensors after computation. + + TODO(team): Revisit always-compute decision for computationally expensive sensors + with interval/period (e.g., raytracers) + """ + worldid, idx = wp.tid() + sid = sensor_ids[idx] + + hist = sensor_history[sid] + nsample = hist[0] + if nsample <= 0: + return + + delay = sensor_delay[sid] + dim = sensor_dim[sid] + interp = hist[1] + buf_offset = sensor_historyadr[sid] + t = time_in[worldid] + + if delay > 0.0: + # delay > 0: read delayed value from buffer + _history_read_vector(sensor_adr[sid], history_in, worldid, buf_offset, nsample, dim, t - delay, interp, sensordata_out) + else: + # interval-only (delay == 0, interval > 0): check interval condition + interval_val = sensor_interval[sid] + period = interval_val[0] + if period > 0.0: + time_prev = history_in[worldid, buf_offset] # user slot + if time_prev + period > t: + # interval condition not satisfied: read from buffer + _history_read_vector(sensor_adr[sid], history_in, worldid, buf_offset, nsample, dim, t, interp, sensordata_out) + # else: interval condition satisfied, keep computed value + + +def read_ctrl_delayed(m: Model, d: Data, ctrl: wp.array2d[float]): + """Read delayed ctrl values for all actuators.""" + if m.nhistory == 0: + wp.copy(ctrl, d.ctrl) + return + + wp.launch( + _read_ctrl_delayed_kernel, + dim=(d.nworld, m.nu), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + m.actuator_delay, + d.time, + d.history, + d.ctrl, + ], + outputs=[ctrl], + ) + + +def insert_ctrl_history(m: Model, d: Data): + """Insert current ctrl values into history buffers.""" + if m.nhistory == 0 or m.nu == 0: + return + + wp.launch( + _insert_ctrl_history_kernel, + dim=(d.nworld, m.nu), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + d.time, + d.ctrl, + ], + outputs=[d.history], + ) + + +def apply_sensor_delay(m: Model, d: Data, sensorid: wp.array[int]): + """Apply delay/interval logic for given sensors after computation. + + Matches MuJoCo C architecture where the delayed read (mj_sensorPos) occurs + before the fresh value insert (mj_advance). We save fresh sensordata, + overwrite with delayed values, then insert the saved fresh values. + """ + if m.nhistory == 0 or sensorid.shape[0] == 0: + return + + # Save fresh sensordata before delay overwrite + fresh_sensordata = wp.empty_like(d.sensordata) + wp.copy(fresh_sensordata, d.sensordata) + + # Read delayed values from buffer → overwrite sensordata + wp.launch( + _apply_sensor_delay_kernel, + dim=(d.nworld, sensorid.shape[0]), + inputs=[ + m.sensor_dim, + m.sensor_adr, + m.sensor_history, + m.sensor_historyadr, + m.sensor_delay, + m.sensor_interval, + d.time, + d.history, + sensorid, + ], + outputs=[d.sensordata], + ) + + # Insert saved fresh sensor values into history buffers + wp.launch( + _insert_sensor_history_stage, + dim=(d.nworld, sensorid.shape[0]), + inputs=[ + m.sensor_dim, + m.sensor_adr, + m.sensor_history, + m.sensor_historyadr, + m.sensor_delay, + m.sensor_interval, + d.time, + fresh_sensordata, + sensorid, + ], + outputs=[d.history], + ) + + +@wp.kernel +def _read_ctrl_kernel( + # Model: + actuator_history: wp.array[wp.vec2i], + actuator_historyadr: wp.array[int], + actuator_delay: wp.array[float], + # Data in: + time_in: wp.array[float], + history_in: wp.array2d[float], + ctrl_in: wp.array2d[float], + # In: + uid: int, + interp: int, + # Out: + result_out: wp.array[float], +): + """Read delayed ctrl for 1 actuator across all worlds.""" + worldid = wp.tid() + + hist = actuator_history[uid] + nsample = hist[0] + + if nsample == 0: + result_out[worldid] = ctrl_in[worldid, uid] + else: + interp_val = interp + if interp_val < 0: + interp_val = hist[1] + delay = actuator_delay[uid] + buf_offset = actuator_historyadr[uid] + t = time_in[worldid] - delay + result_out[worldid] = _history_read_scalar(history_in, worldid, buf_offset, nsample, t, interp_val) + + +def read_ctrl( + m: Model, + d: Data, + ctrlid: int, + time: wp.array[float], + interp: int, + result: wp.array2d[float], +): + """Read delayed ctrl for 1 actuator across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + ctrlid: actuator index. + time: query time per world (nworld,). + interp: interpolation mode (-1=model default, 0=ZOH, 1=linear, 2=cubic). + result: output buffer (nworld,). + """ + wp.launch( + _read_ctrl_kernel, + dim=(d.nworld,), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + m.actuator_delay, + time, + d.history, + d.ctrl, + ctrlid, + interp, + ], + outputs=[result], + ) + + +@wp.kernel +def _read_sensor_kernel( + # Model: + sensor_dim: wp.array[int], + sensor_adr: wp.array[int], + sensor_history: wp.array[wp.vec2i], + sensor_historyadr: wp.array[int], + sensor_delay: wp.array[float], + # Data in: + time_in: wp.array[float], + history_in: wp.array2d[float], + sensordata_in: wp.array2d[float], + # In: + sid: int, + interp: int, + # Out: + result_out: wp.array2d[float], +): + """Read delayed sensor for 1 sensor across all worlds.""" + worldid = wp.tid() + + hist = sensor_history[sid] + nsample = hist[0] + dim = sensor_dim[sid] + adr = sensor_adr[sid] + + if nsample == 0: + for i in range(dim): + result_out[worldid, i] = sensordata_in[worldid, adr + i] + else: + interp_val = interp + if interp_val < 0: + interp_val = hist[1] + delay = sensor_delay[sid] + buf_offset = sensor_historyadr[sid] + t = time_in[worldid] - delay + _history_read_vector( + 0, # write to result_out starting at index 0 (not global sensor adr) + history_in, + worldid, + buf_offset, + nsample, + dim, + t, + interp_val, + result_out, + ) + + +def read_sensor( + m: Model, + d: Data, + sensorid: int, + time: wp.array[float], + interp: int, + result: wp.array2d[float], +): + """Read delayed sensor for 1 sensor across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + sensorid: sensor index. + time: query time per world (nworld,). + interp: interpolation mode (-1=model default, 0=ZOH, 1=linear, 2=cubic). + result: output buffer (nworld, dim). + """ + wp.launch( + _read_sensor_kernel, + dim=(d.nworld,), + inputs=[ + m.sensor_dim, + m.sensor_adr, + m.sensor_history, + m.sensor_historyadr, + m.sensor_delay, + time, + d.history, + d.sensordata, + sensorid, + interp, + ], + outputs=[result], + ) + + +@wp.kernel +def _init_ctrl_history_kernel( + # kernel_analyzer: off + # Model: + actuator_history: wp.array[wp.vec2i], + actuator_historyadr: wp.array[int], + # In: + ctrlid: int, + times: wp.array[float], + values: wp.array2d[float], + has_times: int, + # Data out: + history_out: wp.array2d[float], + # kernel_analyzer: on +): + """Initialize history buffer for 1 actuator across all worlds.""" + worldid = wp.tid() + + nsample = actuator_history[ctrlid][0] + buf_offset = actuator_historyadr[ctrlid] + + # preserve user slot + user = history_out[worldid, buf_offset] + + # cursor = 0 (samples in order, newest at index nsample-1) + history_out[worldid, buf_offset + 1] = float(nsample - 1) + + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + nsample + + for i in range(nsample): + if has_times != 0: + history_out[worldid, times_offset + i] = times[i] + else: + history_out[worldid, times_offset + i] = -MJ_MAXVAL + history_out[worldid, values_offset + i] = values[worldid, i] + + # restore user slot + history_out[worldid, buf_offset] = user + + +def init_ctrl_history( + m: Model, + d: Data, + ctrlid: int, + times: wp.array[float], + values: wp.array2d[float], +): + """Initialize history buffer for 1 actuator across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + ctrlid: actuator index. + times: timestamps or None (nsample,). + values: ctrl values (nworld, nsample). + + Raises: + ValueError: If times are not strictly increasing. + """ + has_times = 0 if times is None else 1 + if times is not None: + t_np = times.numpy() + for i in range(len(t_np) - 1): + if t_np[i + 1] - t_np[i] < MJ_MINVAL: + raise ValueError(f"times must be strictly increasing, got times[{i}]={t_np[i]} >= times[{i + 1}]={t_np[i + 1]}") + if times is None: + times = wp.empty(0, dtype=float) + + wp.launch( + _init_ctrl_history_kernel, + dim=(d.nworld,), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + ctrlid, + times, + values, + has_times, + ], + outputs=[d.history], + ) + + +# kernel_analyzer: off +@wp.kernel +def _init_sensor_history_kernel( + # Model: + sensor_history: wp.array[wp.vec2i], + sensor_historyadr: wp.array[int], + sensor_dim_arr: wp.array[int], + # In: + sensorid: int, + times: wp.array[float], + values: wp.array2d[float], + phase: wp.array[float], + has_times: int, + # Data out: + history_out: wp.array2d[float], +): + # kernel_analyzer: on + """Initialize history buffer for 1 sensor across all worlds.""" + worldid = wp.tid() + + nsample = sensor_history[sensorid][0] + dim = sensor_dim_arr[sensorid] + buf_offset = sensor_historyadr[sensorid] + + # set user slot (phase = last computation time for interval sensors) + history_out[worldid, buf_offset] = phase[worldid] + + # cursor = 0 (samples in order, newest at index nsample-1) + history_out[worldid, buf_offset + 1] = float(nsample - 1) + + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + nsample + + for i in range(nsample): + if has_times != 0: + history_out[worldid, times_offset + i] = times[i] + else: + history_out[worldid, times_offset + i] = -MJ_MAXVAL + for j in range(dim): + history_out[worldid, values_offset + i * dim + j] = values[worldid, i * dim + j] + + +def init_sensor_history( + m: Model, + d: Data, + sensorid: int, + times: wp.array[float], + values: wp.array2d[float], + phase: wp.array[float], +): + """Initialize history buffer for 1 sensor across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + sensorid: sensor index. + times: timestamps or None (nsample,). + values: sensor values (nworld, nsample * dim). + phase: user slot value per world (nworld,). + + Raises: + ValueError: If times are not strictly increasing. + """ + has_times = 0 if times is None else 1 + if times is not None: + t_np = times.numpy() + for i in range(len(t_np) - 1): + if t_np[i + 1] - t_np[i] < MJ_MINVAL: + raise ValueError(f"times must be strictly increasing, got times[{i}]={t_np[i]} >= times[{i + 1}]={t_np[i + 1]}") + if times is None: + times = wp.empty(0, dtype=float) + + wp.launch( + _init_sensor_history_kernel, + dim=(d.nworld,), + inputs=[ + m.sensor_history, + m.sensor_historyadr, + m.sensor_dim, + sensorid, + times, + values, + phase, + has_times, + ], + outputs=[d.history], + ) diff --git a/mujoco_warp/_src/history_test.py b/mujoco_warp/_src/history_test.py new file mode 100644 index 000000000..e4198f0ca --- /dev/null +++ b/mujoco_warp/_src/history_test.py @@ -0,0 +1,1114 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for actuator and sensor delay.""" + +import mujoco +import numpy as np +import warp as wp +from absl.testing import absltest +from absl.testing import parameterized + +from mujoco_warp import test_data +from mujoco_warp._src import forward +from mujoco_warp._src import history + +_TOLERANCE = 1e-8 + + +class PublicAPITest(absltest.TestCase): + """Test public delay API functions against MuJoCo C reference.""" + + def test_read_ctrl(self): + """Test read_ctrl matches mj_readCtrl.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + # step both with ctrl=10, then ctrl=20 + for ctrl_val in [10.0, 20.0, 30.0]: + mjd.ctrl[0] = ctrl_val + wp.copy(d.ctrl, wp.array(np.full((1, 1), ctrl_val), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + # compare read_ctrl at current time + time_arr = d.time + warp_result = wp.empty(d.nworld, dtype=float) + history.read_ctrl(m, d, 0, time_arr, interp=-1, result=warp_result) + mj_result = mujoco.mj_readCtrl(mjm, mjd, 0, mjd.time, -1) + np.testing.assert_allclose( + warp_result.numpy()[0], + mj_result, + atol=_TOLERANCE, + err_msg="read_ctrl mismatch", + ) + + # compare with explicit interp=0 (ZOH) + warp_result_zoh = wp.empty(d.nworld, dtype=float) + history.read_ctrl(m, d, 0, time_arr, interp=0, result=warp_result_zoh) + mj_result_zoh = mujoco.mj_readCtrl(mjm, mjd, 0, mjd.time, 0) + np.testing.assert_allclose( + warp_result_zoh.numpy()[0], + mj_result_zoh, + atol=_TOLERANCE, + err_msg="read_ctrl ZOH mismatch", + ) + + def test_read_sensor(self): + """Test read_sensor matches mj_readSensor.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + for i in range(4): + qpos_val = float((i + 1) * 10) + mjd.qpos[0] = qpos_val + wp.copy(d.qpos, wp.array(np.full((1, 1), qpos_val), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + # compare read_sensor at current time + dim = mjm.sensor_dim[0] + time_arr = d.time + result = wp.empty((d.nworld, dim), dtype=float) + history.read_sensor(m, d, 0, time_arr, interp=-1, result=result) + + mj_result_buf = np.zeros(dim) + ptr = mujoco.mj_readSensor(mjm, mjd, 0, mjd.time, mj_result_buf, -1) + mj_val = ptr if ptr is not None else mj_result_buf + + np.testing.assert_allclose( + result.numpy()[0], + mj_val, + atol=_TOLERANCE, + err_msg="read_sensor mismatch", + ) + + def test_read_sensor_arbitrary_times(self): + """Read sensor at arbitrary query times, compare against C.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + for i, qval in enumerate([1.0, 2.0, 3.0]): + mjd.qpos[0] = qval + wp.copy(d.qpos, wp.array(np.full((1, 1), qval), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + dim = mjm.sensor_dim[0] + for query_time, expected in [(0.0, 1.0), (0.01, 2.0), (0.02, 3.0)]: + time_arr = wp.array([query_time], dtype=float) + result = wp.empty((1, dim), dtype=float) + history.read_sensor(m, d, 0, time_arr, interp=0, result=result) + + mj_result_buf = np.zeros(dim) + ptr = mujoco.mj_readSensor(mjm, mjd, 0, query_time, mj_result_buf, 0) + mj_val = ptr if ptr is not None else mj_result_buf + + np.testing.assert_allclose(result.numpy()[0], mj_val, atol=_TOLERANCE, err_msg=f"read_sensor mismatch at t={query_time}") + + def test_read_sensor_second_index(self): + """Test read_sensor for sensor index > 0 (OOB bug regression test).""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + for i in range(4): + qpos_val = float((i + 1) * 10) + mjd.qpos[0] = qpos_val + mjd.qpos[1] = qpos_val * 2 + wp.copy(d.qpos, wp.array(np.array([[qpos_val, qpos_val * 2]]), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + # read sensor 1 (index > 0) at current time + dim = mjm.sensor_dim[1] + time_arr = d.time + result = wp.empty((1, dim), dtype=float) + history.read_sensor(m, d, 1, time_arr, interp=-1, result=result) + + mj_result_buf = np.zeros(dim) + ptr = mujoco.mj_readSensor(mjm, mjd, 1, mjd.time, mj_result_buf, -1) + mj_val = ptr if ptr is not None else mj_result_buf + + np.testing.assert_allclose( + result.numpy()[0], mj_val, atol=_TOLERANCE, err_msg="read_sensor mismatch for sensor index 1 (OOB regression)" + ) + + def test_init_ctrl_history(self): + """Test init_ctrl_history sets buffer correctly.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + # initialize with custom values + custom_times = np.array([0.1, 0.2, 0.3]) + custom_values = np.array([100.0, 200.0, 300.0]) + times_wp = wp.array(custom_times, dtype=float) + values_wp = wp.array(custom_values.reshape(1, -1), dtype=float) + history.init_ctrl_history(m, d, 0, times_wp, values_wp) + + # also init MuJoCo C side + mujoco.mj_initCtrlHistory(mjm, mjd, 0, custom_times, custom_values) + + # read at a time in the buffer + query_time = 0.23 # between samples → ZOH should return value at t=0.2 + time_arr = wp.array([query_time], dtype=float) + warp_result = wp.empty(d.nworld, dtype=float) + history.read_ctrl(m, d, 0, time_arr, interp=0, result=warp_result) + mj_result = mujoco.mj_readCtrl(mjm, mjd, 0, query_time, 0) + np.testing.assert_allclose( + warp_result.numpy()[0], + mj_result, + atol=_TOLERANCE, + err_msg="init_ctrl_history read mismatch", + ) + + def test_init_sensor_history(self): + """Test init_sensor_history sets buffer correctly.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + dim = mjm.sensor_dim[0] + + # initialize with custom values + custom_times = np.array([0.1, 0.2, 0.3]) + custom_values = np.array([100.0, 200.0, 300.0]) + phase = 0.05 + + times_wp = wp.array(custom_times, dtype=float) + values_wp = wp.array(custom_values.reshape(1, -1), dtype=float) + phase_wp = wp.array([phase], dtype=float) + history.init_sensor_history(m, d, 0, times_wp, values_wp, phase=phase_wp) + + # also init MuJoCo C side + mujoco.mj_initSensorHistory(mjm, mjd, 0, custom_times, custom_values, phase) + + # read at a time in the buffer + query_time = 0.23 + time_arr = wp.array([query_time], dtype=float) + result = wp.empty((1, dim), dtype=float) + history.read_sensor(m, d, 0, time_arr, interp=0, result=result) + + mj_result_buf = np.zeros(dim) + ptr = mujoco.mj_readSensor(mjm, mjd, 0, query_time, mj_result_buf, 0) + mj_val = ptr if ptr is not None else mj_result_buf + + np.testing.assert_allclose( + result.numpy()[0], + mj_val, + atol=_TOLERANCE, + err_msg="init_sensor_history read mismatch", + ) + + def test_actuator_history_only(self): + """Actuator with delay=0 records history but applies ctrl immediately.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertGreater(m.nhistory, 0) + + # ctrl applied immediately when delay=0 + for ctrl_val in [10.0, 20.0, 30.0]: + wp.copy(d.ctrl, wp.array(np.full((1, 1), ctrl_val), dtype=float)) + forward.step(m, d) + np.testing.assert_allclose(d.actuator_force.numpy()[0, 0], ctrl_val, atol=_TOLERANCE) + + # history buffer still readable + result = wp.empty(d.nworld, dtype=float) + time_arr = wp.array([0.015], dtype=float) + history.read_ctrl(m, d, 0, time_arr, interp=0, result=result) + np.testing.assert_allclose(result.numpy()[0], 20.0, atol=_TOLERANCE) + + def test_sensor_history_only(self): + """Sensor with delay=0 records history but reports current value.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertGreater(m.nhistory, 0) + + wp.copy(d.qpos, wp.array(np.full((1, 1), 10.0), dtype=float)) + forward.step(m, d) + np.testing.assert_allclose(d.sensordata.numpy()[0, 0], 10.0, atol=_TOLERANCE) + + # history buffer still readable + time_arr = wp.array([d.time.numpy()[0] - 0.005], dtype=float) + result = wp.empty((d.nworld, 1), dtype=float) + history.read_sensor(m, d, 0, time_arr, interp=0, result=result) + np.testing.assert_allclose(result.numpy()[0, 0], 10.0, atol=_TOLERANCE) + + +class MultiWorldDelayTest(parameterized.TestCase): + """Test delay with nworld > 1 and varying delay values.""" + + @parameterized.parameters(1, 2) + def test_actuator_delay(self, nworld): + for delay, nsample, nzero in ((0.02, 2, 2), (0.04, 5, 4)): + xml = f""" + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml, nworld=nworld) + + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), 10.0), dtype=float)) + + for i in range(nzero): + forward.step(m, d) + for w in range(nworld): + act_force = d.actuator_force.numpy()[w, 0] + np.testing.assert_allclose( + act_force, + 0.0, + atol=_TOLERANCE, + err_msg=f"nworld={nworld} delay={delay} world={w} step {i}", + ) + + forward.step(m, d) + for w in range(nworld): + act_force = d.actuator_force.numpy()[w, 0] + np.testing.assert_allclose( + act_force, + 10.0, + atol=_TOLERANCE, + err_msg=f"nworld={nworld} delay={delay} world={w} step {nzero}", + ) + + @parameterized.parameters(1, 2) + def test_sensor_delay(self, nworld): + for delay, nsample, expected in ( + (0.02, 3, [0.0, 0.0, 10.0, 20.0]), + (0.04, 5, [0.0, 0.0, 0.0, 0.0, 10.0]), + ): + xml = f""" + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml, nworld=nworld) + + for i in range(len(expected)): + qpos_val = float((i + 1) * 10) + wp.copy(d.qpos, wp.array(np.full((nworld, 1), qpos_val), dtype=float)) + forward.step(m, d) + for w in range(nworld): + sdata = d.sensordata.numpy()[w, 0] + np.testing.assert_allclose( + sdata, + expected[i], + atol=_TOLERANCE, + err_msg=f"nworld={nworld} delay={delay} world={w} step {i}", + ) + + +class MultiActuatorSensorDelayTest(absltest.TestCase): + """Test delay with multiple actuators/sensors with different delays.""" + + def test_multi_actuator_delay(self): + """2 actuators with different delays.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + # set both ctrl to 10 + mjd.ctrl[:] = 10.0 + wp.copy(d.ctrl, wp.array(np.full((1, 2), 10.0), dtype=float)) + + for i in range(5): + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + mj_force = mjd.actuator_force.copy() + warp_force = d.actuator_force.numpy()[0] + np.testing.assert_allclose(warp_force, mj_force, atol=_TOLERANCE, err_msg=f"actuator_force mismatch at step {i}") + + def test_multi_sensor_delay(self): + """2 sensors with different delays.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + for i in range(5): + qpos_val = float((i + 1) * 10) + mjd.qpos[:] = qpos_val + wp.copy(d.qpos, wp.array(np.full((1, mjm.nq), qpos_val), dtype=float)) + + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + mj_sdata = mjd.sensordata.copy() + warp_sdata = d.sensordata.numpy()[0, : mjm.nsensordata] + np.testing.assert_allclose(warp_sdata, mj_sdata, atol=_TOLERANCE, err_msg=f"sensordata mismatch at step {i}") + + +class InterpolationTest(parameterized.TestCase): + """Test linear and cubic interpolation against MuJoCo C reference.""" + + @parameterized.parameters( + ("linear", 1, 0.015, 3), + ("cubic", 2, 0.015, 5), + ) + def test_actuator_interp(self, interp_name, interp_val, delay, nsample): + xml = f""" + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertEqual(mjm.actuator_history[0, 1], interp_val) + + for i in range(6): + ctrl_val = float((i + 1) * 10) + mjd.ctrl[0] = ctrl_val + wp.copy(d.ctrl, wp.array(np.full((1, 1), ctrl_val), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + mj_force = mjd.actuator_force[0] + warp_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(warp_force, mj_force, atol=_TOLERANCE, err_msg=f"{interp_name} interp mismatch at step {i}") + + @parameterized.parameters( + ("linear", 1, 0.015, 3), + ("cubic", 2, 0.015, 5), + ) + def test_sensor_interp(self, interp_name, interp_val, delay, nsample): + xml = f""" + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertEqual(mjm.sensor_history[0, 1], interp_val) + + for i in range(6): + qpos_val = float((i + 1) * 10) + mjd.qpos[0] = qpos_val + wp.copy(d.qpos, wp.array(np.full((1, 1), qpos_val), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + mj_sdata = mjd.sensordata[0] + warp_sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(warp_sdata, mj_sdata, atol=_TOLERANCE, err_msg=f"{interp_name} interp mismatch at step {i}") + + def test_read_ctrl_cubic(self): + """read_ctrl with cubic interpolation, compared to MuJoCo C.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + for ctrl_val in [10.0, 20.0, 30.0, 40.0, 50.0]: + mjd.ctrl[0] = ctrl_val + wp.copy(d.ctrl, wp.array(np.full((1, 1), ctrl_val), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + # compare read_ctrl with cubic interp at current time + time_arr = d.time + warp_result = wp.empty(d.nworld, dtype=float) + history.read_ctrl(m, d, 0, time_arr, interp=2, result=warp_result) + mj_result = mujoco.mj_readCtrl(mjm, mjd, 0, mjd.time, 2) + np.testing.assert_allclose(warp_result.numpy()[0], mj_result, atol=_TOLERANCE, err_msg="read_ctrl cubic interp mismatch") + + +class SensorFeatureTest(absltest.TestCase): + """Test sensor-specific features: interval and multi-dimensional sensors.""" + + def test_sensor_delay_interval(self): + """Combined delay + interval, matching SensorDelayInterval C test.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertEqual(mjm.sensor_history[0, 0], 5) + np.testing.assert_allclose(mjm.sensor_delay[0], 0.02, atol=1e-10) + + # set position + mjd.qpos[0] = 5.0 + wp.copy(d.qpos, wp.array(np.full((1, 1), 5.0), dtype=float)) + + for i in range(3): + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + np.testing.assert_allclose(d.sensordata.numpy()[0, 0], mjd.sensordata[0], atol=_TOLERANCE, err_msg=f"step {i}") + + def test_sensor_delay_multi_dim(self): + """3D ballangvel with delay, matching SensorDelayMultiDim C test.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertEqual(mjm.sensor_dim[0], 3) + + # set angular velocity + mjd.qvel[:] = [1.0, 2.0, 3.0] + wp.copy(d.qvel, wp.array(np.array([[1.0, 2.0, 3.0]]), dtype=float)) + + for _ in range(4): + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + warp_sdata = d.sensordata.numpy()[0, :3] + np.testing.assert_allclose(warp_sdata, mjd.sensordata[:3], atol=_TOLERANCE) + + +class MonotonicityCheckTest(absltest.TestCase): + """Test strict monotonicity check in init functions.""" + + def test_init_ctrl_history_monotonicity(self): + xml = """ + + + """ + _, _, m, d = test_data.fixture(xml=xml) + + values = wp.array(np.full((1, 3), 1.0), dtype=float) + + # non-monotonic and equal times should raise + for bad in ([0.1, 0.05, 0.2], [0.1, 0.1, 0.2]): + with self.assertRaises(ValueError): + history.init_ctrl_history(m, d, 0, wp.array(bad, dtype=float), values) + + # good times should succeed + history.init_ctrl_history(m, d, 0, wp.array([0.1, 0.2, 0.3], dtype=float), values) + + def test_init_sensor_history_monotonicity(self): + xml = """ + + + """ + _, _, m, d = test_data.fixture(xml=xml) + + values = wp.array(np.full((1, 3), 1.0), dtype=float) + phase = wp.array([0.0], dtype=float) + + with self.assertRaises(ValueError): + history.init_sensor_history(m, d, 0, wp.array([0.3, 0.2, 0.1], dtype=float), values, phase=phase) + + +class BufferMechanicsTest(parameterized.TestCase): + """Test buffer mechanics: save/restore, out-of-order insertion.""" + + @parameterized.parameters(1, 2) + def test_state_save_restore_actuator(self, nworld): + """Save/restore history state gives identical actuator delay results.""" + xml = """ + + + """ + _, _, m, d = test_data.fixture(xml=xml, nworld=nworld) + + # step 3 times with ctrl=10 + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), 10.0), dtype=float)) + for _ in range(3): + forward.step(m, d) + + # save state + saved_history = wp.empty_like(d.history) + wp.copy(saved_history, d.history) + saved_time = d.time.numpy().copy() + saved_qpos = d.qpos.numpy().copy() + saved_qvel = d.qvel.numpy().copy() + saved_qacc_warmstart = d.qacc_warmstart.numpy().copy() + + # step 3 more times with ctrl=20, record final force + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), 20.0), dtype=float)) + for _ in range(3): + forward.step(m, d) + force_after = d.actuator_force.numpy().copy() + + # restore saved state + wp.copy(d.history, saved_history) + wp.copy(d.time, wp.array(saved_time, dtype=float)) + wp.copy(d.qpos, wp.array(saved_qpos, dtype=float)) + wp.copy(d.qvel, wp.array(saved_qvel, dtype=float)) + wp.copy(d.qacc_warmstart, wp.array(saved_qacc_warmstart, dtype=float)) + + # step 3 more times with same ctrl=20 + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), 20.0), dtype=float)) + for _ in range(3): + forward.step(m, d) + force_restored = d.actuator_force.numpy().copy() + + np.testing.assert_allclose( + force_after, + force_restored, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}: force mismatch after state restore", + ) + + @parameterized.parameters(1, 2) + def test_state_save_restore_sensor(self, nworld): + """Save/restore history state gives identical sensor delay results.""" + xml = """ + + + """ + _, _, m, d = test_data.fixture(xml=xml, nworld=nworld) + + # step 3 times with ctrl=5 + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), 5.0), dtype=float)) + for _ in range(3): + forward.step(m, d) + + # save state + saved_history = wp.empty_like(d.history) + wp.copy(saved_history, d.history) + saved_time = d.time.numpy().copy() + saved_qpos = d.qpos.numpy().copy() + saved_qvel = d.qvel.numpy().copy() + saved_qacc_warmstart = d.qacc_warmstart.numpy().copy() + + # step 3 more times with ctrl=10 + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), 10.0), dtype=float)) + for _ in range(3): + forward.step(m, d) + sensor_after = d.sensordata.numpy().copy() + + # restore saved state + wp.copy(d.history, saved_history) + wp.copy(d.time, wp.array(saved_time, dtype=float)) + wp.copy(d.qpos, wp.array(saved_qpos, dtype=float)) + wp.copy(d.qvel, wp.array(saved_qvel, dtype=float)) + wp.copy(d.qacc_warmstart, wp.array(saved_qacc_warmstart, dtype=float)) + + # step 3 more times with same ctrl=10 + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), 10.0), dtype=float)) + for _ in range(3): + forward.step(m, d) + sensor_restored = d.sensordata.numpy().copy() + + np.testing.assert_allclose( + sensor_after, + sensor_restored, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}: sensor mismatch after state restore", + ) + + @parameterized.parameters(1, 2) + def test_insert_and_read_at_middle_time(self, nworld): + """Insert values at specific times, then read at a time between them.""" + xml = """ + + + """ + _, _, m, d = test_data.fixture(xml=xml, nworld=nworld) + + times = wp.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=float) + values = wp.array(np.tile([1.0, 2.0, 3.0, 4.0, 5.0], (nworld, 1)), dtype=float) + history.init_ctrl_history(m, d, 0, times, values) + + # read_ctrl reads at (query_time - delay). delay=0.02. + delay = 0.02 + for read_t, expected in [(0.25, 2.5), (0.35, 3.5)]: + query_time = wp.array(np.full(nworld, read_t + delay), dtype=float) + result = wp.empty((nworld,), dtype=float) + history.read_ctrl(m, d, 0, query_time, 1, result) + for w in range(nworld): + np.testing.assert_allclose( + result.numpy()[w], + expected, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}, world={w}: read at t={read_t}", + ) + + def test_init_replace_on_collision(self): + """Initializing with same times replaces values (exact match path).""" + xml = """ + + + """ + _, _, m, d = test_data.fixture(xml=xml) + + times = wp.array([0.1, 0.2, 0.3], dtype=float) + delay = 0.02 + query = wp.array([0.2 + delay], dtype=float) + result = wp.empty((1,), dtype=float) + + # init with values [1, 2, 3], read at t=0.2 + history.init_ctrl_history(m, d, 0, times, wp.array(np.array([[1.0, 2.0, 3.0]]), dtype=float)) + history.read_ctrl(m, d, 0, query, 0, result) + np.testing.assert_allclose(result.numpy()[0], 2.0, atol=_TOLERANCE) + + # re-init with same times but different values [10, 20, 30] + history.init_ctrl_history(m, d, 0, times, wp.array(np.array([[10.0, 20.0, 30.0]]), dtype=float)) + history.read_ctrl(m, d, 0, query, 0, result) + np.testing.assert_allclose(result.numpy()[0], 20.0, atol=_TOLERANCE) + + +class StressWrapTest(parameterized.TestCase): + """Stress test: many steps to exercise circular buffer wrapping.""" + + @parameterized.parameters(1, 2) + def test_long_actuator_delay_wrapping(self, nworld): + """25 steps with nsample=5 to exercise full circular buffer wrapping. + + Compare each step's actuator_force against MuJoCo C reference. + Uses nsample=5 (more than minimum) to exercise wrapping with excess slots. + """ + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml, nworld=nworld) + + nsteps = 25 + for step_i in range(nsteps): + ctrl_val = float((step_i + 1) * 3.0) # varying ctrl + + # C reference + mjd.ctrl[0] = ctrl_val + mujoco.mj_step(mjm, mjd) + + # Warp + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), ctrl_val), dtype=float)) + forward.step(m, d) + + mj_force = mjd.actuator_force[0] + warp_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose( + warp_force, + mj_force, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}, step {step_i}: force {warp_force} vs C {mj_force}", + ) + + @parameterized.parameters(1, 2) + def test_long_sensor_delay_wrapping(self, nworld): + """25 steps with nsample=5 to exercise full sensor circular buffer wrapping.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml, nworld=nworld) + + nsteps = 25 + for step_i in range(nsteps): + ctrl_val = float((step_i + 1) * 2.0) + + mjd.ctrl[0] = ctrl_val + mujoco.mj_step(mjm, mjd) + + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), ctrl_val), dtype=float)) + forward.step(m, d) + + mj_sensor = mjd.sensordata[0] + warp_sensor = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose( + warp_sensor, + mj_sensor, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}, step {step_i}: sensor {warp_sensor} vs C {mj_sensor}", + ) + + +class ActivationDelayTest(parameterized.TestCase): + """Test combining activation dynamics with actuator delay.""" + + @parameterized.parameters(1, 2) + def test_filter_with_delay(self, nworld): + """Actuator with dyntype=filter and delay should match C reference. + + The activation dynamics da/dt = (u-a)/tau are independent of delay, + but the ctrl signal fed to dynamics is the delayed ctrl. + """ + xml = """ + + + """ + mjm = mujoco.MjModel.from_xml_string(xml) + mjd = mujoco.MjData(mjm) + _, _, m, d = test_data.fixture(xml=xml, nworld=nworld) + + nsteps = 15 + for step_i in range(nsteps): + ctrl_val = 1.0 + + mjd.ctrl[0] = ctrl_val + mujoco.mj_step(mjm, mjd) + + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), ctrl_val), dtype=float)) + forward.step(m, d) + + # compare activation + mj_act = mjd.act[0] + warp_act = d.act.numpy()[0, 0] + np.testing.assert_allclose( + warp_act, + mj_act, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}, step {step_i}: act {warp_act} vs C {mj_act}", + ) + + # compare force + mj_force = mjd.actuator_force[0] + warp_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose( + warp_force, + mj_force, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}, step {step_i}: force {warp_force} vs C {mj_force}", + ) + + @parameterized.parameters(1, 2) + def test_integrator_with_delay(self, nworld): + """Actuator with dyntype=integrator and delay should match C reference.""" + xml = """ + + + """ + mjm = mujoco.MjModel.from_xml_string(xml) + mjd = mujoco.MjData(mjm) + _, _, m, d = test_data.fixture(xml=xml, nworld=nworld) + + nsteps = 10 + for step_i in range(nsteps): + ctrl_val = 5.0 + + mjd.ctrl[0] = ctrl_val + mujoco.mj_step(mjm, mjd) + + wp.copy(d.ctrl, wp.array(np.full((nworld, 1), ctrl_val), dtype=float)) + forward.step(m, d) + + mj_act = mjd.act[0] + warp_act = d.act.numpy()[0, 0] + np.testing.assert_allclose( + warp_act, + mj_act, + atol=_TOLERANCE, + err_msg=f"nworld={nworld}, step {step_i}: act {warp_act} vs C {mj_act}", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 6ebb7c08c..a7ed46a7d 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -17,6 +17,7 @@ import warp as wp +from mujoco_warp._src import history from mujoco_warp._src import math from mujoco_warp._src import ray from mujoco_warp._src import smooth @@ -917,6 +918,10 @@ def sensor_pos(m: Model, d: Data): ], ) + # apply sensor delay/interval for position sensors + history.apply_sensor_delay(m, d, m.sensor_pos_adr) + history.apply_sensor_delay(m, d, m.sensor_limitpos_adr) + if m.callback.sensor: m.callback.sensor(m, d, Stage.POS) @@ -1459,6 +1464,10 @@ def sensor_vel(m: Model, d: Data): ], ) + # apply sensor delay/interval for velocity sensors + history.apply_sensor_delay(m, d, m.sensor_vel_adr) + history.apply_sensor_delay(m, d, m.sensor_limitvel_adr) + if m.callback.sensor: m.callback.sensor(m, d, Stage.VEL) @@ -2709,6 +2718,10 @@ def sensor_acc(m: Model, d: Data): ], ) + # apply sensor delay/interval for acceleration sensors + history.apply_sensor_delay(m, d, m.sensor_acc_adr) + history.apply_sensor_delay(m, d, m.sensor_limitfrc_adr) + if m.callback.sensor: m.callback.sensor(m, d, Stage.ACC) diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index e16f50388..2b855f392 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -614,6 +614,7 @@ class State(enum.IntEnum): QPOS: position QVEL: velocity ACT: actuator activation + HISTORY: delay/interval history buffers WARMSTART: acceleration used for warmstart CTRL: control QFRC_APPLIED: applied generalized force @@ -622,7 +623,7 @@ class State(enum.IntEnum): MOCAP_POS: positions of mocap bodies MOCAP_QUAT: orientations of mocap bodies NSTATE: number of state elements - PHYSICS: QPOS | QVEL | ACT + PHYSICS: TIME | QPOS | QVEL | ACT | HISTORY FULLPHYSICS: TIME | PHYSICS | PLUGIN USER: CTRL | QFRC_APPLIED | XFRC_APPLIED | EQ_ACTIVE | MOCAP_POS | MOCAP_QUAT | USERDATA INTEGRATION: FULLPHYSICS | USER | WARMSTART @@ -632,6 +633,7 @@ class State(enum.IntEnum): QPOS = mujoco.mjtState.mjSTATE_QPOS QVEL = mujoco.mjtState.mjSTATE_QVEL ACT = mujoco.mjtState.mjSTATE_ACT + HISTORY = mujoco.mjtState.mjSTATE_HISTORY WARMSTART = mujoco.mjtState.mjSTATE_WARMSTART CTRL = mujoco.mjtState.mjSTATE_CTRL QFRC_APPLIED = mujoco.mjtState.mjSTATE_QFRC_APPLIED @@ -640,7 +642,7 @@ class State(enum.IntEnum): MOCAP_POS = mujoco.mjtState.mjSTATE_MOCAP_POS MOCAP_QUAT = mujoco.mjtState.mjSTATE_MOCAP_QUAT NSTATE = mujoco.mjtState.mjNSTATE - PHYSICS = mujoco.mjtState.mjSTATE_PHYSICS + PHYSICS = mujoco.mjtState.mjSTATE_PHYSICS # includes HISTORY FULLPHYSICS = mujoco.mjtState.mjSTATE_FULLPHYSICS USER = mujoco.mjtState.mjSTATE_USER INTEGRATION = mujoco.mjtState.mjSTATE_INTEGRATION @@ -896,6 +898,7 @@ class Model: nJmom: number of non-zeros in actuator_moment ngravcomp: number of bodies with nonzero gravcomp nsensordata: number of elements in sensor data vector + nhistory: number of history buffer entries opt: physics options stat: model statistics qpos0: qpos values at default pose (*, nq) @@ -1127,6 +1130,9 @@ class Model: actuator_trnid: transmission id: joint, tendon, site (nu, 2) actuator_actadr: first activation address; -1: stateless (nu,) actuator_actnum: number of activation variables (nu,) + actuator_history: history buffer sizes (nu, 2) + actuator_historyadr: history buffer address (nu,) + actuator_delay: delay in seconds (nu,) actuator_ctrllimited: is control limited (nu,) actuator_forcelimited: is force limited (nu,) actuator_actlimited: is activation limited (nu,) @@ -1151,6 +1157,10 @@ class Model: sensor_dim: number of scalar outputs (nsensor,) sensor_adr: address in sensor array (nsensor,) sensor_cutoff: cutoff for real and positive; 0: ignore (nsensor,) + sensor_history: history buffer sizes (nsensor, 2) + sensor_historyadr: history buffer address (nsensor,) + sensor_delay: delay in seconds (nsensor,) + sensor_interval: sensor interval and phase (nsensor, 2) plugin: globally registered plugin slot number (nplugin,) plugin_attr: config attributes of geom plugin (nplugin, _NPLUGINATTR) M_rownnz: number of non-zeros in each row of qM (nv,) @@ -1302,6 +1312,7 @@ class Model: nJmom: int ngravcomp: int nsensordata: int + nhistory: int opt: Option stat: Statistic qpos0: array("*", "nq", float) @@ -1533,6 +1544,9 @@ class Model: actuator_trnid: array("nu", wp.vec2i) actuator_actadr: array("nu", int) actuator_actnum: array("nu", int) + actuator_history: array("nu", wp.vec2i) + actuator_historyadr: array("nu", int) + actuator_delay: array("nu", float) actuator_ctrllimited: array("nu", bool) actuator_forcelimited: array("nu", bool) actuator_actlimited: array("nu", bool) @@ -1557,6 +1571,10 @@ class Model: sensor_dim: array("nsensor", int) sensor_adr: array("nsensor", int) sensor_cutoff: array("nsensor", float) + sensor_history: array("nsensor", wp.vec2i) + sensor_historyadr: array("nsensor", int) + sensor_delay: array("nsensor", float) + sensor_interval: array("nsensor", wp.vec2) plugin: array("nplugin", int) plugin_attr: array("nplugin", vec_pluginattr) M_rownnz: array("nv", int) @@ -1764,6 +1782,7 @@ class Data: qpos: position (nworld, nq) qvel: velocity (nworld, nv) act: actuator activation (nworld, na) + history: history buffer for delays (nworld, nhistory) qacc_warmstart: acceleration used for warmstart (nworld, nv) ctrl: control (nworld, nu) qfrc_applied: applied generalized force (nworld, nv) @@ -1862,6 +1881,7 @@ class Data: qpos: array("nworld", "nq", float) qvel: array("nworld", "nv", float) act: array("nworld", "na", float) + history: array("nworld", "nhistory", float) qacc_warmstart: array("nworld", "nv", float) ctrl: array("nworld", "nu", float) qfrc_applied: array("nworld", "nv", float)