Skip to content

feat(bb/msm-webgpu): straus_msm port — all 9 phases (P1-P9) stacked#23475

Draft
AztecBot wants to merge 9 commits into
zw/msm-webgpu-experiments-v2from
cb/4914b38e87d4
Draft

feat(bb/msm-webgpu): straus_msm port — all 9 phases (P1-P9) stacked#23475
AztecBot wants to merge 9 commits into
zw/msm-webgpu-experiments-v2from
cb/4914b38e87d4

Conversation

@AztecBot
Copy link
Copy Markdown
Collaborator

@AztecBot AztecBot commented May 21, 2026

Summary

Full execution of the straus_msm WebGPU port plan — every phase committed as a separate, reviewable commit on this branch. Phases were intended to run sequentially across PRs but the operator directed me to bundle them into a single session, so they're stacked here. Review per-commit; squash on merge will collapse them.

Commit Phase What it lands
3f9ad9d9a2c P1 Host-side reference + GLV/Booth primitives
ce60b3947fa P2 Lookup-precompute WGSL kernel + dispatch unit test
0248458f946 P3 straus_main kernel — single-chunk correctness
aeb637155bb P4 Multi-thread dispatch test (T = ceil(n/k))
6d03dfb087a P5 GPU combine tree-fold + Jacobian→affine kernel
9241f3d7551 P6 TrivialMsm host driver (create/prepare/run/destroy)
11a4f7eb81d P7 bench-nt-sweep page + local exhaustive Playwright driver
43c7f0e0efa P8 M2 BrowserStack confirmation tooling
01ac7379d01 P9 Size dispatcher — TrivialMsm for small N

Per-phase notes

P1 — Host-side reference + GLV/Booth primitives

  • STRAUS_REFERENCE.md (committed copy of the P0 reference gist).
  • straus/glv.ts + tests — splitIntoEndomorphismScalars from field_declarations.hpp:501-530; λ materialised from bn254/fr.hpp Mont limbs.
  • straus/booth.ts + tests — BoothSliceParams, 32-row table, boothPackedDigit.
  • straus/reference.ts + tests — pure-TS straus_msm cross-checked against noble.G1.msm for n ∈ [1, 256], plus the three edge cases from element.test.cpp.

P2 — Lookup-precompute kernel

  • wgsl/cuzk/straus_lookup_precompute_bn254.template.wgsl: one thread per active point builds lut[i*8 + k] = (k+1) · base[i].
  • cuzk/straus_kernels.ts: StrausKernels.{render,compile}LookupPrecompute factory (the gpu helpers parameter keeps the module jest-loadable in node).
  • cuzk/shader_manager.ts: gen_straus_lookup_precompute_shader(n, workgroup_size).
  • dev/msm-webgpu/wgsl_unit_tests.ts: testStrausLookupPrecompute(n) wired into runAllWgslUnitTests for n ∈ {1, 8, 64, 256, 1024}.

P3 — straus_main kernel

  • wgsl/cuzk/straus_main_bn254.template.wgsl: compile-time NUM_THREAD_MULS, counted inner ii loop (NOT unrolled), 32 windows × 2 halves × ii. Reads k1_lims / k2_lims from storage each iter; β-Mont injected from straus_constants.ts.
  • Booth-digit extraction in WGSL uses 32-bit limbs directly — no u64 helpers needed.
  • cuzk/straus_constants.ts (new): fqCubeRootOfUnityMont(numWords, wordSize) re-bases β from bn254/fq.hpp's 2^256-Mont into this tree's R = 2^(num_words·word_size)-Mont form. Sanity test verifies β³ ≡ 1 mod q.
  • dev/msm-webgpu/wgsl_unit_tests.ts: testStrausChunk(k) for k ∈ {1, 2, 3, 4, 6, 8, 12, 16} (single-chunk dispatch, Jacobian readback compared to referenceStrausMsm).

P4 — Multi-thread dispatch test

  • testStrausMultiThread(n, k) dispatches T = ceil(n/k) threads, reads back T Jacobian partials, sums them via noble's projective add, asserts affine equality with referenceStrausMsm. Wired for (n, k) ∈ {16, 64, 256, 1024} × {1, 2, 4, 8}.

P5 — Combine + to-affine

  • wgsl/cuzk/straus_combine_fold_bn254.template.wgsl: in-pass add_points(in[2t], in[2t+1])out[t]. Ping-pong-buffered by the host between dispatches; in-place fold avoided because cross-workgroup read/write hazards would otherwise need a global barrier.
  • wgsl/cuzk/straus_to_affine_bn254.template.wgsl: single-thread Jacobian→affine via fr_inv_by_a (the same BY safegcd driver batch_inverse already uses).
  • testStrausEndToEnd(n, k) runs lookup + main + log2(T) fold passes + to-affine and compares affine (x, y) directly to referenceStrausMsm.

P6 — TrivialMsm host driver

  • dev/msm-webgpu/trivial_msm.ts: create(device, n, pointsBuf, ntm) / prepare(scalarsBuf) / run() / destroy() mirroring MsmV2. create compiles every per-(n, k) pipeline including all log2(T0) fold passes and runs lookup_precompute once; prepare does host-side GLV split + Booth-pack + scalar upload; run encodes + submits the full straus pipeline and reads back the 2 × BigInt result.

P7 — bench-nt-sweep

  • dev/msm-webgpu/bench-nt-sweep.{html,ts}: clone of bench-c-sweep parameterised on NUM_THREAD_MULS. Each cell records median + min TrivialMsm.run() ms; the same row also times MsmV2 so the table reports per-logN speedup and surfaces the crossover.
  • dev/msm-webgpu/scripts/bench-nt-sweep.mjs: headless Playwright driver, stall-detection at 180s, prints final pickNTM JSON + speedup table on stdout.
  • dev/msm-webgpu/scripts/run-browserstack.mjs: pageMap entry for bench-nt-sweep.

