Skip to content

Commit 362fd38

Browse files
Bytes chip optimization and new operations (#421)
* `u8_chain_rotr7` and `u8_chain_rotr4` added * removed `bit_decomposition` from `blake3_g_function` * `u8_add` now outputs a single column the carry can be derived from the expression `(x + y - z)/256` * `u8_sub` now outputs single column * cross test
1 parent b762c15 commit 362fd38

25 files changed

Lines changed: 518 additions & 93 deletions

Ix/Aiur/Compiler/Check.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,14 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with
739739
let a' ← checkNoEscape a .field
740740
let b' ← checkNoEscape b .field
741741
pure (Typed.Term.u8Mul (.tuple #[.field, .field]) false a' b')
742+
| .u8ChainRotr7 a b => do
743+
let a' ← checkNoEscape a .field
744+
let b' ← checkNoEscape b .field
745+
pure (Typed.Term.u8ChainRotr7 (.tuple #[.field, .field, .field]) false a' b')
746+
| .u8ChainRotr4 a b => do
747+
let a' ← checkNoEscape a .field
748+
let b' ← checkNoEscape b .field
749+
pure (Typed.Term.u8ChainRotr4 (.tuple #[.field, .field, .field]) false a' b')
742750
| .u8Sub a b => do
743751
let a' ← checkNoEscape a .field
744752
let b' ← checkNoEscape b .field
@@ -905,6 +913,8 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with
905913
| .u8Xor τ e a b => do pure (.u8Xor (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
906914
| .u8Add τ e a b => do pure (.u8Add (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
907915
| .u8Mul τ e a b => do pure (.u8Mul (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
916+
| .u8ChainRotr7 τ e a b => do pure (.u8ChainRotr7 (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
917+
| .u8ChainRotr4 τ e a b => do pure (.u8ChainRotr4 (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
908918
| .u8Sub τ e a b => do pure (.u8Sub (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
909919
| .u8And τ e a b => do pure (.u8And (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
910920
| .u8Or τ e a b => do pure (.u8Or (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))

Ix/Aiur/Compiler/Concretize.lean

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ def termToConcrete
328328
| .u8Mul τ e a b => do
329329
pure (.u8Mul (← typToConcrete mono τ) e
330330
(← termToConcrete mono a) (← termToConcrete mono b))
331+
| .u8ChainRotr7 τ e a b => do
332+
pure (.u8ChainRotr7 (← typToConcrete mono τ) e
333+
(← termToConcrete mono a) (← termToConcrete mono b))
334+
| .u8ChainRotr4 τ e a b => do
335+
pure (.u8ChainRotr4 (← typToConcrete mono τ) e
336+
(← termToConcrete mono a) (← termToConcrete mono b))
331337
| .u8Sub τ e a b => do
332338
pure (.u8Sub (← typToConcrete mono τ) e
333339
(← termToConcrete mono a) (← termToConcrete mono b))
@@ -526,6 +532,10 @@ def rewriteTypedTerm (decls : Typed.Decls)
526532
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
527533
| .u8Mul τ e a b => .u8Mul (rewriteTyp subst mono τ) e
528534
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
535+
| .u8ChainRotr7 τ e a b => .u8ChainRotr7 (rewriteTyp subst mono τ) e
536+
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
537+
| .u8ChainRotr4 τ e a b => .u8ChainRotr4 (rewriteTyp subst mono τ) e
538+
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
529539
| .u8Sub τ e a b => .u8Sub (rewriteTyp subst mono τ) e
530540
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
531541
| .u8And τ e a b => .u8And (rewriteTyp subst mono τ) e
@@ -603,6 +613,7 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) :
603613
args.attach.foldl (fun s ⟨a, _⟩ => collectInTypedTerm s a) seen
604614
| .add τ _ a b | .sub τ _ a b | .mul τ _ a b
605615
| .u8Xor τ _ a b | .u8Add τ _ a b | .u8Mul τ _ a b | .u8Sub τ _ a b
616+
| .u8ChainRotr7 τ _ a b | .u8ChainRotr4 τ _ a b
606617
| .u8And τ _ a b | .u8Or τ _ a b
607618
| .u8LessThan τ _ a b | .u32LessThan τ _ a b =>
608619
collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b
@@ -665,6 +676,7 @@ def collectCalls (decls : Typed.Decls)
665676
bs.attach.foldl (fun s ⟨(_, b), _⟩ => collectCalls decls s b) seen
666677
| .add _ _ a b | .sub _ _ a b | .mul _ _ a b
667678
| .u8Xor _ _ a b | .u8Add _ _ a b | .u8Mul _ _ a b | .u8Sub _ _ a b
679+
| .u8ChainRotr7 _ _ a b | .u8ChainRotr4 _ _ a b
668680
| .u8And _ _ a b | .u8Or _ _ a b
669681
| .u8LessThan _ _ a b | .u32LessThan _ _ a b =>
670682
collectCalls decls (collectCalls decls seen a) b
@@ -756,6 +768,10 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term
756768
(substInTypedTerm subst a) (substInTypedTerm subst b)
757769
| .u8Mul τ e a b => .u8Mul (Typ.instantiate subst τ) e
758770
(substInTypedTerm subst a) (substInTypedTerm subst b)
771+
| .u8ChainRotr7 τ e a b => .u8ChainRotr7 (Typ.instantiate subst τ) e
772+
(substInTypedTerm subst a) (substInTypedTerm subst b)
773+
| .u8ChainRotr4 τ e a b => .u8ChainRotr4 (Typ.instantiate subst τ) e
774+
(substInTypedTerm subst a) (substInTypedTerm subst b)
759775
| .u8Sub τ e a b => .u8Sub (Typ.instantiate subst τ) e
760776
(substInTypedTerm subst a) (substInTypedTerm subst b)
761777
| .u8And τ e a b => .u8And (Typ.instantiate subst τ) e

Ix/Aiur/Compiler/Layout.lean

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,17 @@ def opLayout : Bytecode.Op → LayoutM Unit
185185
| .u8BitDecomposition _ => do pushDegrees $ .replicate 8 1; bumpAuxiliaries 8; bumpLookups
186186
| .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. | .u8And .. | .u8Or .. => do
187187
pushDegree 1; bumpAuxiliaries; bumpLookups
188-
| .u8Add .. | .u8Mul .. | .u8Sub .. => do pushDegrees #[1, 1]; bumpAuxiliaries 2; bumpLookups
188+
| .u8Add a b | .u8Sub a b => do
189+
-- Low byte `z` is the only auxiliary; the carry/borrow is a compound
190+
-- expression of degree `max(deg a, deg b, 1)` (add: `(a+b-z)/256`,
191+
-- sub: `(z+b-a)/256`).
192+
let aDegree ← getDegree a
193+
let bDegree ← getDegree b
194+
pushDegree 1
195+
pushDegree ((aDegree.max bDegree).max 1)
196+
bumpAuxiliaries; bumpLookups
197+
| .u8Mul .. => do pushDegrees #[1, 1]; bumpAuxiliaries 2; bumpLookups
198+
| .u8ChainRotr7 .. | .u8ChainRotr4 .. => do pushDegrees #[1, 1, 1]; bumpAuxiliaries 3; bumpLookups
189199
| .u8LessThan .. => do pushDegree 1; bumpAuxiliaries; bumpLookups
190200
| .u32LessThan .. => do pushDegree 1; bumpAuxiliaries 12; bumpLookups 6
191201
| .debug .. => pure ()

Ix/Aiur/Compiler/Lower.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,12 @@ def toIndex
286286
| .u8Mul _ _ i j => do
287287
let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j
288288
pushOp (.u8Mul i j) 2
289+
| .u8ChainRotr7 _ _ i j => do
290+
let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j
291+
pushOp (.u8ChainRotr7 i j) 3
292+
| .u8ChainRotr4 _ _ i j => do
293+
let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j
294+
pushOp (.u8ChainRotr4 i j) 3
289295
| .u8Sub _ _ i j => do
290296
let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j
291297
pushOp (.u8Sub i j) 2

Ix/Aiur/Compiler/Match.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ def typedToSimple : Term → Simple.Term
381381
| .u8Xor τ e a b => .u8Xor τ e (typedToSimple a) (typedToSimple b)
382382
| .u8Add τ e a b => .u8Add τ e (typedToSimple a) (typedToSimple b)
383383
| .u8Mul τ e a b => .u8Mul τ e (typedToSimple a) (typedToSimple b)
384+
| .u8ChainRotr7 τ e a b => .u8ChainRotr7 τ e (typedToSimple a) (typedToSimple b)
385+
| .u8ChainRotr4 τ e a b => .u8ChainRotr4 τ e (typedToSimple a) (typedToSimple b)
384386
| .u8Sub τ e a b => .u8Sub τ e (typedToSimple a) (typedToSimple b)
385387
| .u8And τ e a b => .u8And τ e (typedToSimple a) (typedToSimple b)
386388
| .u8Or τ e a b => .u8Or τ e (typedToSimple a) (typedToSimple b)

Ix/Aiur/Goldilocks.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ def G.u8Mul (a b : G) : G × G :=
6464
def G.u8Sub (a b : G) : G × G :=
6565
(G.ofNat ((a.n + 256 - b.n) % 256), if a.n < b.n then 1 else 0)
6666

67+
/-- Chainable partial for a right-rotation by 7 bits over little-endian bytes.
68+
Returns `(a>>7 + b<<1, b>>7, a<<1)` (shifts mod 256). -/
69+
def G.u8ChainRotr7 (a b : G) : G × G × G :=
70+
(G.ofNat (a.n / 128 + (b.n * 2) % 256), G.ofNat (b.n / 128), G.ofNat ((a.n * 2) % 256))
71+
72+
/-- Chainable partial for a right-rotation by 4 bits over little-endian bytes.
73+
Returns `(a>>4 + b<<4, b>>4, a<<4)` (shifts mod 256). -/
74+
def G.u8ChainRotr4 (a b : G) : G × G × G :=
75+
(G.ofNat (a.n / 16 + (b.n * 16) % 256), G.ofNat (b.n / 16), G.ofNat ((a.n * 16) % 256))
76+
6777
def G.u8ShiftLeft (a : G) : G := G.ofNat ((a.n * 2) % 256)
6878
def G.u8ShiftRight (a : G) : G := G.ofNat (a.n / 2)
6979

Ix/Aiur/Interpret.lean

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,24 @@ partial def interp (decls : Decls) (bindings : Bindings) : Term → InterpM Valu
358358
let hi : Value := .field (G.ofUInt8 (x / 256).toUInt8)
359359
return .tuple #[lo, hi]
360360
| _, _ => throwErr "u8Mul: expected field values"
361+
| .u8ChainRotr7 t1 t2 => do
362+
match (← interp decls bindings t1), (← interp decls bindings t2) with
363+
| .field a, .field b =>
364+
let i := a.val.toUInt8
365+
let j := b.val.toUInt8
366+
return .tuple #[.field (G.ofUInt8 ((i >>> 7) + (j <<< 1))),
367+
.field (G.ofUInt8 (j >>> 7)),
368+
.field (G.ofUInt8 (i <<< 1))]
369+
| _, _ => throwErr "u8ChainRotr7: expected field values"
370+
| .u8ChainRotr4 t1 t2 => do
371+
match (← interp decls bindings t1), (← interp decls bindings t2) with
372+
| .field a, .field b =>
373+
let i := a.val.toUInt8
374+
let j := b.val.toUInt8
375+
return .tuple #[.field (G.ofUInt8 ((i >>> 4) + (j <<< 4))),
376+
.field (G.ofUInt8 (j >>> 4)),
377+
.field (G.ofUInt8 (i <<< 4))]
378+
| _, _ => throwErr "u8ChainRotr4: expected field values"
361379
| .u8Sub t1 t2 => do
362380
match (← interp decls bindings t1), (← interp decls bindings t2) with
363381
| .field a, .field b =>

Ix/Aiur/Meta.lean

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ syntax "u8_shift_right" "(" aiur_trm ")" : a
174174
syntax "u8_xor" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
175175
syntax "u8_add" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
176176
syntax "u8_mul" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
177+
syntax "u8_chain_rotr7" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
178+
syntax "u8_chain_rotr4" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
177179
syntax "u8_sub" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
178180
syntax "u8_and" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
179181
syntax "u8_or" "(" aiur_trm ", " aiur_trm ")" : aiur_trm
@@ -295,6 +297,10 @@ partial def elabTrm : ElabStxCat `aiur_trm
295297
mkAppM ``Source.Term.u8Add #[← elabTrm i, ← elabTrm j]
296298
| `(aiur_trm| u8_mul($i:aiur_trm, $j:aiur_trm)) => do
297299
mkAppM ``Source.Term.u8Mul #[← elabTrm i, ← elabTrm j]
300+
| `(aiur_trm| u8_chain_rotr7($i:aiur_trm, $j:aiur_trm)) => do
301+
mkAppM ``Source.Term.u8ChainRotr7 #[← elabTrm i, ← elabTrm j]
302+
| `(aiur_trm| u8_chain_rotr4($i:aiur_trm, $j:aiur_trm)) => do
303+
mkAppM ``Source.Term.u8ChainRotr4 #[← elabTrm i, ← elabTrm j]
298304
| `(aiur_trm| u8_sub($i:aiur_trm, $j:aiur_trm)) => do
299305
mkAppM ``Source.Term.u8Sub #[← elabTrm i, ← elabTrm j]
300306
| `(aiur_trm| u8_and($i:aiur_trm, $j:aiur_trm)) => do
@@ -487,6 +493,14 @@ where
487493
let i ← replaceToken old new i
488494
let j ← replaceToken old new j
489495
`(aiur_trm| u8_mul($i, $j))
496+
| `(aiur_trm| u8_chain_rotr7($i:aiur_trm, $j:aiur_trm)) => do
497+
let i ← replaceToken old new i
498+
let j ← replaceToken old new j
499+
`(aiur_trm| u8_chain_rotr7($i, $j))
500+
| `(aiur_trm| u8_chain_rotr4($i:aiur_trm, $j:aiur_trm)) => do
501+
let i ← replaceToken old new i
502+
let j ← replaceToken old new j
503+
`(aiur_trm| u8_chain_rotr4($i, $j))
490504
| `(aiur_trm| u8_sub($i:aiur_trm, $j:aiur_trm)) => do
491505
let i ← replaceToken old new i
492506
let j ← replaceToken old new j

Ix/Aiur/Semantics/BytecodeEval.lean

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,18 @@ def evalOp (t : Bytecode.Toplevel) (fuel : Nat) (op : Op) (st : EvalState) :
252252
let prod := x.val.toUInt8.toNat * y.val.toUInt8.toNat
253253
let st1 := pushMap st (G.ofUInt8 prod.toUInt8)
254254
pure (pushMap st1 (G.ofUInt8 (prod / 256).toUInt8))
255+
| .u8ChainRotr7 a b => do
256+
let x ← readIdx st a; let y ← readIdx st b
257+
let i := x.val.toUInt8; let j := y.val.toUInt8
258+
let st1 := pushMap st (G.ofUInt8 ((i >>> 7) + (j <<< 1)))
259+
let st2 := pushMap st1 (G.ofUInt8 (j >>> 7))
260+
pure (pushMap st2 (G.ofUInt8 (i <<< 1)))
261+
| .u8ChainRotr4 a b => do
262+
let x ← readIdx st a; let y ← readIdx st b
263+
let i := x.val.toUInt8; let j := y.val.toUInt8
264+
let st1 := pushMap st (G.ofUInt8 ((i >>> 4) + (j <<< 4)))
265+
let st2 := pushMap st1 (G.ofUInt8 (j >>> 4))
266+
pure (pushMap st2 (G.ofUInt8 (i <<< 4)))
255267
| .u8Sub a b => do
256268
let x ← readIdx st a; let y ← readIdx st b
257269
let i := x.val.toUInt8; let j := y.val.toUInt8

Ix/Aiur/Semantics/SourceEval.lean

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,24 @@ def interp (decls : Decls) (fuel : Nat) (bindings : Bindings)
390390
.field (G.ofUInt8 (x / 256).toUInt8)])
391391
(interp decls fuel bindings t1 st)
392392
(fun st1 => interp decls fuel bindings t2 st1)
393+
| .u8ChainRotr7 t1 t2 =>
394+
combineFieldsResult
395+
(fun a b =>
396+
let i := a.val.toUInt8; let j := b.val.toUInt8
397+
.tuple #[.field (G.ofUInt8 ((i >>> 7) + (j <<< 1))),
398+
.field (G.ofUInt8 (j >>> 7)),
399+
.field (G.ofUInt8 (i <<< 1))])
400+
(interp decls fuel bindings t1 st)
401+
(fun st1 => interp decls fuel bindings t2 st1)
402+
| .u8ChainRotr4 t1 t2 =>
403+
combineFieldsResult
404+
(fun a b =>
405+
let i := a.val.toUInt8; let j := b.val.toUInt8
406+
.tuple #[.field (G.ofUInt8 ((i >>> 4) + (j <<< 4))),
407+
.field (G.ofUInt8 (j >>> 4)),
408+
.field (G.ofUInt8 (i <<< 4))])
409+
(interp decls fuel bindings t1 st)
410+
(fun st1 => interp decls fuel bindings t2 st1)
393411
| .u8Sub t1 t2 =>
394412
combineFieldsResult
395413
(fun a b =>

0 commit comments

Comments
 (0)