Skip to content

Commit 6f235d4

Browse files
authored
restore cache_kernel (#1318)
1 parent fc8b0d6 commit 6f235d4

10 files changed

Lines changed: 57 additions & 0 deletions

File tree

mujoco_warp/_src/collision_convex.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from mujoco_warp._src.types import mat43
4343
from mujoco_warp._src.types import mat63
4444
from mujoco_warp._src.types import vec5
45+
from mujoco_warp._src.warp_util import cache_kernel
4546
from mujoco_warp._src.warp_util import event_scope
4647

4748
# TODO(team): improve compile time to enable backward pass
@@ -153,6 +154,7 @@ def _hfield_filter(
153154
return False, xmin, xmax, ymin, ymax, zmin, zmax
154155

155156

157+
@cache_kernel
156158
def ccd_hfield_kernel_builder(
157159
geomtype1: int,
158160
geomtype2: int,
@@ -695,6 +697,7 @@ def ccd_hfield_kernel(
695697
return ccd_hfield_kernel
696698

697699

700+
@cache_kernel
698701
def ccd_kernel_builder(
699702
geomtype1: int,
700703
geomtype2: int,

mujoco_warp/_src/collision_driver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from mujoco_warp._src.types import Model
3535
from mujoco_warp._src.types import mat23
3636
from mujoco_warp._src.types import mat63
37+
from mujoco_warp._src.warp_util import cache_kernel
3738
from mujoco_warp._src.warp_util import event_scope
3839

3940
wp.set_module_options({"enable_backward": False})
@@ -441,6 +442,7 @@ def _sap_range(
441442
range_out[worldid, geomid] = limit - geomid
442443

443444

445+
@cache_kernel
444446
def _sap_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
445447
@wp.kernel(module="unique", enable_backward=False)
446448
def kernel(
@@ -641,6 +643,7 @@ def sap_broadphase(m: Model, d: Data, ctx: CollisionContext):
641643
)
642644

643645

646+
@cache_kernel
644647
def _nxn_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
645648
@wp.kernel(module="unique", enable_backward=False)
646649
def kernel(

mujoco_warp/_src/collision_primitive.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from mujoco_warp._src.types import Model
4343
from mujoco_warp._src.types import mat43
4444
from mujoco_warp._src.types import vec5
45+
from mujoco_warp._src.warp_util import cache_kernel
4546
from mujoco_warp._src.warp_util import event_scope
4647

4748
wp.set_module_options({"enable_backward": False})
@@ -1296,6 +1297,7 @@ def box_box_wrapper(
12961297
}
12971298

12981299

1300+
@cache_kernel
12991301
def _primitive_narrowphase(primitive_collisions_types, primitive_collisions_func):
13001302
@wp.kernel(module="unique", enable_backward=False)
13011303
def primitive_narrowphase(

mujoco_warp/_src/constraint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from mujoco_warp._src.types import DisableBit
2424
from mujoco_warp._src.types import vec5
2525
from mujoco_warp._src.types import vec11
26+
from mujoco_warp._src.warp_util import cache_kernel
2627
from mujoco_warp._src.warp_util import event_scope
2728

2829
wp.set_module_options({"enable_backward": False})
@@ -672,6 +673,7 @@ def _equality_tendon(
672673
)
673674

674675

676+
@cache_kernel
675677
def _equality_flex(is_sparse: bool):
676678
@wp.kernel(module="unique", enable_backward=False)
677679
def kernel(

mujoco_warp/_src/forward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from mujoco_warp._src.types import TileSet
4343
from mujoco_warp._src.types import TrnType
4444
from mujoco_warp._src.types import vec10f
45+
from mujoco_warp._src.warp_util import cache_kernel
4546
from mujoco_warp._src.warp_util import event_scope
4647

4748
wp.set_module_options({"enable_backward": False})
@@ -358,6 +359,7 @@ def _euler_damp_qfrc_sparse(
358359
qM_integration_out[worldid, 0, adr] += timestep * damp_deriv[worldid, tid]
359360

360361

362+
@cache_kernel
361363
def _tile_euler_dense(tile: TileSet):
362364
@wp.kernel(module="unique", enable_backward=False)
363365
def euler_dense(

mujoco_warp/_src/sensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from mujoco_warp._src.types import vec_pluginattr
4646
from mujoco_warp._src.util_misc import inside_geom
4747
from mujoco_warp._src.util_misc import poly_potential
48+
from mujoco_warp._src.warp_util import cache_kernel
4849
from mujoco_warp._src.warp_util import event_scope
4950

5051
wp.set_module_options({"enable_backward": False})
@@ -2406,6 +2407,7 @@ def _contact_match(
24062407
sensor_contact_direction_out[worldid, contactsensorid, contactmatchid] = dir
24072408

24082409

2410+
@cache_kernel
24092411
def _contact_sort(maxmatch: int):
24102412
@wp.kernel(module="unique", enable_backward=False)
24112413
def contact_sort(
@@ -2901,6 +2903,7 @@ def energy_pos(m: Model, d: Data):
29012903
# TODO(team): flex
29022904

29032905

2906+
@cache_kernel
29042907
def _energy_vel_kinetic(nv: int):
29052908
@wp.kernel(module="unique", enable_backward=False)
29062909
def energy_vel_kinetic(

mujoco_warp/_src/smooth.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from mujoco_warp._src.types import vec5
3636
from mujoco_warp._src.types import vec10
3737
from mujoco_warp._src.types import vec11
38+
from mujoco_warp._src.warp_util import cache_kernel
3839
from mujoco_warp._src.warp_util import event_scope
3940

4041
wp.set_module_options({"enable_backward": False})
@@ -1063,6 +1064,7 @@ def _factor_i_sparse(m: Model, d: Data, M: wp.array3d[float], L: wp.array3d[floa
10631064
wp.launch(_qLDiag_div, dim=(d.nworld, m.nv), inputs=[m.M_rownnz, m.M_rowadr, L], outputs=[D])
10641065

10651066

1067+
@cache_kernel
10661068
def _tile_cholesky_factorize(tile: TileSet):
10671069
"""Returns a kernel for dense Cholesky factorization of a tile."""
10681070

@@ -2694,6 +2696,7 @@ def transmission(m: Model, d: Data):
26942696
)
26952697

26962698

2699+
@cache_kernel
26972700
def _solve_LD_sparse_fused(nv: int, nlevels: int):
26982701
"""Fused sparse backsubstitution: UP + diag + DOWN in one kernel."""
26992702

@@ -2779,6 +2782,7 @@ def _solve_LD_sparse(
27792782
)
27802783

27812784

2785+
@cache_kernel
27822786
def _tile_cholesky_solve(tile: TileSet):
27832787
"""Returns a kernel for dense Cholesky backsubstitution of a tile."""
27842788

@@ -2856,6 +2860,7 @@ def solve_m(m: Model, d: Data, x: wp.array2d[float], y: wp.array2d[float]):
28562860
solve_LD(m, d, d.qLD, d.qLDiagInv, x, y)
28572861

28582862

2863+
@cache_kernel
28592864
def _tile_cholesky_factorize_solve(tile: TileSet):
28602865
"""Returns a kernel for dense Cholesky factorization and backsubstitution of a tile."""
28612866

mujoco_warp/_src/solver.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mujoco_warp._src import types
2626
from mujoco_warp._src.block_cholesky import create_blocked_cholesky_func
2727
from mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func
28+
from mujoco_warp._src.warp_util import cache_kernel
2829
from mujoco_warp._src.warp_util import event_scope
2930
from mujoco_warp._src.warp_util import scoped_mathdx_gemm_disabled
3031

@@ -882,6 +883,7 @@ def _compute_efc_eval_pt_3alphas_elliptic(
882883
# =============================================================================
883884

884885

886+
@cache_kernel
885887
def linesearch_iterative(ls_iterations: int, cone_type: types.ConeType, fuse_jv: bool, is_sparse: bool):
886888
"""Factory for iterative linesearch kernel.
887889
@@ -1408,6 +1410,7 @@ def linesearch_zero_jv(
14081410
ctx_jv_out[worldid, efcid] = 0.0
14091411

14101412

1413+
@cache_kernel
14111414
def linesearch_jv_fused(is_sparse: bool, nv: int, dofs_per_thread: int):
14121415
@wp.kernel(module="unique", enable_backward=False)
14131416
def kernel(
@@ -1468,6 +1471,7 @@ def kernel(
14681471
return kernel
14691472

14701473

1474+
@cache_kernel
14711475
def linesearch_prepare_gauss(nv: int, dofs_per_thread: int):
14721476
@wp.kernel(module="unique", enable_backward=False)
14731477
def kernel(
@@ -1715,6 +1719,7 @@ def solve_init_efc(
17151719
ctx_search_dot_out[worldid] = 0.0
17161720

17171721

1722+
@cache_kernel
17181723
def solve_init_jaref(is_sparse: bool, nv: int, dofs_per_thread: int):
17191724
@wp.kernel(module="unique", enable_backward=False)
17201725
def kernel(
@@ -1797,6 +1802,7 @@ def update_constraint_init_cost(
17971802
ctx_cost_out[worldid] = 0.0
17981803

17991804

1805+
@cache_kernel
18001806
def update_constraint_efc(track_changes: bool):
18011807
TRACK_CHANGES = track_changes
18021808

@@ -2004,6 +2010,7 @@ def update_constraint_init_qfrc_constraint_dense(
20042010
qfrc_constraint_out[worldid, dofid] = sum_qfrc
20052011

20062012

2013+
@cache_kernel
20072014
def update_constraint_gauss_cost(nv: int, dofs_per_thread: int):
20082015
@wp.kernel(module="unique", enable_backward=False)
20092016
def kernel(
@@ -2287,6 +2294,7 @@ def active_check(tid: int, threshold: int) -> float:
22872294
return 1.0
22882295

22892296

2297+
@cache_kernel
22902298
def update_gradient_JTDAJ_sparse_tiled(tile_size: int, njmax: int):
22912299
TILE_SIZE = tile_size
22922300

@@ -2356,6 +2364,7 @@ def kernel(
23562364
return kernel
23572365

23582366

2367+
@cache_kernel
23592368
def update_gradient_JTDAJ_dense_tiled(nv_pad: int, tile_size: int, njmax: int):
23602369
if njmax < tile_size:
23612370
tile_size = njmax
@@ -2719,6 +2728,7 @@ def update_gradient_JTCJ_dense(
27192728
ctx_h_out[worldid, dof1id, dof2id] += h
27202729

27212730

2731+
@cache_kernel
27222732
def update_gradient_cholesky(tile_size: int):
27232733
@wp.kernel(module="unique", enable_backward=False)
27242734
def kernel(
@@ -2744,6 +2754,7 @@ def kernel(
27442754
return kernel
27452755

27462756

2757+
@cache_kernel
27472758
def update_gradient_cholesky_blocked(tile_size: int, matrix_size: int):
27482759
@wp.kernel(module="unique", enable_backward=False)
27492760
def kernel(

mujoco_warp/_src/support.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from mujoco_warp._src.types import State
2828
from mujoco_warp._src.types import vec5
2929
from mujoco_warp._src.types import vec10f
30+
from mujoco_warp._src.warp_util import cache_kernel
3031
from mujoco_warp._src.warp_util import event_scope
3132

3233
wp.set_module_options({"enable_backward": False})
@@ -63,6 +64,7 @@ def next_act(
6364
return act
6465

6566

67+
@cache_kernel
6668
def mul_m_sparse(check_skip: bool):
6769
@wp.kernel(module="unique")
6870
def _mul_m_sparse(
@@ -99,6 +101,7 @@ def _mul_m_sparse(
99101
return _mul_m_sparse
100102

101103

104+
@cache_kernel
102105
def mul_m_dense(nv: int, check_skip: bool):
103106
"""Simple SIMT dense matmul: one thread per output element."""
104107

@@ -429,6 +432,7 @@ def jac_dof(
429432
return jacp, jacr
430433

431434

435+
@cache_kernel
432436
def _make_jac_kernel(has_jacp: bool, has_jacr: bool):
433437
@wp.kernel(module="unique", enable_backward=False)
434438
def _jac(

mujoco_warp/_src/warp_util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,28 @@ def wrapper(*args, **kwargs):
119119
return wrapper
120120

121121

122+
_KERNEL_CACHE = {}
123+
124+
125+
def cache_kernel(func):
126+
# caching kernels to avoid crashes in graph_conditional code
127+
@functools.wraps(func)
128+
def wrapper(*args):
129+
def _hash_arg(a):
130+
if hasattr(a, "size"):
131+
return a.size
132+
if isinstance(a, list):
133+
return hash(tuple(a))
134+
return hash(a)
135+
136+
key = tuple(_hash_arg(a) for a in args) + (hash(func.__name__),)
137+
if key not in _KERNEL_CACHE:
138+
_KERNEL_CACHE[key] = func(*args)
139+
return _KERNEL_CACHE[key]
140+
141+
return wrapper
142+
143+
122144
def check_toolkit_driver():
123145
wp.init()
124146
if wp.get_device().is_cuda:

0 commit comments

Comments
 (0)