Skip to content

Commit 60f843e

Browse files
committed
perf(bb): iter 2 — merged tree finalize for cases 1-3
Replaces ba_finalize_collect + ba_finalize_inverse + ba_finalize_apply with ONE merged kernel that handles cases 1-3 inline (writing Jacobian directly to bucket_sum_x/y/z) and compacts case-4 deltas (~0.4% of slots at logN=16) for a much smaller batch inverse over only the case-4 entries. Today's finalize_inverse runs the Montgomery prefix product over all T*h ~= 524K slots even though 99.6% are placeholder writes. This kernel shrinks the per-subtask inverse chain by ~270x.
1 parent 8e15a17 commit 60f843e

5 files changed

Lines changed: 856 additions & 1 deletion

File tree

barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,16 @@ export const smvp_batch_affine_gpu = async (
313313
// T*h ≈ 524 K threads).
314314
const finalize_collect_shader = shaderManager.gen_batch_affine_finalize_collect_shader(f_workgroup_size, num_columns);
315315
const finalize_apply_shader = shaderManager.gen_batch_affine_finalize_apply_shader(f_workgroup_size, num_columns);
316+
// Tree-reduce path uses a merged finalize that writes Jacobian
317+
// directly to bucket_x/y/z for cases 1-3 (no per-bucket delta) and
318+
// only compacts case-4 deltas (~0.4% of T*h slots at logN=16) into
319+
// a per-subtask slice for a much smaller batch inverse.
320+
const tree_finalize_shader = use_tree_reduce
321+
? shaderManager.gen_batch_affine_tree_finalize_shader(f_workgroup_size, num_columns)
322+
: '';
323+
const tree_finalize_apply_shader = use_tree_reduce
324+
? shaderManager.gen_batch_affine_tree_finalize_apply_shader(f_workgroup_size, num_columns)
325+
: '';
316326

317327
const _compile_t0 = performance.now();
318328

@@ -432,6 +442,48 @@ export const smvp_batch_affine_gpu = async (
432442
`bn254:batch_affine_finalize_apply:v1:${num_columns}:${input_size}:${f_workgroup_size}`,
433443
);
434444

445+
const tree_finalize_pipe = use_tree_reduce
446+
? await compile_pipeline_for(
447+
device,
448+
[
449+
'read-only-storage', // 0 running_x
450+
'read-only-storage', // 1 running_y
451+
'read-only-storage', // 2 bucket_active
452+
'storage', // 3 bucket_x
453+
'storage', // 4 bucket_y
454+
'storage', // 5 bucket_z
455+
'storage', // 6 case4_delta
456+
'storage', // 7 case4_back_id
457+
'storage', // 8 case4_count (atomic)
458+
'uniform', // 9 params
459+
],
460+
tree_finalize_shader,
461+
context,
462+
`bn254:batch_affine_tree_finalize:v1:${num_columns}:${input_size}:${f_workgroup_size}`,
463+
)
464+
: undefined;
465+
466+
const tree_finalize_apply_pipe = use_tree_reduce
467+
? await compile_pipeline_for(
468+
device,
469+
[
470+
'read-only-storage', // 0 running_x
471+
'read-only-storage', // 1 running_y
472+
'read-only-storage', // 2 bucket_active
473+
'read-only-storage', // 3 case4_back_id
474+
'read-only-storage', // 4 case4_inv
475+
'read-only-storage', // 5 case4_count
476+
'storage', // 6 bucket_x
477+
'storage', // 7 bucket_y
478+
'storage', // 8 bucket_z
479+
'uniform', // 9 params
480+
],
481+
tree_finalize_apply_shader,
482+
context,
483+
`bn254:batch_affine_tree_finalize_apply:v1:${num_columns}:${input_size}:${f_workgroup_size}`,
484+
)
485+
: undefined;
486+
435487
cpu_timer?.accumulate('compile_smvp_batch_affine', performance.now() - _compile_t0);
436488

437489
// ----- Uniforms -----
@@ -609,6 +661,62 @@ export const smvp_batch_affine_gpu = async (
609661
finalize_ub,
610662
]);
611663

