@@ -2693,43 +2693,65 @@ def transmission(m: Model, d: Data):
26932693 )
26942694
26952695
2696- @wp .kernel
2697- def _solve_LD_sparse_x_acc_up (
2698- # In:
2699- L : wp .array3d (dtype = float ),
2700- qLD_updates_ : wp .array (dtype = wp .vec3i ),
2701- # Out:
2702- x : wp .array2d (dtype = float ),
2703- ):
2704- worldid , nodeid = wp .tid ()
2705- update = qLD_updates_ [nodeid ]
2706- i , k , Madr_ki = update [0 ], update [1 ], update [2 ]
2707- wp .atomic_sub (x [worldid ], i , L [worldid , 0 , Madr_ki ] * x [worldid , k ])
2708-
2709-
2710- @wp .kernel
2711- def _solve_LD_sparse_qLDiag_mul (
2712- # In:
2713- D : wp .array2d (dtype = float ),
2714- # Out:
2715- out : wp .array2d (dtype = float ),
2716- ):
2717- worldid , dofid = wp .tid ()
2718- out [worldid , dofid ] *= D [worldid , dofid ]
2696+ @cache_kernel
2697+ def _solve_LD_sparse_fused (nv : int , nlevels : int ):
2698+ """Fused sparse backsubstitution: UP + diag + DOWN in one kernel."""
27192699
2700+ @wp .func_native (snippet = "WP_TILE_SYNC();" )
2701+ def _syncthreads ():
2702+ pass
27202703
2721- @wp .kernel
2722- def _solve_LD_sparse_x_acc_down (
2723- # In:
2724- L : wp .array3d (dtype = float ),
2725- qLD_updates_ : wp .array (dtype = wp .vec3i ),
2726- # Out:
2727- x : wp .array2d (dtype = float ),
2728- ):
2729- worldid , nodeid = wp .tid ()
2730- update = qLD_updates_ [nodeid ]
2731- i , k , Madr_ki = update [0 ], update [1 ], update [2 ]
2732- wp .atomic_sub (x [worldid ], k , L [worldid , 0 , Madr_ki ] * x [worldid , i ])
2704+ @wp .kernel (module = "unique" , enable_backward = False )
2705+ def kernel (
2706+ # In:
2707+ L : wp .array3d (dtype = float ),
2708+ D : wp .array2d (dtype = float ),
2709+ all_updates : wp .array (dtype = wp .vec3i ),
2710+ level_offsets : wp .array (dtype = int ),
2711+ y : wp .array2d (dtype = float ),
2712+ # Out:
2713+ x_out : wp .array2d (dtype = float ),
2714+ ):
2715+ worldid , tid = wp .tid ()
2716+ NV = wp .static (nv )
2717+ NLEVELS = wp .static (nlevels )
2718+ BLOCK_DIM = wp .block_dim ()
2719+
2720+ # Copy y to x_out
2721+ for dofid in range (tid , NV , BLOCK_DIM ):
2722+ x_out [worldid , dofid ] = y [worldid , dofid ]
2723+ _syncthreads ()
2724+
2725+ # Forward substitution
2726+ for level in range (NLEVELS ):
2727+ level_idx = NLEVELS - 1 - level
2728+ level_offset = level_offsets [level_idx ]
2729+ level_size = level_offsets [level_idx + 1 ] - level_offset
2730+
2731+ for u in range (tid , level_size , BLOCK_DIM ):
2732+ update = all_updates [level_offset + u ]
2733+ i , k , Madr_ki = update [0 ], update [1 ], update [2 ]
2734+ wp .atomic_sub (x_out [worldid ], i , L [worldid , 0 , Madr_ki ] * x_out [worldid , k ])
2735+ _syncthreads ()
2736+
2737+ # Diagonal multiply
2738+ for dofid in range (tid , NV , BLOCK_DIM ):
2739+ x_out [worldid , dofid ] *= D [worldid , dofid ]
2740+ _syncthreads ()
2741+
2742+ # Backward substitution
2743+ for level in range (NLEVELS ):
2744+ level_idx = level
2745+ level_offset = level_offsets [level_idx ]
2746+ level_size = level_offsets [level_idx + 1 ] - level_offset
2747+
2748+ for u in range (tid , level_size , BLOCK_DIM ):
2749+ update = all_updates [level_offset + u ]
2750+ i , k , Madr_ki = update [0 ], update [1 ], update [2 ]
2751+ wp .atomic_sub (x_out [worldid ], k , L [worldid , 0 , Madr_ki ] * x_out [worldid , i ])
2752+ _syncthreads ()
2753+
2754+ return kernel
27332755
27342756
27352757def _solve_LD_sparse (
@@ -2741,14 +2763,20 @@ def _solve_LD_sparse(
27412763 y : wp .array2d (dtype = float ),
27422764):
27432765 """Computes sparse backsubstitution: x = inv(L'*D*L)*y."""
2744- wp .copy (x , y )
2745- for qLD_updates in reversed (m .qLD_updates ):
2746- wp .launch (_solve_LD_sparse_x_acc_up , dim = (d .nworld , qLD_updates .size ), inputs = [L , qLD_updates ], outputs = [x ])
2747-
2748- wp .launch (_solve_LD_sparse_qLDiag_mul , dim = (d .nworld , m .nv ), inputs = [D ], outputs = [x ])
2766+ nlevels = len (m .qLD_updates )
2767+ if wp .get_device ().is_cuda :
2768+ dim_block = m .block_dim .solve_LD_sparse_fused
2769+ else :
2770+ # Fallback for CPU
2771+ dim_block = 1
27492772
2750- for qLD_updates in m .qLD_updates :
2751- wp .launch (_solve_LD_sparse_x_acc_down , dim = (d .nworld , qLD_updates .size ), inputs = [L , qLD_updates ], outputs = [x ])
2773+ wp .launch (
2774+ _solve_LD_sparse_fused (m .nv , nlevels ),
2775+ dim = (d .nworld , dim_block ),
2776+ inputs = [L , D , m .qLD_all_updates , m .qLD_level_offsets , y ],
2777+ outputs = [x ],
2778+ block_dim = dim_block ,
2779+ )
27522780
27532781
27542782@cache_kernel
0 commit comments