Skip to content

Commit 0962084

Browse files
Optimization: solve_LD_sparse (google-deepmind#1260)
* Fused solve_LD_sparse kernels * Fix bug in new fused kernel * Fix formatting * Fix new kernel on CPU * Change to use wp.block_dim
1 parent b6dd9df commit 0962084

3 files changed

Lines changed: 84 additions & 42 deletions

File tree

mujoco_warp/_src/io.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,15 @@ def _check_margin(name, t1, t2, margin):
583583
Madr_ki -= 1
584584
m.qLD_updates = tuple(wp.array(qLD_updates[i], dtype=wp.vec3i) for i in sorted(qLD_updates))
585585

586+
# Build concatenated updates for fused kernel
587+
all_updates_flat = []
588+
level_offsets = [0]
589+
for level in sorted(qLD_updates):
590+
all_updates_flat.extend(qLD_updates[level])
591+
level_offsets.append(len(all_updates_flat))
592+
m.qLD_all_updates = all_updates_flat if all_updates_flat else [(0, 0, 0)]
593+
m.qLD_level_offsets = level_offsets
594+
586595
# indices for sparse qM_fullm (used in solver)
587596
m.qM_fullm_i, m.qM_fullm_j = [], []
588597
for i in range(mjm.nv):

mujoco_warp/_src/smooth.py

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2693,43 +2693,65 @@ def transmission(m: Model, d: Data):
26932693
)
26942694

26952695

2696-
@wp.kernel
2697-
def _solve_LD_sparse_x_acc_up(
2698-
# In:
2699-
L: wp.array3d(dtype=float),
2700-
qLD_updates_: wp.array(dtype=wp.vec3i),
2701-
# Out:
2702-
x: wp.array2d(dtype=float),
2703-
):
2704-
worldid, nodeid = wp.tid()
2705-
update = qLD_updates_[nodeid]
2706-
i, k, Madr_ki = update[0], update[1], update[2]
2707-
wp.atomic_sub(x[worldid], i, L[worldid, 0, Madr_ki] * x[worldid, k])
2708-
2709-
2710-
@wp.kernel
2711-
def _solve_LD_sparse_qLDiag_mul(
2712-
# In:
2713-
D: wp.array2d(dtype=float),
2714-
# Out:
2715-
out: wp.array2d(dtype=float),
2716-
):
2717-
worldid, dofid = wp.tid()
2718-
out[worldid, dofid] *= D[worldid, dofid]
2696+
@cache_kernel
2697+
def _solve_LD_sparse_fused(nv: int, nlevels: int):
2698+
"""Fused sparse backsubstitution: UP + diag + DOWN in one kernel."""
27192699

2700+
@wp.func_native(snippet="WP_TILE_SYNC();")
2701+
def _syncthreads():
2702+
pass
27202703

2721-
@wp.kernel
2722-
def _solve_LD_sparse_x_acc_down(
2723-
# In:
2724-
L: wp.array3d(dtype=float),
2725-
qLD_updates_: wp.array(dtype=wp.vec3i),
2726-
# Out:
2727-
x: wp.array2d(dtype=float),
2728-
):
2729-
worldid, nodeid = wp.tid()
2730-
update = qLD_updates_[nodeid]
2731-
i, k, Madr_ki = update[0], update[1], update[2]
2732-
wp.atomic_sub(x[worldid], k, L[worldid, 0, Madr_ki] * x[worldid, i])
2704+
@wp.kernel(module="unique", enable_backward=False)
2705+
def kernel(
2706+
# In:
2707+
L: wp.array3d(dtype=float),
2708+
D: wp.array2d(dtype=float),
2709+
all_updates: wp.array(dtype=wp.vec3i),
2710+
level_offsets: wp.array(dtype=int),
2711+
y: wp.array2d(dtype=float),
2712+
# Out:
2713+
x_out: wp.array2d(dtype=float),
2714+
):
2715+
worldid, tid = wp.tid()
2716+
NV = wp.static(nv)
2717+
NLEVELS = wp.static(nlevels)
2718+
BLOCK_DIM = wp.block_dim()
2719+
2720+
# Copy y to x_out
2721+
for dofid in range(tid, NV, BLOCK_DIM):
2722+
x_out[worldid, dofid] = y[worldid, dofid]
2723+
_syncthreads()
2724+
2725+
# Forward substitution
2726+
for level in range(NLEVELS):
2727+
level_idx = NLEVELS - 1 - level
2728+
level_offset = level_offsets[level_idx]
2729+
level_size = level_offsets[level_idx + 1] - level_offset
2730+
2731+
for u in range(tid, level_size, BLOCK_DIM):
2732+
update = all_updates[level_offset + u]
2733+
i, k, Madr_ki = update[0], update[1], update[2]
2734+
wp.atomic_sub(x_out[worldid], i, L[worldid, 0, Madr_ki] * x_out[worldid, k])
2735+
_syncthreads()
2736+
2737+
# Diagonal multiply
2738+
for dofid in range(tid, NV, BLOCK_DIM):
2739+
x_out[worldid, dofid] *= D[worldid, dofid]
2740+
_syncthreads()
2741+
2742+
# Backward substitution
2743+
for level in range(NLEVELS):
2744+
level_idx = level
2745+
level_offset = level_offsets[level_idx]
2746+
level_size = level_offsets[level_idx + 1] - level_offset
2747+
2748+
for u in range(tid, level_size, BLOCK_DIM):
2749+
update = all_updates[level_offset + u]
2750+
i, k, Madr_ki = update[0], update[1], update[2]
2751+
wp.atomic_sub(x_out[worldid], k, L[worldid, 0, Madr_ki] * x_out[worldid, i])
2752+
_syncthreads()
2753+
2754+
return kernel
27332755