664+
// Case-4 compaction buffers + bind groups for the tree-reduce merged
665+
// finalize. Sized for worst-case (every slot is case 4); in practice
666+
// typically <1% of slots are case 4 at logN=16, so most of the
667+
// allocation is unused but cheap to hold.
668+
let case4_delta_sb: GPUBuffer | undefined;
669+
let case4_back_id_sb: GPUBuffer | undefined;
670+
let case4_count_sb: GPUBuffer | undefined;
671+
let tree_finalize_bg: GPUBindGroup | undefined;
672+
let tree_finalize_inverse_bg: GPUBindGroup | undefined;
673+
let tree_finalize_apply_bg: GPUBindGroup | undefined;
674+
if (use_tree_reduce) {
675+
const case4_pool_capacity = num_subtasks * half_num_columns;
676+
case4_delta_sb = acquire_ws('case4_delta', case4_pool_capacity * limb_byte_length);
677+
case4_back_id_sb = acquire_ws('case4_back_id', case4_pool_capacity * 4);
678+
case4_count_sb = acquire_ws('case4_count', num_subtasks * 4);
679+
680+
tree_finalize_bg = acquire_bg('tree_finalize_bg', tree_finalize_pipe!.layout, [
681+
running_x_sb,
682+
running_y_sb,
683+
bucket_active_sb,
684+
bucket_sum_x_sb,
685+
bucket_sum_y_sb,
686+
bucket_sum_z_sb,
687+
case4_delta_sb,
688+
case4_back_id_sb,
689+
case4_count_sb,
690+
finalize_ub,
691+
]);
692+
693+
// Reuse the per-subtask batch_inverse pipeline. count_buf =
694+
// case4_count_sb (per-subtask atomic counter populated by the
695+
// merged finalize); inputs = case4_delta_sb, outputs = pair_inv_sb
696+
// (reused as case4_inv — the per-round pool is unused during
697+
// finalize). Pitch is the same per-subtask stride as case4_*.
698+
tree_finalize_inverse_bg = acquire_bg('tree_finalize_inverse_bg', inverse_pipe.layout, [
699+
case4_delta_sb,
700+
pair_prefix_sb,
701+
pair_inv_sb,
702+
case4_count_sb,
703+
inverse_finalize_ub,
704+
]);
705+
706+
tree_finalize_apply_bg = acquire_bg('tree_finalize_apply_bg', tree_finalize_apply_pipe!.layout, [
707+
running_x_sb,
708+
running_y_sb,
709+
bucket_active_sb,
710+
case4_back_id_sb,
711+
pair_inv_sb,
712+
case4_count_sb,
713+
bucket_sum_x_sb,
714+
bucket_sum_y_sb,
715+
bucket_sum_z_sb,
716+
finalize_ub,
717+
]);
718+
}
719+
612720
// ----- Dispatch sequence -----
613721

614722
// 1. Init: ceil(total_buckets / 256) workgroups in x, 1 thread per bucket.
@@ -951,6 +1059,44 @@ export const smvp_batch_affine_gpu = async (
9511059
}
9521060
} // end if (!use_tree_reduce)
9531061