P8 — M2 BrowserStack confirmation

  • dev/msm-webgpu/scripts/narrow-from-local.mjs: reads local sweep JSON, emits the union of bestNtm ± 1 per logN (constrained to the default set) as the BS --ntmlist.
  • dev/msm-webgpu/results_format.ts + scripts/format-m2-report.mjs: convert the BS JSONL into the markdown gist body (pickNTM table + speedup + crossover).

P9 — Size dispatcher

  • src/msm_webgpu/cuzk/trivial_msm.ts: production copy of the dev driver (dev/ becomes a thin re-export to keep the two from diverging).
  • src/msm_webgpu/cuzk/size_dispatcher.ts: compute_bn254_msm_auto(device, n, points, scalars, fallback) routes n ≤ PICK_NTM_CROSSOVER_N to TrivialMsm with k = pickNtm(n); otherwise falls through to the supplied MsmV2-shaped fallback. The pickNTM table and crossover constant are placeholders that need P8's actual M2 measurements.
  • src/msm_webgpu/index.ts: re-exports compute_bn254_msm_auto, pickNtm, PICK_NTM_CROSSOVER_N, TrivialMsm.

Gate command output (host-side)

$ cd barretenberg/ts && yarn test src/msm_webgpu/cuzk/straus_kernels src/msm_webgpu/straus

PASS src/msm_webgpu/straus/glv.test.ts
PASS src/msm_webgpu/cuzk/straus_kernels.test.ts
PASS src/msm_webgpu/straus/booth.test.ts
PASS src/msm_webgpu/straus/reference.test.ts

Test Suites: 4 passed, 4 total
Tests:       64 passed, 64 total

Tests cover: P1 GLV identity over 200 random scalars + negative-k2 sweep + 11 explicit edge cases; P1 Booth round-trip over 100 GLV outputs + 256 random halves + tiny boundaries; P1 referenceStrausMsm vs noble for n ∈ [1, 256] × 5 seeded inputs + the three element.test.cpp edge cases; P2 lookup-precompute renderer (bindings, partials, mustache wiring); P3 straus_main renderer (NUM_THREAD_MULS interpolation, no inner-loop unroll, bindings, 20-limb β-Mont initializer, β³ ≡ 1 mod q); P5 combine-fold renderer (T_IN interpolation, bindings) and to-affine renderer (single-thread, fr_inv_by_a).

Caveats / follow-ups

  • No GPU validation in this session. The session container has no Chrome and no GPU, so none of the kernel-dispatch tests, the bench-nt-sweep page, the Playwright drivers, or the BrowserStack run were exercised locally. The host-side jest covers algorithm correctness (P1) and the renderer wiring (P2-P5). Static review of WGSL kernels, bind-group layouts, buffer sizing, and dispatch counts checks out, but a real WebGPU device is required to confirm:
    • The lookup-precompute kernel produces correct (k+1) · base[i] entries.
    • The straus_main Booth digit extraction in WGSL matches the host port (the WGSL form uses 32-bit limbs directly; the host uses bigint operations — both should produce the same digits).
    • The combine-fold ping-pong is correctly wired in TrivialMsm and the multi-thread test.
    • The fr_inv_by_a in to-affine produces a correct Mont-form inverse on the read-back Z.
    • TrivialMsm.run() is idempotent across repeats (the warm-path stability claim from the plan).
  • P8 hasn't been run. The pickNTM table and PICK_NTM_CROSSOVER_N in size_dispatcher.ts are static-reasoning placeholders. Run the local sweep then BrowserStack narrowed sweep to populate them with real M2 numbers.
  • TrivialMsm MsmV2 sanity-cross-check skipped. P6 in the plan said "add a TrivialMsm row to the existing sweep UI" in main.ts. Touching main.ts (1500+ lines, deep integration with MsmV2 lifecycle) without GPU validation felt too risky; left for a follow-up. The bench-nt-sweep page already covers TrivialMsm vs MsmV2 timing comparison.
  • drive-unit-tests.mjs not added. P2 asked for a Playwright driver that invokes one unit test by name. The new testStrausLookupPrecompute / testStrausChunk / testStrausMultiThread / testStrausEndToEnd functions are wired into runAllWgslUnitTests() reachable from main.ts's "Run Unit Tests" button. Standalone driver + ?autorun=… URL flag can be added in a follow-up.
  • The plan's yarn build:esm gate fails on pre-existing branch errors unrelated to any phase here: missing src/cbind/generated/* (requires yarn generate + native C++ build) plus pre-existing TS errors in cuzk/batch_affine.ts, cuzk/smvp_tree.ts, barretenberg/poseidon.{test,bench.test}.ts. Confirmed by running yarn build:esm on the unmodified zw/msm-webgpu-experiments-v2 base. None of the failing files are touched in this PR.
  • --testPathPattern='src/msm_webgpu/straus' from the plan's gate does not filter under this repo's jest config; use yarn test src/msm_webgpu/straus (positional) instead.

@AztecBot AztecBot added the claudebox Owned by claudebox. it can push to this PR. label May 21, 2026
@AztecBot AztecBot changed the title feat(bb/msm-webgpu): host-side straus_msm reference + GLV/Booth primitives feat(bb/msm-webgpu): straus_msm port — P1 host-side primitives + P2 lookup-precompute kernel May 21, 2026
@AztecBot AztecBot changed the title feat(bb/msm-webgpu): straus_msm port — P1 host-side primitives + P2 lookup-precompute kernel feat(bb/msm-webgpu): straus_msm port — all 9 phases (P1-P9) stacked May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

claudebox Owned by claudebox. it can push to this PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant