Skip to content

Commit b762c15

Browse files
Blake3 refactor (#420)
* inlining of `u32_add` and `u32_xor` in `blake3_g_function` * `blake3_compress_chunks` refactor * `blake3_compress_inner_perm` inline
1 parent 51f270b commit b762c15

1 file changed

Lines changed: 210 additions & 60 deletions

File tree

Ix/IxVM/Blake3.lean

Lines changed: 210 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def blake3 := ⟦
4242

4343
fn blake3(input: ByteStream) -> [[G; 4]; 8] {
4444
let IV = [[103, 230, 9, 106], [133, 174, 103, 187], [114, 243, 110, 60], [58, 245, 79, 165], [127, 82, 14, 81], [140, 104, 5, 155], [171, 217, 131, 31], [25, 205, 224, 91]];
45-
blake3_compress_layer(load(blake3_compress_chunks(input, store(ListNode.Nil), 0, 0, [0; 8], IV, store(Layer.Nil))))
45+
blake3_compress_layer(load(blake3_compress_chunks(input, store(ListNode.Nil), 0, 0, store([0; 8]), store(IV), store(Layer.Nil))))
4646
}
4747

4848
fn blake3_next_layer(layer: Layer, digest: [[G; 4]; 8], root: G) -> (MaybeDigest, Layer) {
@@ -95,8 +95,8 @@ def blake3 := ⟦
9595
byte_acc: ByteStream,
9696
block_index: G,
9797
chunk_index: G,
98-
chunk_count: U64,
99-
block_digest: [[G; 4]; 8],
98+
chunk_count: &U64,
99+
block_digest: &[[G; 4]; 8],
100100
layer: &Layer
101101
) -> &Layer {
102102
match load(input) {
@@ -205,26 +205,26 @@ def blake3 := ⟦
205205
byte_acc: ByteStream,
206206
block_index: G,
207207
chunk_index: G,
208-
chunk_count: U64,
209-
block_digest: [[G; 4]; 8],
208+
chunk_count: &U64,
209+
block_digest: &[[G; 4]; 8],
210210
layer: &Layer
211211
) -> &Layer {
212212
let CHUNK_START = 1;
213213
let CHUNK_END = 2;
214214
let ROOT = 8;
215215
match (block_index, chunk_index) {
216216
(0, 0) =>
217-
match chunk_count {
217+
match load(chunk_count) {
218218
[0, 0, 0, 0, 0, 0, 0, 0] =>
219219
let flags = ROOT + CHUNK_START + CHUNK_END;
220-
store(Layer.Push(layer, blake3_compress(block_digest, [[0; 4]; 16], chunk_count, 0, flags))),
220+
store(Layer.Push(layer, blake3_compress(load(block_digest), [[0; 4]; 16], load(chunk_count), 0, flags))),
221221
_ => layer,
222222
},
223-
(0, _) => store(Layer.Push(layer, block_digest)),
223+
(0, _) => store(Layer.Push(layer, load(block_digest))),
224224
(_, _) =>
225-
let flags = CHUNK_END + u64_is_zero(chunk_count) * ROOT + eq_zero(chunk_index - block_index) * CHUNK_START;
225+
let flags = CHUNK_END + u64_is_zero(load(chunk_count)) * ROOT + eq_zero(chunk_index - block_index) * CHUNK_START;
226226
let block = bytes_to_block(pad_block(byte_acc, 64 - block_index));
227-
store(Layer.Push(layer, blake3_compress(block_digest, block, chunk_count, block_index, flags))),
227+
store(Layer.Push(layer, blake3_compress(load(block_digest), block, load(chunk_count), block_index, flags))),
228228
}
229229
}
230230

@@ -235,8 +235,8 @@ def blake3 := ⟦
235235
input: ByteStream,
236236
byte_acc: ByteStream,
237237
chunk_index: G,
238-
chunk_count: U64,
239-
block_digest: [[G; 4]; 8],
238+
chunk_count: &U64,
239+
block_digest: &[[G; 4]; 8],
240240
layer: &Layer
241241
) -> &Layer {
242242
let CHUNK_START = 1;
@@ -245,17 +245,17 @@ def blake3 := ⟦
245245
let block = bytes_to_block(byte_acc);
246246
match chunk_index {
247247
1023 =>
248-
let flags = ROOT * list_is_empty(input) * u64_is_zero(chunk_count) + CHUNK_END;
248+
let flags = ROOT * list_is_empty(input) * u64_is_zero(load(chunk_count)) + CHUNK_END;
249249
let IV = [[103, 230, 9, 106], [133, 174, 103, 187], [114, 243, 110, 60], [58, 245, 79, 165], [127, 82, 14, 81], [140, 104, 5, 155], [171, 217, 131, 31], [25, 205, 224, 91]];
250-
let layer = store(Layer.Push(layer, blake3_compress(block_digest, block, chunk_count, 64, flags)));
251-
blake3_compress_chunks(input, store(ListNode.Nil), 0, 0, relaxed_u64_succ(chunk_count), IV, layer),
250+
let layer = store(Layer.Push(layer, blake3_compress(load(block_digest), block, load(chunk_count), 64, flags)));
251+
blake3_compress_chunks(input, store(ListNode.Nil), 0, 0, store(relaxed_u64_succ(load(chunk_count))), store(IV), layer),
252252
_ =>
253253
let chunk_end_flag = list_is_empty(input) * CHUNK_END;
254-
let root_flag = list_is_empty(input) * u64_is_zero(chunk_count) * ROOT;
254+
let root_flag = list_is_empty(input) * u64_is_zero(load(chunk_count)) * ROOT;
255255
let chunk_start_flag = eq_zero(chunk_index - 63) * CHUNK_START;
256256
let flags = chunk_end_flag + root_flag + chunk_start_flag;
257-
let block_digest = blake3_compress(block_digest, block, chunk_count, 64, flags);
258-
blake3_compress_chunks(input, store(ListNode.Nil), 0, chunk_index + 1, chunk_count, block_digest, layer),
257+
let block_digest = blake3_compress(load(block_digest), block, load(chunk_count), 64, flags);
258+
blake3_compress_chunks(input, store(ListNode.Nil), 0, chunk_index + 1, chunk_count, store(block_digest), layer),
259259
}
260260
}
261261

@@ -267,12 +267,49 @@ def blake3 := ⟦
267267
x: [G; 4],
268268
y: [G; 4]
269269
) -> [[G; 4]; 4] {
270-
let a = u32_add(u32_add(a, b), x);
271-
let [d0, d1, d2, d3] = u32_xor(d, a);
270+
-- a = (a + b) + x
271+
let (r1_0, r1_c1) = u8_add(a[0], b[0]);
272+
let (r1_s1, r1_o1) = u8_add(a[1], b[1]);
273+
let (r1_1, r1_c1a) = u8_add(r1_s1, r1_c1);
274+
let r1_c2 = r1_o1 + r1_c1a;
275+
let (r1_s2, r1_o2) = u8_add(a[2], b[2]);
276+
let (r1_2, r1_c2a) = u8_add(r1_s2, r1_c2);
277+
let r1_c3 = r1_o2 + r1_c2a;
278+
let (r1_s3, _z) = u8_add(a[3], b[3]);
279+
let (r1_3, _z) = u8_add(r1_s3, r1_c3);
280+
let (a0, r2_c1) = u8_add(r1_0, x[0]);
281+
let (r2_s1, r2_o1) = u8_add(r1_1, x[1]);
282+
let (a1, r2_c1a) = u8_add(r2_s1, r2_c1);
283+
let r2_c2 = r2_o1 + r2_c1a;
284+
let (r2_s2, r2_o2) = u8_add(r1_2, x[2]);
285+
let (a2, r2_c2a) = u8_add(r2_s2, r2_c2);
286+
let r2_c3 = r2_o2 + r2_c2a;
287+
let (r2_s3, _z) = u8_add(r1_3, x[3]);
288+
let (a3, _z) = u8_add(r2_s3, r2_c3);
289+
let a = [a0, a1, a2, a3];
290+
291+
let d0 = u8_xor(d[0], a[0]);
292+
let d1 = u8_xor(d[1], a[1]);
293+
let d2 = u8_xor(d[2], a[2]);
294+
let d3 = u8_xor(d[3], a[3]);
272295
let d = [d2, d3, d0, d1]; -- Right-rotated 16
273296

274-
let c = u32_add(c, d);
275-
let [b0, b1, b2, b3] = u32_xor(b, c);
297+
-- c = c + d
298+
let (nc0, r3_c1) = u8_add(c[0], d[0]);
299+
let (r3_s1, r3_o1) = u8_add(c[1], d[1]);
300+
let (nc1, r3_c1a) = u8_add(r3_s1, r3_c1);
301+
let r3_c2 = r3_o1 + r3_c1a;
302+
let (r3_s2, r3_o2) = u8_add(c[2], d[2]);
303+
let (nc2, r3_c2a) = u8_add(r3_s2, r3_c2);
304+
let r3_c3 = r3_o2 + r3_c2a;
305+
let (r3_s3, _z) = u8_add(c[3], d[3]);
306+
let (nc3, _z) = u8_add(r3_s3, r3_c3);
307+
let c = [nc0, nc1, nc2, nc3];
308+
309+
let b0 = u8_xor(b[0], c[0]);
310+
let b1 = u8_xor(b[1], c[1]);
311+
let b2 = u8_xor(b[2], c[2]);
312+
let b3 = u8_xor(b[3], c[3]);
276313
let [b00, b01, b02, b03, b04, b05, b06, b07] = u8_bit_decomposition(b0);
277314
let [b10, b11, b12, b13, b14, b15, b16, b17] = u8_bit_decomposition(b1);
278315
let [b20, b21, b22, b23, b24, b25, b26, b27] = u8_bit_decomposition(b2);
@@ -285,12 +322,49 @@ def blake3 := ⟦
285322
let b3 = b04 + 2 * b05 + 4 * b06 + 8 * b07 + 16 * b10 + 32 * b11 + 64 * b12 + 128 * b13;
286323
let b = [b0, b1, b2, b3]; -- Right-rotated 12
287324

288-
let a = u32_add(u32_add(a, b), y);
289-
let [d0, d1, d2, d3] = u32_xor(d, a);
325+
-- a = (a + b) + y
326+
let (r4_0, r4_c1) = u8_add(a[0], b[0]);
327+
let (r4_s1, r4_o1) = u8_add(a[1], b[1]);
328+
let (r4_1, r4_c1a) = u8_add(r4_s1, r4_c1);
329+
let r4_c2 = r4_o1 + r4_c1a;
330+
let (r4_s2, r4_o2) = u8_add(a[2], b[2]);
331+
let (r4_2, r4_c2a) = u8_add(r4_s2, r4_c2);
332+
let r4_c3 = r4_o2 + r4_c2a;
333+
let (r4_s3, _z) = u8_add(a[3], b[3]);
334+
let (r4_3, _z) = u8_add(r4_s3, r4_c3);
335+
let (a0, r5_c1) = u8_add(r4_0, y[0]);
336+
let (r5_s1, r5_o1) = u8_add(r4_1, y[1]);
337+
let (a1, r5_c1a) = u8_add(r5_s1, r5_c1);
338+
let r5_c2 = r5_o1 + r5_c1a;
339+
let (r5_s2, r5_o2) = u8_add(r4_2, y[2]);
340+
let (a2, r5_c2a) = u8_add(r5_s2, r5_c2);
341+
let r5_c3 = r5_o2 + r5_c2a;
342+
let (r5_s3, _z) = u8_add(r4_3, y[3]);
343+
let (a3, _z) = u8_add(r5_s3, r5_c3);
344+
let a = [a0, a1, a2, a3];
345+
346+
let d0 = u8_xor(d[0], a[0]);
347+
let d1 = u8_xor(d[1], a[1]);
348+
let d2 = u8_xor(d[2], a[2]);
349+
let d3 = u8_xor(d[3], a[3]);
290350
let d = [d1, d2, d3, d0]; -- Right-rotated 8
291351

292-
let c = u32_add(c, d);
293-
let [b0, b1, b2, b3] = u32_xor(b, c);
352+
-- c = c + d
353+
let (nc0, r6_c1) = u8_add(c[0], d[0]);
354+
let (r6_s1, r6_o1) = u8_add(c[1], d[1]);
355+
let (nc1, r6_c1a) = u8_add(r6_s1, r6_c1);
356+
let r6_c2 = r6_o1 + r6_c1a;
357+
let (r6_s2, r6_o2) = u8_add(c[2], d[2]);
358+
let (nc2, r6_c2a) = u8_add(r6_s2, r6_c2);
359+
let r6_c3 = r6_o2 + r6_c2a;
360+
let (r6_s3, _z) = u8_add(c[3], d[3]);
361+
let (nc3, _z) = u8_add(r6_s3, r6_c3);
362+
let c = [nc0, nc1, nc2, nc3];
363+
364+
let b0 = u8_xor(b[0], c[0]);
365+
let b1 = u8_xor(b[1], c[1]);
366+
let b2 = u8_xor(b[2], c[2]);
367+
let b3 = u8_xor(b[3], c[3]);
294368
let [b00, b01, b02, b03, b04, b05, b06, b07] = u8_bit_decomposition(b0);
295369
let [b10, b11, b12, b13, b14, b15, b16, b17] = u8_bit_decomposition(b1);
296370
let [b20, b21, b22, b23, b24, b25, b26, b27] = u8_bit_decomposition(b2);
@@ -380,26 +454,6 @@ def blake3 := ⟦
380454
state
381455
}
382456

383-
fn blake3_compress_inner_perm(state: [[G; 4]; 32]) -> [[G; 4]; 32] {
384-
let new_state = set(state, 16, state[18]);
385-
let new_state = set(new_state, 17, state[22]);
386-
let new_state = set(new_state, 18, state[19]);
387-
let new_state = set(new_state, 19, state[26]);
388-
let new_state = set(new_state, 20, state[23]);
389-
let new_state = set(new_state, 21, state[16]);
390-
let new_state = set(new_state, 22, state[20]);
391-
let new_state = set(new_state, 23, state[29]);
392-
let new_state = set(new_state, 24, state[17]);
393-
let new_state = set(new_state, 25, state[27]);
394-
let new_state = set(new_state, 26, state[28]);
395-
let new_state = set(new_state, 27, state[21]);
396-
let new_state = set(new_state, 28, state[25]);
397-
let new_state = set(new_state, 29, state[30]);
398-
let new_state = set(new_state, 30, state[31]);
399-
let new_state = set(new_state, 31, state[24]);
400-
new_state
401-
}
402-
403457
-- TODO:
404458
-- `block_words` could be two arguments of type [[G; 4]; 8]
405459
fn blake3_compress(
@@ -432,40 +486,136 @@ def blake3 := ⟦
432486

433487
-- Round 0
434488
let state = blake3_compress_inner_j(state);
435-
let state = blake3_compress_inner_perm(state);
489+
let new_state = set(state, 16, state[18]);
490+
let new_state = set(new_state, 17, state[22]);
491+
let new_state = set(new_state, 18, state[19]);
492+
let new_state = set(new_state, 19, state[26]);
493+
let new_state = set(new_state, 20, state[23]);
494+
let new_state = set(new_state, 21, state[16]);
495+
let new_state = set(new_state, 22, state[20]);
496+
let new_state = set(new_state, 23, state[29]);
497+
let new_state = set(new_state, 24, state[17]);
498+
let new_state = set(new_state, 25, state[27]);
499+
let new_state = set(new_state, 26, state[28]);
500+
let new_state = set(new_state, 27, state[21]);
501+
let new_state = set(new_state, 28, state[25]);
502+
let new_state = set(new_state, 29, state[30]);
503+
let new_state = set(new_state, 30, state[31]);
504+
let new_state = set(new_state, 31, state[24]);
505+
let state = new_state;
436506

437507
-- Round 1
438508
let state = blake3_compress_inner_j(state);
439-
let state = blake3_compress_inner_perm(state);
509+
let new_state = set(state, 16, state[18]);
510+
let new_state = set(new_state, 17, state[22]);
511+
let new_state = set(new_state, 18, state[19]);
512+
let new_state = set(new_state, 19, state[26]);
513+
let new_state = set(new_state, 20, state[23]);
514+
let new_state = set(new_state, 21, state[16]);
515+
let new_state = set(new_state, 22, state[20]);
516+
let new_state = set(new_state, 23, state[29]);
517+
let new_state = set(new_state, 24, state[17]);
518+
let new_state = set(new_state, 25, state[27]);
519+
let new_state = set(new_state, 26, state[28]);
520+
let new_state = set(new_state, 27, state[21]);
521+
let new_state = set(new_state, 28, state[25]);
522+
let new_state = set(new_state, 29, state[30]);
523+
let new_state = set(new_state, 30, state[31]);
524+
let new_state = set(new_state, 31, state[24]);
525+
let state = new_state;
440526

441527
-- Round 2
442528
let state = blake3_compress_inner_j(state);
443-
let state = blake3_compress_inner_perm(state);
529+
let new_state = set(state, 16, state[18]);
530+
let new_state = set(new_state, 17, state[22]);
531+
let new_state = set(new_state, 18, state[19]);
532+
let new_state = set(new_state, 19, state[26]);
533+
let new_state = set(new_state, 20, state[23]);
534+
let new_state = set(new_state, 21, state[16]);
535+
let new_state = set(new_state, 22, state[20]);
536+
let new_state = set(new_state, 23, state[29]);
537+
let new_state = set(new_state, 24, state[17]);
538+
let new_state = set(new_state, 25, state[27]);
539+
let new_state = set(new_state, 26, state[28]);
540+
let new_state = set(new_state, 27, state[21]);
541+
let new_state = set(new_state, 28, state[25]);
542+
let new_state = set(new_state, 29, state[30]);
543+
let new_state = set(new_state, 30, state[31]);
544+
let new_state = set(new_state, 31, state[24]);
545+
let state = new_state;
444546

445547
-- Round 3
446548
let state = blake3_compress_inner_j(state);
447-
let state = blake3_compress_inner_perm(state);
549+
let new_state = set(state, 16, state[18]);
550+
let new_state = set(new_state, 17, state[22]);
551+
let new_state = set(new_state, 18, state[19]);
552+
let new_state = set(new_state, 19, state[26]);
553+
let new_state = set(new_state, 20, state[23]);
554+
let new_state = set(new_state, 21, state[16]);
555+
let new_state = set(new_state, 22, state[20]);
556+
let new_state = set(new_state, 23, state[29]);
557+
let new_state = set(new_state, 24, state[17]);
558+
let new_state = set(new_state, 25, state[27]);
559+
let new_state = set(new_state, 26, state[28]);
560+
let new_state = set(new_state, 27, state[21]);
561+
let new_state = set(new_state, 28, state[25]);
562+
let new_state = set(new_state, 29, state[30]);
563+
let new_state = set(new_state, 30, state[31]);
564+
let new_state = set(new_state, 31, state[24]);
565+
let state = new_state;
448566

449567
-- Round 4
450568
let state = blake3_compress_inner_j(state);
451-
let state = blake3_compress_inner_perm(state);
569+
let new_state = set(state, 16, state[18]);
570+
let new_state = set(new_state, 17, state[22]);
571+
let new_state = set(new_state, 18, state[19]);
572+
let new_state = set(new_state, 19, state[26]);
573+
let new_state = set(new_state, 20, state[23]);
574+
let new_state = set(new_state, 21, state[16]);
575+
let new_state = set(new_state, 22, state[20]);
576+
let new_state = set(new_state, 23, state[29]);
577+
let new_state = set(new_state, 24, state[17]);
578+
let new_state = set(new_state, 25, state[27]);
579+
let new_state = set(new_state, 26, state[28]);
580+
let new_state = set(new_state, 27, state[21]);
581+
let new_state = set(new_state, 28, state[25]);
582+
let new_state = set(new_state, 29, state[30]);
583+
let new_state = set(new_state, 30, state[31]);
584+
let new_state = set(new_state, 31, state[24]);
585+
let state = new_state;
452586

453587
-- Round 5
454588
let state = blake3_compress_inner_j(state);
455-
let state = blake3_compress_inner_perm(state);
589+
let new_state = set(state, 16, state[18]);
590+
let new_state = set(new_state, 17, state[22]);
591+
let new_state = set(new_state, 18, state[19]);
592+
let new_state = set(new_state, 19, state[26]);
593+
let new_state = set(new_state, 20, state[23]);
594+
let new_state = set(new_state, 21, state[16]);
595+
let new_state = set(new_state, 22, state[20]);
596+
let new_state = set(new_state, 23, state[29]);
597+
let new_state = set(new_state, 24, state[17]);
598+
let new_state = set(new_state, 25, state[27]);
599+
let new_state = set(new_state, 26, state[28]);
600+
let new_state = set(new_state, 27, state[21]);
601+
let new_state = set(new_state, 28, state[25]);
602+
let new_state = set(new_state, 29, state[30]);
603+
let new_state = set(new_state, 30, state[31]);
604+
let new_state = set(new_state, 31, state[24]);
605+
let state = new_state;
456606

457607
-- Round 6
458608
let state = blake3_compress_inner_j(state);
459609

460610
[
461-
u32_xor(state[0], state[8]),
462-
u32_xor(state[1], state[9]),
463-
u32_xor(state[2], state[10]),
464-
u32_xor(state[3], state[11]),
465-
u32_xor(state[4], state[12]),
466-
u32_xor(state[5], state[13]),
467-
u32_xor(state[6], state[14]),
468-
u32_xor(state[7], state[15])
611+
[u8_xor(state[0][0], state[8][0]), u8_xor(state[0][1], state[8][1]), u8_xor(state[0][2], state[8][2]), u8_xor(state[0][3], state[8][3])],
612+
[u8_xor(state[1][0], state[9][0]), u8_xor(state[1][1], state[9][1]), u8_xor(state[1][2], state[9][2]), u8_xor(state[1][3], state[9][3])],
613+
[u8_xor(state[2][0], state[10][0]), u8_xor(state[2][1], state[10][1]), u8_xor(state[2][2], state[10][2]), u8_xor(state[2][3], state[10][3])],
614+
[u8_xor(state[3][0], state[11][0]), u8_xor(state[3][1], state[11][1]), u8_xor(state[3][2], state[11][2]), u8_xor(state[3][3], state[11][3])],
615+
[u8_xor(state[4][0], state[12][0]), u8_xor(state[4][1], state[12][1]), u8_xor(state[4][2], state[12][2]), u8_xor(state[4][3], state[12][3])],
616+
[u8_xor(state[5][0], state[13][0]), u8_xor(state[5][1], state[13][1]), u8_xor(state[5][2], state[13][2]), u8_xor(state[5][3], state[13][3])],
617+
[u8_xor(state[6][0], state[14][0]), u8_xor(state[6][1], state[14][1]), u8_xor(state[6][2], state[14][2]), u8_xor(state[6][3], state[14][3])],
618+
[u8_xor(state[7][0], state[15][0]), u8_xor(state[7][1], state[15][1]), u8_xor(state[7][2], state[15][2]), u8_xor(state[7][3], state[15][3])]
469619
]
470620
}
471621

0 commit comments

Comments
 (0)