1062+
if (use_tree_reduce) {
1063+
// 3'. Merged finalize. Zero case4_count, dispatch the merged
1064+
// collect (writes cases 1-3 directly to bucket_x/y/z, compacts
1065+
// case 4 into case4_delta + case4_back_id), then run the per-subtask
1066+
// batch_inverse over the compacted case-4 slice, then dispatch the
1067+
// compacted apply (early-returns past case4_count[subtask_idx]).
1068+
commandEncoder.clearBuffer(case4_count_sb!, 0, num_subtasks * 4);
1069+
1070+
await execute_pipeline(
1071+
commandEncoder,
1072+
tree_finalize_pipe!.pipeline,
1073+
tree_finalize_bg!,
1074+
f_num_x_workgroups,
1075+
f_num_y_workgroups,
1076+
f_num_z_workgroups,
1077+
profiler?.stage('ba_tree_finalize'),
1078+
);
1079+
1080+
await execute_pipeline(
1081+
commandEncoder,
1082+
inverse_pipe.pipeline,
1083+
tree_finalize_inverse_bg!,
1084+
NUM_SUB_WGS_PER_SUBTASK,
1085+
1,
1086+
num_subtasks,
1087+
profiler?.stage('ba_tree_finalize_inverse'),
1088+
);
1089+
1090+
await execute_pipeline(
1091+
commandEncoder,
1092+
tree_finalize_apply_pipe!.pipeline,
1093+
tree_finalize_apply_bg!,
1094+
f_num_x_workgroups,
1095+
f_num_y_workgroups,
1096+
f_num_z_workgroups,
1097+
profiler?.stage('ba_tree_finalize_apply'),
1098+
);
1099+
} else {
9541100
// 3. Finalize — three single dispatches: collect → batch_inverse → apply.
9551101
//
9561102
// Pass A (collect): single dispatch over T·h threads. Each thread
@@ -993,4 +1139,5 @@ export const smvp_batch_affine_gpu = async (
9931139
f_num_z_workgroups,
9941140
profiler?.stage('ba_finalize_apply'),
9951141
);
1142+
}
9961143
};

barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import {
99
batch_affine_finalize as batch_affine_finalize_shader,
1010
batch_affine_finalize_apply as batch_affine_finalize_apply_shader,
1111
batch_affine_finalize_collect as batch_affine_finalize_collect_shader,
12+
batch_affine_tree_finalize as batch_affine_tree_finalize_shader,
13+
batch_affine_tree_finalize_apply as batch_affine_tree_finalize_apply_shader,
1214
batch_affine_init as batch_affine_init_shader,
1315
batch_affine_schedule as batch_affine_schedule_shader,
1416
batch_inverse as batch_inverse_shader,
@@ -961,6 +963,58 @@ export class ShaderManager {
961963
);
962964
}
963965

966+
public gen_batch_affine_tree_finalize_shader(workgroup_size: number, num_csr_cols: number): string {
967+
return mustache.render(
968+
batch_affine_tree_finalize_shader,
969+
{
970+
workgroup_size,
971+
num_columns: num_csr_cols,
972+
half_num_columns: num_csr_cols / 2,
973+
word_size: this.word_size,
974+
num_words: this.num_words,
975+
n0: this.n0,
976+
p_limbs: this.p_limbs,
977+
r_limbs: this.r_limbs,
978+
mask: this.mask,
979+
two_pow_word_size: this.two_pow_word_size,
980+
p_inv_mod_2w: this.p_inv_mod_2w,
981+
recompile: this.recompile,
982+
},
983+
{
984+
structs,
985+
bigint_funcs,
986+
montgomery_product_funcs: this.mont_product_src,
987+
field_funcs,
988+
},
989+
);
990+
}
991+
992+
public gen_batch_affine_tree_finalize_apply_shader(workgroup_size: number, num_csr_cols: number): string {
993+
return mustache.render(
994+
batch_affine_tree_finalize_apply_shader,
995+
{
996+
workgroup_size,
997+
num_columns: num_csr_cols,
998+
half_num_columns: num_csr_cols / 2,
999+
word_size: this.word_size,
1000+
num_words: this.num_words,
1001+
n0: this.n0,
1002+
p_limbs: this.p_limbs,
1003+
r_limbs: this.r_limbs,
1004+
mask: this.mask,
1005+
two_pow_word_size: this.two_pow_word_size,
1006+
p_inv_mod_2w: this.p_inv_mod_2w,
1007+
recompile: this.recompile,
1008+
},
1009+
{
1010+
structs,
1011+
bigint_funcs,
1012+
montgomery_product_funcs: this.mont_product_src,
1013+
field_funcs,
1014+
},
1015+
);
1016+
}
1017+
9641018
public gen_batch_affine_apply_shader(workgroup_size: number): string {
9651019
return mustache.render(
9661020
batch_affine_apply_shader,

0 commit comments

Comments
 (0)