@@ -313,6 +313,16 @@ export const smvp_batch_affine_gpu = async (
313313 // T*h ≈ 524 K threads).
314314 const finalize_collect_shader = shaderManager . gen_batch_affine_finalize_collect_shader ( f_workgroup_size , num_columns ) ;
315315 const finalize_apply_shader = shaderManager . gen_batch_affine_finalize_apply_shader ( f_workgroup_size , num_columns ) ;
316+ // Tree-reduce path uses a merged finalize that writes Jacobian
317+ // directly to bucket_x/y/z for cases 1-3 (no per-bucket delta) and
318+ // only compacts case-4 deltas (~0.4% of T*h slots at logN=16) into
319+ // a per-subtask slice for a much smaller batch inverse.
320+ const tree_finalize_shader = use_tree_reduce
321+ ? shaderManager . gen_batch_affine_tree_finalize_shader ( f_workgroup_size , num_columns )
322+ : '' ;
323+ const tree_finalize_apply_shader = use_tree_reduce
324+ ? shaderManager . gen_batch_affine_tree_finalize_apply_shader ( f_workgroup_size , num_columns )
325+ : '' ;
316326
317327 const _compile_t0 = performance . now ( ) ;
318328
@@ -432,6 +442,48 @@ export const smvp_batch_affine_gpu = async (
432442 `bn254:batch_affine_finalize_apply:v1:${ num_columns } :${ input_size } :${ f_workgroup_size } ` ,
433443 ) ;
434444
445+ const tree_finalize_pipe = use_tree_reduce
446+ ? await compile_pipeline_for (
447+ device ,
448+ [
449+ 'read-only-storage' , // 0 running_x
450+ 'read-only-storage' , // 1 running_y
451+ 'read-only-storage' , // 2 bucket_active
452+ 'storage' , // 3 bucket_x
453+ 'storage' , // 4 bucket_y
454+ 'storage' , // 5 bucket_z
455+ 'storage' , // 6 case4_delta
456+ 'storage' , // 7 case4_back_id
457+ 'storage' , // 8 case4_count (atomic)
458+ 'uniform' , // 9 params
459+ ] ,
460+ tree_finalize_shader ,
461+ context ,
462+ `bn254:batch_affine_tree_finalize:v1:${ num_columns } :${ input_size } :${ f_workgroup_size } ` ,
463+ )
464+ : undefined ;
465+
466+ const tree_finalize_apply_pipe = use_tree_reduce
467+ ? await compile_pipeline_for (
468+ device ,
469+ [
470+ 'read-only-storage' , // 0 running_x
471+ 'read-only-storage' , // 1 running_y
472+ 'read-only-storage' , // 2 bucket_active
473+ 'read-only-storage' , // 3 case4_back_id
474+ 'read-only-storage' , // 4 case4_inv
475+ 'read-only-storage' , // 5 case4_count
476+ 'storage' , // 6 bucket_x
477+ 'storage' , // 7 bucket_y
478+ 'storage' , // 8 bucket_z
479+ 'uniform' , // 9 params
480+ ] ,
481+ tree_finalize_apply_shader ,
482+ context ,
483+ `bn254:batch_affine_tree_finalize_apply:v1:${ num_columns } :${ input_size } :${ f_workgroup_size } ` ,
484+ )
485+ : undefined ;
486+
435487 cpu_timer ?. accumulate ( 'compile_smvp_batch_affine' , performance . now ( ) - _compile_t0 ) ;
436488
437489 // ----- Uniforms -----
@@ -609,6 +661,62 @@ export const smvp_batch_affine_gpu = async (
609661 finalize_ub ,
610662 ] ) ;
611663
664+ // Case-4 compaction buffers + bind groups for the tree-reduce merged
665+ // finalize. Sized for worst-case (every slot is case 4); in practice
666+ // typically <1% of slots are case 4 at logN=16, so most of the
667+ // allocation is unused but cheap to hold.
668+ let case4_delta_sb : GPUBuffer | undefined ;
669+ let case4_back_id_sb : GPUBuffer | undefined ;
670+ let case4_count_sb : GPUBuffer | undefined ;
671+ let tree_finalize_bg : GPUBindGroup | undefined ;
672+ let tree_finalize_inverse_bg : GPUBindGroup | undefined ;
673+ let tree_finalize_apply_bg : GPUBindGroup | undefined ;
674+ if ( use_tree_reduce ) {
675+ const case4_pool_capacity = num_subtasks * half_num_columns ;
676+ case4_delta_sb = acquire_ws ( 'case4_delta' , case4_pool_capacity * limb_byte_length ) ;
677+ case4_back_id_sb = acquire_ws ( 'case4_back_id' , case4_pool_capacity * 4 ) ;
678+ case4_count_sb = acquire_ws ( 'case4_count' , num_subtasks * 4 ) ;
679+
680+ tree_finalize_bg = acquire_bg ( 'tree_finalize_bg' , tree_finalize_pipe ! . layout , [
681+ running_x_sb ,
682+ running_y_sb ,
683+ bucket_active_sb ,
684+ bucket_sum_x_sb ,
685+ bucket_sum_y_sb ,
686+ bucket_sum_z_sb ,
687+ case4_delta_sb ,
688+ case4_back_id_sb ,
689+ case4_count_sb ,
690+ finalize_ub ,
691+ ] ) ;
692+
693+ // Reuse the per-subtask batch_inverse pipeline. count_buf =
694+ // case4_count_sb (per-subtask atomic counter populated by the
695+ // merged finalize); inputs = case4_delta_sb, outputs = pair_inv_sb
696+ // (reused as case4_inv — the per-round pool is unused during
697+ // finalize). Pitch is the same per-subtask stride as case4_*.
698+ tree_finalize_inverse_bg = acquire_bg ( 'tree_finalize_inverse_bg' , inverse_pipe . layout , [
699+ case4_delta_sb ,
700+ pair_prefix_sb ,
701+ pair_inv_sb ,
702+ case4_count_sb ,
703+ inverse_finalize_ub ,
704+ ] ) ;
705+
706+ tree_finalize_apply_bg = acquire_bg ( 'tree_finalize_apply_bg' , tree_finalize_apply_pipe ! . layout , [
707+ running_x_sb ,
708+ running_y_sb ,
709+ bucket_active_sb ,
710+ case4_back_id_sb ,
711+ pair_inv_sb ,
712+ case4_count_sb ,
713+ bucket_sum_x_sb ,
714+ bucket_sum_y_sb ,
715+ bucket_sum_z_sb ,
716+ finalize_ub ,
717+ ] ) ;
718+ }
719+
612720 // ----- Dispatch sequence -----
613721
614722 // 1. Init: ceil(total_buckets / 256) workgroups in x, 1 thread per bucket.
@@ -951,6 +1059,44 @@ export const smvp_batch_affine_gpu = async (
9511059 }
9521060 } // end if (!use_tree_reduce)
9531061
1062+ if ( use_tree_reduce ) {
1063+ // 3'. Merged finalize. Zero case4_count, dispatch the merged
1064+ // collect (writes cases 1-3 directly to bucket_x/y/z, compacts
1065+ // case 4 into case4_delta + case4_back_id), then run the per-subtask
1066+ // batch_inverse over the compacted case-4 slice, then dispatch the
1067+ // compacted apply (early-returns past case4_count[subtask_idx]).
1068+ commandEncoder . clearBuffer ( case4_count_sb ! , 0 , num_subtasks * 4 ) ;
1069+
1070+ await execute_pipeline (
1071+ commandEncoder ,
1072+ tree_finalize_pipe ! . pipeline ,
1073+ tree_finalize_bg ! ,
1074+ f_num_x_workgroups ,
1075+ f_num_y_workgroups ,
1076+ f_num_z_workgroups ,
1077+ profiler ?. stage ( 'ba_tree_finalize' ) ,
1078+ ) ;
1079+
1080+ await execute_pipeline (
1081+ commandEncoder ,
1082+ inverse_pipe . pipeline ,
1083+ tree_finalize_inverse_bg ! ,
1084+ NUM_SUB_WGS_PER_SUBTASK ,
1085+ 1 ,
1086+ num_subtasks ,
1087+ profiler ?. stage ( 'ba_tree_finalize_inverse' ) ,
1088+ ) ;
1089+
1090+ await execute_pipeline (
1091+ commandEncoder ,
1092+ tree_finalize_apply_pipe ! . pipeline ,
1093+ tree_finalize_apply_bg ! ,
1094+ f_num_x_workgroups ,
1095+ f_num_y_workgroups ,
1096+ f_num_z_workgroups ,
1097+ profiler ?. stage ( 'ba_tree_finalize_apply' ) ,
1098+ ) ;
1099+ } else {
9541100 // 3. Finalize — three single dispatches: collect → batch_inverse → apply.
9551101 //
9561102 // Pass A (collect): single dispatch over T·h threads. Each thread
@@ -993,4 +1139,5 @@ export const smvp_batch_affine_gpu = async (
9931139 f_num_z_workgroups ,
9941140 profiler ?. stage ( 'ba_finalize_apply' ) ,
9951141 ) ;
1142+ }
9961143} ;
0 commit comments