Skip to content

Commit f6e7438

Browse files
committed
Revert "perf(bb): iter 9 — MAX_SLICE_ENTRIES 4096->8192 + MAX_PAIRS 2048->4096"
This reverts commit b90edcf.
1 parent b90edcf commit f6e7438

4 files changed

Lines changed: 34 additions & 158 deletions

File tree

barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,8 @@ export const smvp_batch_affine_gpu = async (
731731
// bucket_active stays as init wrote it. The existing finalize stage
732732
// below consumes the populated running_x/y + bucket_active.
733733
const TREE_TPB = 128;
734-
const TREE_MAX_SLICE_ENTRIES = 8192;
735-
const TREE_MAX_PAIRS = 4096;
734+
const TREE_MAX_SLICE_ENTRIES = 4096;
735+
const TREE_MAX_PAIRS = 2048;
736736
const TREE_MAX_LAYERS = 25;
737737
const TREE_PRELUDE_WG_SIZE = 64;
738738
const TREE_SCAN_WG_SIZE = 256;

barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts

Lines changed: 16 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5779,26 +5779,17 @@ fn main(
57795779
if (chunk_hi < slice_hi) {
57805780
next_chunk_bucket = entry_bucket_id[chunk_hi];
57815781
}
5782-
// PER_THREAD_ENTRIES can exceed 32 (iter 9: 64), so the emit/pair
5783-
// flag bitmasks split into low (off ∈ [0, 32)) and high (off ∈ [32, 64))
5784-
// u32s. A single u32 would silently overflow on \`1u << off\` for off >= 32.
57855782
var local_emit: u32 = 0u;
57865783
var local_pair: u32 = 0u;
5787-
var local_emit_mask_lo: u32 = 0u;
5788-
var local_emit_mask_hi: u32 = 0u;
5789-
var local_pair_mask_lo: u32 = 0u;
5790-
var local_pair_mask_hi: u32 = 0u;
5784+
var local_emit_mask: u32 = 0u;
5785+
var local_pair_mask: u32 = 0u;
57915786
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
57925787
let e = chunk_lo + off;
57935788
if (e >= chunk_hi) { continue; }
57945789
let p = e - local_break_pos[off];
57955790
if ((p & 1u) != 0u) { continue; }
57965791
local_emit = local_emit + 1u;
5797-
if (off < 32u) {
5798-
local_emit_mask_lo = local_emit_mask_lo | (1u << off);
5799-
} else {
5800-
local_emit_mask_hi = local_emit_mask_hi | (1u << (off - 32u));
5801-
}
5792+
local_emit_mask = local_emit_mask | (1u << off);
58025793
var next_b: u32 = UNPAIRED_SENTINEL;
58035794
if (off + 1u < PER_THREAD_ENTRIES) {
58045795
if (e + 1u < chunk_hi) {
@@ -5811,11 +5802,7 @@ fn main(
58115802
}
58125803
if (next_b == local_buckets[off]) {
58135804
local_pair = local_pair + 1u;
5814-
if (off < 32u) {
5815-
local_pair_mask_lo = local_pair_mask_lo | (1u << off);
5816-
} else {
5817-
local_pair_mask_hi = local_pair_mask_hi | (1u << (off - 32u));
5818-
}
5805+
local_pair_mask = local_pair_mask | (1u << off);
58195806
}
58205807
}
58215808
@@ -5851,21 +5838,12 @@ fn main(
58515838
var raw_w: u32 = raw_base;
58525839
var pair_w: u32 = pair_base;
58535840
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
5854-
var emit_bit: u32;
5855-
var pair_bit: u32;
5856-
if (off < 32u) {
5857-
emit_bit = local_emit_mask_lo & (1u << off);
5858-
pair_bit = local_pair_mask_lo & (1u << off);
5859-
} else {
5860-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
5861-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
5862-
}
5863-
if (emit_bit == 0u) { continue; }
5841+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
58645842
let e = chunk_lo + off;
58655843
let raw = raw_w;
58665844
raw_w = raw_w + 1u;
58675845
meta_pool[pair_idx_a_base + raw] = e;
5868-
if (pair_bit != 0u) {
5846+
if ((local_pair_mask & (1u << off)) != 0u) {
58695847
meta_pool[pair_idx_b_base + raw] = e + 1u;
58705848
let pair_rank = pair_w;
58715849
pair_w = pair_w + 1u;
@@ -5881,19 +5859,10 @@ fn main(
58815859
var raw_r: u32 = raw_base;
58825860
var pair_r: u32 = pair_base;
58835861
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
5884-
var emit_bit: u32;
5885-
var pair_bit: u32;
5886-
if (off < 32u) {
5887-
emit_bit = local_emit_mask_lo & (1u << off);
5888-
pair_bit = local_pair_mask_lo & (1u << off);
5889-
} else {
5890-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
5891-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
5892-
}
5893-
if (emit_bit == 0u) { continue; }
5862+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
58945863
let raw = raw_r;
58955864
raw_r = raw_r + 1u;
5896-
if (pair_bit != 0u) {
5865+
if ((local_pair_mask & (1u << off)) != 0u) {
58975866
let pair_rank = pair_r;
58985867
pair_r = pair_r + 1u;
58995868
if (pair_rank == 0u) {
@@ -6030,26 +5999,17 @@ fn main(
60305999
if (chunk_hi < slice_hi) {
60316000
next_chunk_bucket = input_bucket_id[chunk_hi];
60326001
}
6033-
// PER_THREAD_ENTRIES can exceed 32 (iter 9: 64), so the emit/pair
6034-
// flag bitmasks split into low (off ∈ [0, 32)) and high (off ∈ [32, 64))
6035-
// u32s. A single u32 would silently overflow on \`1u << off\` for off >= 32.
60366002
var local_emit: u32 = 0u;
60376003
var local_pair: u32 = 0u;
6038-
var local_emit_mask_lo: u32 = 0u;
6039-
var local_emit_mask_hi: u32 = 0u;
6040-
var local_pair_mask_lo: u32 = 0u;
6041-
var local_pair_mask_hi: u32 = 0u;
6004+
var local_emit_mask: u32 = 0u;
6005+
var local_pair_mask: u32 = 0u;
60426006
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
60436007
let e = chunk_lo + off;
60446008
if (e >= chunk_hi) { continue; }
60456009
let p = e - local_break_pos[off];
60466010
if ((p & 1u) != 0u) { continue; }
60476011
local_emit = local_emit + 1u;
6048-
if (off < 32u) {
6049-
local_emit_mask_lo = local_emit_mask_lo | (1u << off);
6050-
} else {
6051-
local_emit_mask_hi = local_emit_mask_hi | (1u << (off - 32u));
6052-
}
6012+
local_emit_mask = local_emit_mask | (1u << off);
60536013
var next_b: u32 = UNPAIRED_SENTINEL;
60546014
if (off + 1u < PER_THREAD_ENTRIES) {
60556015
if (e + 1u < chunk_hi) {
@@ -6062,11 +6022,7 @@ fn main(
60626022
}
60636023
if (next_b == local_buckets[off]) {
60646024
local_pair = local_pair + 1u;
6065-
if (off < 32u) {
6066-
local_pair_mask_lo = local_pair_mask_lo | (1u << off);
6067-
} else {
6068-
local_pair_mask_hi = local_pair_mask_hi | (1u << (off - 32u));
6069-
}
6025+
local_pair_mask = local_pair_mask | (1u << off);
60706026
}
60716027
}
60726028
@@ -6102,21 +6058,12 @@ fn main(
61026058
var raw_w: u32 = raw_base;
61036059
var pair_w: u32 = pair_base;
61046060
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
6105-
var emit_bit: u32;
6106-
var pair_bit: u32;
6107-
if (off < 32u) {
6108-
emit_bit = local_emit_mask_lo & (1u << off);
6109-
pair_bit = local_pair_mask_lo & (1u << off);
6110-
} else {
6111-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
6112-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
6113-
}
6114-
if (emit_bit == 0u) { continue; }
6061+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
61156062
let e = chunk_lo + off;
61166063
let raw = raw_w;
61176064
raw_w = raw_w + 1u;
61186065
meta_pool[pair_idx_a_base + raw] = e;
6119-
if (pair_bit != 0u) {
6066+
if ((local_pair_mask & (1u << off)) != 0u) {
61206067
meta_pool[pair_idx_b_base + raw] = e + 1u;
61216068
let pair_rank = pair_w;
61226069
pair_w = pair_w + 1u;
@@ -6132,19 +6079,10 @@ fn main(
61326079
var raw_r: u32 = raw_base;
61336080
var pair_r: u32 = pair_base;
61346081
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
6135-
var emit_bit: u32;
6136-
var pair_bit: u32;
6137-
if (off < 32u) {
6138-
emit_bit = local_emit_mask_lo & (1u << off);
6139-
pair_bit = local_pair_mask_lo & (1u << off);
6140-
} else {
6141-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
6142-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
6143-
}
6144-
if (emit_bit == 0u) { continue; }
6082+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
61456083
let raw = raw_r;
61466084
raw_r = raw_r + 1u;
6147-
if (pair_bit != 0u) {
6085+
if ((local_pair_mask & (1u << off)) != 0u) {
61486086
let pair_rank = pair_r;
61496087
pair_r = pair_r + 1u;
61506088
if (pair_rank == 0u) {

barretenberg/ts/src/msm_webgpu/wgsl/cuzk/smvp_tree_meta_phase1.template.wgsl

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -145,26 +145,17 @@ fn main(
145145
if (chunk_hi < slice_hi) {
146146
next_chunk_bucket = entry_bucket_id[chunk_hi];
147147
}
148-
// PER_THREAD_ENTRIES can exceed 32 (iter 9: 64), so the emit/pair
149-
// flag bitmasks split into low (off ∈ [0, 32)) and high (off ∈ [32, 64))
150-
// u32s. A single u32 would silently overflow on `1u << off` for off >= 32.
151148
var local_emit: u32 = 0u;
152149
var local_pair: u32 = 0u;
153-
var local_emit_mask_lo: u32 = 0u;
154-
var local_emit_mask_hi: u32 = 0u;
155-
var local_pair_mask_lo: u32 = 0u;
156-
var local_pair_mask_hi: u32 = 0u;
150+
var local_emit_mask: u32 = 0u;
151+
var local_pair_mask: u32 = 0u;
157152
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
158153
let e = chunk_lo + off;
159154
if (e >= chunk_hi) { continue; }
160155
let p = e - local_break_pos[off];
161156
if ((p & 1u) != 0u) { continue; }
162157
local_emit = local_emit + 1u;
163-
if (off < 32u) {
164-
local_emit_mask_lo = local_emit_mask_lo | (1u << off);
165-
} else {
166-
local_emit_mask_hi = local_emit_mask_hi | (1u << (off - 32u));
167-
}
158+
local_emit_mask = local_emit_mask | (1u << off);
168159
var next_b: u32 = UNPAIRED_SENTINEL;
169160
if (off + 1u < PER_THREAD_ENTRIES) {
170161
if (e + 1u < chunk_hi) {
@@ -177,11 +168,7 @@ fn main(
177168
}
178169
if (next_b == local_buckets[off]) {
179170
local_pair = local_pair + 1u;
180-
if (off < 32u) {
181-
local_pair_mask_lo = local_pair_mask_lo | (1u << off);
182-
} else {
183-
local_pair_mask_hi = local_pair_mask_hi | (1u << (off - 32u));
184-
}
171+
local_pair_mask = local_pair_mask | (1u << off);
185172
}
186173
}
187174

@@ -217,21 +204,12 @@ fn main(
217204
var raw_w: u32 = raw_base;
218205
var pair_w: u32 = pair_base;
219206
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
220-
var emit_bit: u32;
221-
var pair_bit: u32;
222-
if (off < 32u) {
223-
emit_bit = local_emit_mask_lo & (1u << off);
224-
pair_bit = local_pair_mask_lo & (1u << off);
225-
} else {
226-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
227-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
228-
}
229-
if (emit_bit == 0u) { continue; }
207+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
230208
let e = chunk_lo + off;
231209
let raw = raw_w;
232210
raw_w = raw_w + 1u;
233211
meta_pool[pair_idx_a_base + raw] = e;
234-
if (pair_bit != 0u) {
212+
if ((local_pair_mask & (1u << off)) != 0u) {
235213
meta_pool[pair_idx_b_base + raw] = e + 1u;
236214
let pair_rank = pair_w;
237215
pair_w = pair_w + 1u;
@@ -247,19 +225,10 @@ fn main(
247225
var raw_r: u32 = raw_base;
248226
var pair_r: u32 = pair_base;
249227
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
250-
var emit_bit: u32;
251-
var pair_bit: u32;
252-
if (off < 32u) {
253-
emit_bit = local_emit_mask_lo & (1u << off);
254-
pair_bit = local_pair_mask_lo & (1u << off);
255-
} else {
256-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
257-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
258-
}
259-
if (emit_bit == 0u) { continue; }
228+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
260229
let raw = raw_r;
261230
raw_r = raw_r + 1u;
262-
if (pair_bit != 0u) {
231+
if ((local_pair_mask & (1u << off)) != 0u) {
263232
let pair_rank = pair_r;
264233
pair_r = pair_r + 1u;
265234
if (pair_rank == 0u) {

barretenberg/ts/src/msm_webgpu/wgsl/cuzk/smvp_tree_meta_phase2.template.wgsl

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -120,26 +120,17 @@ fn main(
120120
if (chunk_hi < slice_hi) {
121121
next_chunk_bucket = input_bucket_id[chunk_hi];
122122
}
123-
// PER_THREAD_ENTRIES can exceed 32 (iter 9: 64), so the emit/pair
124-
// flag bitmasks split into low (off ∈ [0, 32)) and high (off ∈ [32, 64))
125-
// u32s. A single u32 would silently overflow on `1u << off` for off >= 32.
126123
var local_emit: u32 = 0u;
127124
var local_pair: u32 = 0u;
128-
var local_emit_mask_lo: u32 = 0u;
129-
var local_emit_mask_hi: u32 = 0u;
130-
var local_pair_mask_lo: u32 = 0u;
131-
var local_pair_mask_hi: u32 = 0u;
125+
var local_emit_mask: u32 = 0u;
126+
var local_pair_mask: u32 = 0u;
132127
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
133128
let e = chunk_lo + off;
134129
if (e >= chunk_hi) { continue; }
135130
let p = e - local_break_pos[off];
136131
if ((p & 1u) != 0u) { continue; }
137132
local_emit = local_emit + 1u;
138-
if (off < 32u) {
139-
local_emit_mask_lo = local_emit_mask_lo | (1u << off);
140-
} else {
141-
local_emit_mask_hi = local_emit_mask_hi | (1u << (off - 32u));
142-
}
133+
local_emit_mask = local_emit_mask | (1u << off);
143134
var next_b: u32 = UNPAIRED_SENTINEL;
144135
if (off + 1u < PER_THREAD_ENTRIES) {
145136
if (e + 1u < chunk_hi) {
@@ -152,11 +143,7 @@ fn main(
152143
}
153144
if (next_b == local_buckets[off]) {
154145
local_pair = local_pair + 1u;
155-
if (off < 32u) {
156-
local_pair_mask_lo = local_pair_mask_lo | (1u << off);
157-
} else {
158-
local_pair_mask_hi = local_pair_mask_hi | (1u << (off - 32u));
159-
}
146+
local_pair_mask = local_pair_mask | (1u << off);
160147
}
161148
}
162149

@@ -192,21 +179,12 @@ fn main(
192179
var raw_w: u32 = raw_base;
193180
var pair_w: u32 = pair_base;
194181
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
195-
var emit_bit: u32;
196-
var pair_bit: u32;
197-
if (off < 32u) {
198-
emit_bit = local_emit_mask_lo & (1u << off);
199-
pair_bit = local_pair_mask_lo & (1u << off);
200-
} else {
201-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
202-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
203-
}
204-
if (emit_bit == 0u) { continue; }
182+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
205183
let e = chunk_lo + off;
206184
let raw = raw_w;
207185
raw_w = raw_w + 1u;
208186
meta_pool[pair_idx_a_base + raw] = e;
209-
if (pair_bit != 0u) {
187+
if ((local_pair_mask & (1u << off)) != 0u) {
210188
meta_pool[pair_idx_b_base + raw] = e + 1u;
211189
let pair_rank = pair_w;
212190
pair_w = pair_w + 1u;
@@ -222,19 +200,10 @@ fn main(
222200
var raw_r: u32 = raw_base;
223201
var pair_r: u32 = pair_base;
224202
for (var off: u32 = 0u; off < PER_THREAD_ENTRIES; off = off + 1u) {
225-
var emit_bit: u32;
226-
var pair_bit: u32;
227-
if (off < 32u) {
228-
emit_bit = local_emit_mask_lo & (1u << off);
229-
pair_bit = local_pair_mask_lo & (1u << off);
230-
} else {
231-
emit_bit = local_emit_mask_hi & (1u << (off - 32u));
232-
pair_bit = local_pair_mask_hi & (1u << (off - 32u));
233-
}
234-
if (emit_bit == 0u) { continue; }
203+
if ((local_emit_mask & (1u << off)) == 0u) { continue; }
235204
let raw = raw_r;
236205
raw_r = raw_r + 1u;
237-
if (pair_bit != 0u) {
206+
if ((local_pair_mask & (1u << off)) != 0u) {
238207
let pair_rank = pair_r;
239208
pair_r = pair_r + 1u;
240209
if (pair_rank == 0u) {

0 commit comments

Comments
 (0)