2525from mujoco_warp ._src import types
2626from mujoco_warp ._src .block_cholesky import create_blocked_cholesky_func
2727from mujoco_warp ._src .block_cholesky import create_blocked_cholesky_solve_func
28+ from mujoco_warp ._src .warp_util import cache_kernel
2829from mujoco_warp ._src .warp_util import event_scope
2930from mujoco_warp ._src .warp_util import scoped_mathdx_gemm_disabled
3031
@@ -882,6 +883,7 @@ def _compute_efc_eval_pt_3alphas_elliptic(
882883# =============================================================================
883884
884885
886+ @cache_kernel
885887def linesearch_iterative (ls_iterations : int , cone_type : types .ConeType , fuse_jv : bool , is_sparse : bool ):
886888 """Factory for iterative linesearch kernel.
887889
@@ -1408,6 +1410,7 @@ def linesearch_zero_jv(
14081410 ctx_jv_out [worldid , efcid ] = 0.0
14091411
14101412
1413+ @cache_kernel
14111414def linesearch_jv_fused (is_sparse : bool , nv : int , dofs_per_thread : int ):
14121415 @wp .kernel (module = "unique" , enable_backward = False )
14131416 def kernel (
@@ -1468,6 +1471,7 @@ def kernel(
14681471 return kernel
14691472
14701473
1474+ @cache_kernel
14711475def linesearch_prepare_gauss (nv : int , dofs_per_thread : int ):
14721476 @wp .kernel (module = "unique" , enable_backward = False )
14731477 def kernel (
@@ -1715,6 +1719,7 @@ def solve_init_efc(
17151719 ctx_search_dot_out [worldid ] = 0.0
17161720
17171721
1722+ @cache_kernel
17181723def solve_init_jaref (is_sparse : bool , nv : int , dofs_per_thread : int ):
17191724 @wp .kernel (module = "unique" , enable_backward = False )
17201725 def kernel (
@@ -1797,6 +1802,7 @@ def update_constraint_init_cost(
17971802 ctx_cost_out [worldid ] = 0.0
17981803
17991804
1805+ @cache_kernel
18001806def update_constraint_efc (track_changes : bool ):
18011807 TRACK_CHANGES = track_changes
18021808
@@ -2004,6 +2010,7 @@ def update_constraint_init_qfrc_constraint_dense(
20042010 qfrc_constraint_out [worldid , dofid ] = sum_qfrc
20052011
20062012
2013+ @cache_kernel
20072014def update_constraint_gauss_cost (nv : int , dofs_per_thread : int ):
20082015 @wp .kernel (module = "unique" , enable_backward = False )
20092016 def kernel (
@@ -2287,6 +2294,7 @@ def active_check(tid: int, threshold: int) -> float:
22872294 return 1.0
22882295
22892296
2297+ @cache_kernel
22902298def update_gradient_JTDAJ_sparse_tiled (tile_size : int , njmax : int ):
22912299 TILE_SIZE = tile_size
22922300
@@ -2356,6 +2364,7 @@ def kernel(
23562364 return kernel
23572365
23582366
2367+ @cache_kernel
23592368def update_gradient_JTDAJ_dense_tiled (nv_pad : int , tile_size : int , njmax : int ):
23602369 if njmax < tile_size :
23612370 tile_size = njmax
@@ -2719,6 +2728,7 @@ def update_gradient_JTCJ_dense(
27192728 ctx_h_out [worldid , dof1id , dof2id ] += h
27202729
27212730
2731+ @cache_kernel
27222732def update_gradient_cholesky (tile_size : int ):
27232733 @wp .kernel (module = "unique" , enable_backward = False )
27242734 def kernel (
@@ -2744,6 +2754,7 @@ def kernel(
27442754 return kernel
27452755
27462756
2757+ @cache_kernel
27472758def update_gradient_cholesky_blocked (tile_size : int , matrix_size : int ):
27482759 @wp .kernel (module = "unique" , enable_backward = False )
27492760 def kernel (
0 commit comments