27342756

27352757
def _solve_LD_sparse(
@@ -2741,14 +2763,20 @@ def _solve_LD_sparse(
27412763
y: wp.array2d(dtype=float),
27422764
):
27432765
"""Computes sparse backsubstitution: x = inv(L'*D*L)*y."""
2744-
wp.copy(x, y)
2745-
for qLD_updates in reversed(m.qLD_updates):
2746-
wp.launch(_solve_LD_sparse_x_acc_up, dim=(d.nworld, qLD_updates.size), inputs=[L, qLD_updates], outputs=[x])
2747-
2748-
wp.launch(_solve_LD_sparse_qLDiag_mul, dim=(d.nworld, m.nv), inputs=[D], outputs=[x])
2766+
nlevels = len(m.qLD_updates)
2767+
if wp.get_device().is_cuda:
2768+
dim_block = m.block_dim.solve_LD_sparse_fused
2769+
else:
2770+
# Fallback for CPU
2771+
dim_block = 1
27492772

2750-
for qLD_updates in m.qLD_updates:
2751-
wp.launch(_solve_LD_sparse_x_acc_down, dim=(d.nworld, qLD_updates.size), inputs=[L, qLD_updates], outputs=[x])
2773+
wp.launch(
2774+
_solve_LD_sparse_fused(m.nv, nlevels),
2775+
dim=(d.nworld, dim_block),
2776+
inputs=[L, D, m.qLD_all_updates, m.qLD_level_offsets, y],
2777+
outputs=[x],
2778+
block_dim=dim_block,
2779+
)
27522780

27532781

27542782
@cache_kernel

mujoco_warp/_src/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class BlockDim:
5959
cholesky_factorize: int = 32
6060
cholesky_solve: int = 32
6161
cholesky_factorize_solve: int = 32
62+
solve_LD_sparse_fused: int = 64
6263
# solver
6364
update_gradient_cholesky: int = 64
6465
update_gradient_cholesky_blocked: int = 32
@@ -1204,6 +1205,8 @@ class Model:
12041205
taxel_sensorid: address for tactile sensors
12051206
qM_tiles: tiling configuration
12061207
qLD_updates: tuple of index triples for sparse factorization
1208+
qLD_all_updates: tuple of all levels concatenated
1209+
qLD_level_offsets: tuple of start offsets for each level
12071210
qM_fullm_i: sparse mass matrix addressing
12081211
qM_fullm_j: sparse mass matrix addressing
12091212
qM_mulm_rowadr: sparse matmul row pointers
@@ -1580,6 +1583,8 @@ class Model:
15801583
taxel_sensorid: wp.array(dtype=int)
15811584
qM_tiles: tuple[TileSet, ...]
15821585
qLD_updates: tuple[wp.array(dtype=wp.vec3i), ...]
1586+
qLD_all_updates: wp.array(dtype=wp.vec3i)
1587+
qLD_level_offsets: wp.array(dtype=int)
15831588
qM_fullm_i: wp.array(dtype=int)
15841589
qM_fullm_j: wp.array(dtype=int)
15851590
# Gather-based sparse mul_m indices (thread per DOF, no atomics)

0 commit comments

Comments
 (0)