Skip to content

Commit 400d0c1

Browse files
feat(bb): use Karatsuba+Yuval as the default Mont mult in WebGPU MSM
Routes the `montgomery_product_funcs` mustache partial through a pre-rendered Karatsuba+Yuval body in every MSM shader that does a base-field multiply (15 callsites: convert_points, smvp, horner, batch_affine_{apply,schedule,finalize_*,init,apply_scatter}, batch_inverse{,_parallel}, bpr, decompress_g1, montgomery_parity). The Karatsuba body benches ~27% faster than the mitschabaude runtime-loop CIOS at n=2^20, k=100 (80 ms vs 109 ms). It exposes the same `fn montgomery_product(x, y) -> BigInt` symbol plus the same `get_p` / `conditional_reduce` helpers and uses the same 20×13-bit limb layout, so the swap is a drop-in change with no callsite churn. The field-mul bench retains both options (`?variant=cios` renders the original template inline, `?variant=karat` reuses the class-level default) so the two bodies can be compared side-by-side.
1 parent 8f414ef commit 400d0c1

1 file changed

Lines changed: 32 additions & 15 deletions

File tree

barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ export class ShaderManager {
9393
public num_limbs_f32_22: number;
9494
public n0_f32_22: bigint;
9595
public p_limbs_f32_22_str: string;
96+
// Pre-rendered u32 Montgomery product source used as the
97+
// `montgomery_product_funcs` mustache partial by every MSM shader that
98+
// needs a base-field multiply. Defaults to the Karatsuba + Yuval body
99+
// (see `renderKaratYuvalMont`), which benches ~27% faster than the
100+
// runtime-loop CIOS at n=2^20, k=100 on Apple GPU. Both bodies expose
101+
// the same `fn montgomery_product(x, y) -> BigInt` symbol and the same
102+
// `get_p` / `conditional_reduce` helpers, so swapping the partial is
103+
// a drop-in change at every callsite.
104+
public mont_product_src: string;
96105
public curveConfig: CurveConfig;
97106
public recompile = '';
98107

@@ -139,6 +148,11 @@ export class ShaderManager {
139148
this.n0_f32_22 = params_f32_22.n0;
140149
this.p_limbs_f32_22_str = gen_p_limbs_f32(this.p, this.num_limbs_f32_22, 22);
141150

151+
// Render the Karatsuba+Yuval Mont body once. This is the default
152+
// u32 multiplier used by every MSM shader that includes the
153+
// `montgomery_product_funcs` mustache partial.
154+
this.mont_product_src = this.renderKaratYuvalMont();
155+
142156
if (force_recompile) {
143157
const rand = Math.round(Math.random() * 100000000000000000) % 2 ** 32;
144158
this.recompile = `
@@ -199,7 +213,7 @@ export class ShaderManager {
199213
bigint_funcs,
200214
field_funcs,
201215
barrett_funcs,
202-
montgomery_product_funcs,
216+
montgomery_product_funcs: this.mont_product_src,
203217
extract_word_from_bytes_le_funcs,
204218
},
205219
);
@@ -236,7 +250,7 @@ export class ShaderManager {
236250
bigint_funcs,
237251
field_funcs,
238252
barrett_funcs,
239-
montgomery_product_funcs,
253+
montgomery_product_funcs: this.mont_product_src,
240254
extract_word_from_bytes_le_funcs,
241255
},
242256
);
@@ -268,7 +282,7 @@ export class ShaderManager {
268282
bigint_funcs,
269283
field_funcs,
270284
barrett_funcs,
271-
montgomery_product_funcs,
285+
montgomery_product_funcs: this.mont_product_src,
272286
fr_pow_funcs,
273287
},
274288
);
@@ -354,7 +368,7 @@ export class ShaderManager {
354368
{
355369
structs,
356370
bigint_funcs,
357-
montgomery_product_funcs,
371+
montgomery_product_funcs: this.mont_product_src,
358372
field_funcs,
359373
ec_funcs: ec_bn254_funcs,
360374
},
@@ -379,7 +393,7 @@ export class ShaderManager {
379393
{
380394
structs,
381395
bigint_funcs,
382-
montgomery_product_funcs,
396+
montgomery_product_funcs: this.mont_product_src,
383397
field_funcs,
384398
fr_pow_funcs,
385399
},
@@ -405,7 +419,7 @@ export class ShaderManager {
405419
{
406420
structs,
407421
bigint_funcs,
408-
montgomery_product_funcs,
422+
montgomery_product_funcs: this.mont_product_src,
409423
field_funcs,
410424
fr_pow_funcs,
411425
},
@@ -446,7 +460,7 @@ export class ShaderManager {
446460
{
447461
structs,
448462
bigint_funcs,
449-
montgomery_product_funcs,
463+
montgomery_product_funcs: this.mont_product_src,
450464
field_funcs,
451465
},
452466
);
@@ -470,7 +484,7 @@ export class ShaderManager {
470484
{
471485
structs,
472486
bigint_funcs,
473-
montgomery_product_funcs,
487+
montgomery_product_funcs: this.mont_product_src,
474488
field_funcs,
475489
},
476490
);
@@ -497,7 +511,7 @@ export class ShaderManager {
497511
{
498512
structs,
499513
bigint_funcs,
500-
montgomery_product_funcs,
514+
montgomery_product_funcs: this.mont_product_src,
501515
field_funcs,
502516
fr_pow_funcs,
503517
ec_funcs: ec_bn254_funcs,
@@ -525,7 +539,7 @@ export class ShaderManager {
525539
{
526540
structs,
527541
bigint_funcs,
528-
montgomery_product_funcs,
542+
montgomery_product_funcs: this.mont_product_src,
529543
field_funcs,
530544
},
531545
);
@@ -551,7 +565,7 @@ export class ShaderManager {
551565
{
552566
structs,
553567
bigint_funcs,
554-
montgomery_product_funcs,
568+
montgomery_product_funcs: this.mont_product_src,
555569
field_funcs,
556570
},
557571
);
@@ -575,7 +589,7 @@ export class ShaderManager {
575589
{
576590
structs,
577591
bigint_funcs,
578-
montgomery_product_funcs,
592+
montgomery_product_funcs: this.mont_product_src,
579593
field_funcs,
580594
},
581595
);
@@ -630,7 +644,7 @@ export class ShaderManager {
630644
{
631645
structs,
632646
bigint_funcs,
633-
montgomery_product_funcs,
647+
montgomery_product_funcs: this.mont_product_src,
634648
field_funcs,
635649
ec_funcs: ec_bn254_funcs,
636650
},
@@ -657,7 +671,7 @@ export class ShaderManager {
657671
{
658672
structs,
659673
bigint_funcs,
660-
montgomery_product_funcs,
674+
montgomery_product_funcs: this.mont_product_src,
661675
field_funcs,
662676
ec_funcs: ec_bn254_funcs,
663677
},
@@ -674,9 +688,12 @@ export class ShaderManager {
674688
): string {
675689
const structs_src = mustache.render(structs, { num_words: this.num_words });
676690
const bigint_src = mustache.render(bigint_funcs, {});
691+
// 'karat' reuses the pre-rendered class-level default; 'cios' renders
692+
// the original mitschabaude template inline so the bench can compare
693+
// both bodies even though karat is the production default.
677694
const mont_src =
678695
variant === 'karat'
679-
? this.renderKaratYuvalMont()
696+
? this.mont_product_src
680697
: mustache.render(montgomery_product_funcs, {
681698
num_words: this.num_words,
682699
word_size: this.word_size,

0 commit comments

Comments
 (0)