Skip to content

Commit 7bdeed9

Browse files
U8 type (#426)
* `U8` type. Type safe `u8` operations * `U8` pattern matching * `U8` in the IxVM * `Sha256` fix
1 parent 7e6a02d commit 7bdeed9

38 files changed

Lines changed: 1366 additions & 1059 deletions

Ix/Aiur/Compiler/Check.lean

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ inductive CheckError
4848
| unconstrainedConstructor : Global → CheckError
4949
| infiniteType : Nat → Typ → CheckError
5050
| unresolvedMVar : Nat → CheckError
51+
| u8LitOutOfRange : Nat → CheckError
5152
deriving Repr
5253

5354
instance : ToString CheckError where
@@ -59,6 +60,7 @@ open Source
5960
def Typ.instantiate (subst : Global → Option Typ) : Typ → Typ
6061
| .unit => .unit
6162
| .field => .field
63+
| .u8 => .u8
6264
| .tuple ts => .tuple (ts.attach.map (fun ⟨t, _⟩ => Typ.instantiate subst t))
6365
| .array t n => .array (Typ.instantiate subst t) n
6466
| .pointer t => .pointer (Typ.instantiate subst t)
@@ -83,6 +85,7 @@ def expandTypeMBound : Nat → Std.HashSet Global → Array TypeAlias →
8385
Typ → StateT (Std.HashMap Global Typ) (Except CheckError) Typ
8486
| _, _, _, .unit => pure .unit
8587
| _, _, _, .field => pure .field
88+
| _, _, _, .u8 => pure .u8
8689
| bound, visited, tops, .pointer t => do
8790
pure $ .pointer (← expandTypeMBound bound visited tops t)
8891
| bound, visited, tops, .function inputs output => do
@@ -307,7 +310,7 @@ def unifyTypBound : Nat → Typ → Typ → CheckM Bool
307310
if a == b then pure true else do bindMVar a (.mvar b); pure true
308311
| .mvar a, _ => do bindMVar a t2; pure true
309312
| _, .mvar b => do bindMVar b t1; pure true
310-
| .unit, .unit | .field, .field => pure true
313+
| .unit, .unit | .field, .field | .u8, .u8 => pure true
311314
| .tuple ts1, .tuple ts2 =>
312315
if ts1.size != ts2.size then pure false else
313316
ts1.zip ts2 |>.allM fun (x, y) => unifyTypBound bound x y
@@ -330,6 +333,7 @@ available proxy for `sizeOf` in `unifyTyp`'s bound. -/
330333
def Typ.nodeCount : Typ → Nat
331334
| .unit => 1
332335
| .field => 1
336+
| .u8 => 1
333337
| .mvar _ => 1
334338
| .ref _ => 1
335339
| .pointer t => 1 + Typ.nodeCount t
@@ -445,8 +449,10 @@ def checkPatternAux (pat : Pattern) (typ : Typ) : CheckM (List (Local × Typ)) :
445449
| .var var => pure [(var, typ)]
446450
| .wildcard => pure []
447451
| .field _ =>
452+
-- A numeric-literal pattern matches both `field` and `u8` scrutinees: a
453+
-- `u8` is a field element, so comparing it against a byte constant is sound.
448454
match typ with
449-
| .field => pure []
455+
| .field | .u8 => pure []
450456
| _ => throw $ .incompatiblePattern pat typ
451457
| .tuple pats =>
452458
match typ with
@@ -711,54 +717,73 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with
711717
let b' ← checkNoEscape b .field
712718
pure (Typed.Term.mul .field false a' b')
713719
| .u8ShiftLeft a => do
714-
let a' ← checkNoEscape a .field
715-
pure (Typed.Term.u8ShiftLeft .field false a')
720+
let a' ← checkNoEscape a .u8
721+
pure (Typed.Term.u8ShiftLeft .u8 false a')
716722
| .u8ShiftRight a => do
717-
let a' ← checkNoEscape a .field
718-
pure (Typed.Term.u8ShiftRight .field false a')
723+
let a' ← checkNoEscape a .u8
724+
pure (Typed.Term.u8ShiftRight .u8 false a')
719725
| .u8BitDecomposition a => do
720-
let a' ← checkNoEscape a .field
726+
-- Bits are 0/1 values, kept as plain `field`.
727+
let a' ← checkNoEscape a .u8
721728
pure (Typed.Term.u8BitDecomposition (.array .field 8) false a')
722729
| .u8Xor a b => do
723-
let a' ← checkNoEscape a .field
724-
let b' ← checkNoEscape b .field
725-
pure (Typed.Term.u8Xor .field false a' b')
730+
let a' ← checkNoEscape a .u8
731+
let b' ← checkNoEscape b .u8
732+
pure (Typed.Term.u8Xor .u8 false a' b')
726733
| .u8And a b => do
727-
let a' ← checkNoEscape a .field
728-
let b' ← checkNoEscape b .field
729-
pure (Typed.Term.u8And .field false a' b')
734+
let a' ← checkNoEscape a .u8
735+
let b' ← checkNoEscape b .u8
736+
pure (Typed.Term.u8And .u8 false a' b')
730737
| .u8Or a b => do
731-
let a' ← checkNoEscape a .field
732-
let b' ← checkNoEscape b .field
733-
pure (Typed.Term.u8Or .field false a' b')
738+
let a' ← checkNoEscape a .u8
739+
let b' ← checkNoEscape b .u8
740+
pure (Typed.Term.u8Or .u8 false a' b')
734741
| .u8Add a b => do
735-
let a' ← checkNoEscape a .field
736-
let b' ← checkNoEscape b .field
737-
pure (Typed.Term.u8Add (.tuple #[.field, .field]) false a' b')
742+
-- Low byte and the 0/1 carry are both `u8` (the carry is provably in range:
743+
-- the add lookup forces the inputs to be bytes, so `carry ∈ {0, 1}`).
744+
let a' ← checkNoEscape a .u8
745+
let b' ← checkNoEscape b .u8
746+
pure (Typed.Term.u8Add (.tuple #[.u8, .u8]) false a' b')
738747
| .u8Mul a b => do
739-
let a' ← checkNoEscape a .field
740-
let b' ← checkNoEscape b .field
741-
pure (Typed.Term.u8Mul (.tuple #[.field, .field]) false a' b')
748+
-- Both low and high bytes are `u8`.
749+
let a' ← checkNoEscape a .u8
750+
let b' ← checkNoEscape b .u8
751+
pure (Typed.Term.u8Mul (.tuple #[.u8, .u8]) false a' b')
742752
| .u8ChainRotr7 a b => do
743-
let a' ← checkNoEscape a .field
744-
let b' ← checkNoEscape b .field
745-
pure (Typed.Term.u8ChainRotr7 (.tuple #[.field, .field, .field]) false a' b')
753+
let a' ← checkNoEscape a .u8
754+
let b' ← checkNoEscape b .u8
755+
pure (Typed.Term.u8ChainRotr7 (.tuple #[.u8, .u8, .u8]) false a' b')
746756
| .u8ChainRotr4 a b => do
747-
let a' ← checkNoEscape a .field
748-
let b' ← checkNoEscape b .field
749-
pure (Typed.Term.u8ChainRotr4 (.tuple #[.field, .field, .field]) false a' b')
757+
let a' ← checkNoEscape a .u8
758+
let b' ← checkNoEscape b .u8
759+
pure (Typed.Term.u8ChainRotr4 (.tuple #[.u8, .u8, .u8]) false a' b')
750760
| .u8Sub a b => do
751-
let a' ← checkNoEscape a .field
752-
let b' ← checkNoEscape b .field
753-
pure (Typed.Term.u8Sub (.tuple #[.field, .field]) false a' b')
761+
-- Low byte and the 0/1 borrow are both `u8` (same range argument as add).
762+
let a' ← checkNoEscape a .u8
763+
let b' ← checkNoEscape b .u8
764+
pure (Typed.Term.u8Sub (.tuple #[.u8, .u8]) false a' b')
754765
| .u8LessThan a b => do
755-
let a' ← checkNoEscape a .field
756-
let b' ← checkNoEscape b .field
766+
-- Result is a 0/1 flag (`field`).
767+
let a' ← checkNoEscape a .u8
768+
let b' ← checkNoEscape b .u8
757769
pure (Typed.Term.u8LessThan .field false a' b')
758770
| .u32LessThan a b => do
759771
let a' ← checkNoEscape a .field
760772
let b' ← checkNoEscape b .field
761773
pure (Typed.Term.u32LessThan .field false a' b')
774+
| .u8Lit n => do
775+
if n ≥ 256 then throw (.u8LitOutOfRange n)
776+
pure (Typed.Term.field .u8 false (G.ofNat n))
777+
| .u8RangeCheck a b => do
778+
let a' ← checkNoEscape a .field
779+
let b' ← checkNoEscape b .field
780+
pure (Typed.Term.u8RangeCheck (.tuple #[.u8, .u8]) false a' b')
781+
| .toField a => do
782+
let a' ← checkNoEscape a .u8
783+
pure (Typed.Term.toField .field false a')
784+
| .u8FromFieldUnsafe a => do
785+
let a' ← checkNoEscape a .field
786+
pure (Typed.Term.u8FromFieldUnsafe .u8 false a')
762787
| .ioGetInfo key => do
763788
let key' ← inferNoEscape key
764789
match ← walkTyp key'.typ with
@@ -922,6 +947,11 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with
922947
pure (.u8LessThan (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
923948
| .u32LessThan τ e a b => do
924949
pure (.u32LessThan (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
950+
| .u8RangeCheck τ e a b => do
951+
pure (.u8RangeCheck (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
952+
| .toField τ e a => do pure (.toField (← zonkTyp τ) e (← zonkTypedTerm a))
953+
| .u8FromFieldUnsafe τ e a => do
954+
pure (.u8FromFieldUnsafe (← zonkTyp τ) e (← zonkTypedTerm a))
925955
| .debug τ e label t r => do
926956
let t' ← match t with
927957
| none => pure none

Ix/Aiur/Compiler/Concretize.lean

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def typToConcrete (mono : Std.HashMap (Global × Array Typ) Global) :
5050
Typ → Except ConcretizeError Concrete.Typ
5151
| .unit => pure .unit
5252
| .field => pure .field
53+
-- `u8` is erased here: same representation as `field`, distinction no longer
54+
-- needed past type-checking.
55+
| .u8 => pure .field
5356
| .tuple ts => do
5457
pure (.tuple (← ts.attach.mapM fun ⟨t, _⟩ => typToConcrete mono t))
5558
| .array t n => do pure (.array (← typToConcrete mono t) n)
@@ -77,6 +80,7 @@ termination_by t => sizeOf t
7780

7881
def Typ.toFlatName : Typ → String
7982
| .field => "G"
83+
| .u8 => "U8"
8084
| .unit => "Unit"
8185
| Typ.ref g => g.toName.toString (escape := false)
8286
| .pointer t => "Ptr_" ++ t.toFlatName
@@ -93,6 +97,7 @@ decreasing_by all_goals first | decreasing_tactic | grind
9397

9498
def Typ.appendNameLimbs (g : Global) : Typ → Global
9599
| .field => g.pushNamespace "G"
100+
| .u8 => g.pushNamespace "U8"
96101
| .unit => g.pushNamespace "Unit"
97102
| Typ.ref g' =>
98103
let rec pushAll (g : Global) : Lean.Name → Global
@@ -349,6 +354,12 @@ def termToConcrete
349354
| .u32LessThan τ e a b => do
350355
pure (.u32LessThan (← typToConcrete mono τ) e
351356
(← termToConcrete mono a) (← termToConcrete mono b))
357+
| .u8RangeCheck τ e a b => do
358+
pure (.u8RangeCheck (← typToConcrete mono τ) e
359+
(← termToConcrete mono a) (← termToConcrete mono b))
360+
-- `toField` / `u8FromFieldUnsafe` are erased coercions: `u8` and `field`
361+
-- share a representation, so we drop the wrapper and keep the inner term.
362+
| .toField _ _ a | .u8FromFieldUnsafe _ _ a => termToConcrete mono a
352363
| .debug τ e l t r => do
353364
-- Inline the Option.mapM case-split so termination sees the sub-Term.
354365
let t' ← match t with
@@ -546,6 +557,12 @@ def rewriteTypedTerm (decls : Typed.Decls)
546557
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
547558
| .u32LessThan τ e a b => .u32LessThan (rewriteTyp subst mono τ) e
548559
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
560+
| .u8RangeCheck τ e a b => .u8RangeCheck (rewriteTyp subst mono τ) e
561+
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
562+
| .toField τ e a => .toField (rewriteTyp subst mono τ) e
563+
(rewriteTypedTerm decls subst mono a)
564+
| .u8FromFieldUnsafe τ e a => .u8FromFieldUnsafe (rewriteTyp subst mono τ) e
565+
(rewriteTypedTerm decls subst mono a)
549566
| .debug τ e l t r =>
550567
let t' := match t with
551568
| none => none
@@ -613,11 +630,12 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) :
613630
args.attach.foldl (fun s ⟨a, _⟩ => collectInTypedTerm s a) seen
614631
| .add τ _ a b | .sub τ _ a b | .mul τ _ a b
615632
| .u8Xor τ _ a b | .u8Add τ _ a b | .u8Mul τ _ a b | .u8Sub τ _ a b
616-
| .u8ChainRotr7 τ _ a b | .u8ChainRotr4 τ _ a b
633+
| .u8ChainRotr7 τ _ a b | .u8ChainRotr4 τ _ a b | .u8RangeCheck τ _ a b
617634
| .u8And τ _ a b | .u8Or τ _ a b
618635
| .u8LessThan τ _ a b | .u32LessThan τ _ a b =>
619636
collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b
620-
| .eqZero τ _ a | .store τ _ a | .load τ _ a | .ptrVal τ _ a
637+
| .eqZero τ _ a | .store τ _ a | .load τ _ a | .ptrVal τ _ a | .toField τ _ a
638+
| .u8FromFieldUnsafe τ _ a
621639
| .u8BitDecomposition τ _ a | .u8ShiftLeft τ _ a | .u8ShiftRight τ _ a
622640
| .ioGetInfo τ _ a => collectInTypedTerm (collectInTyp seen τ) a
623641
| .proj τ _ a _ | .get τ _ a _ | .slice τ _ a _ _ =>
@@ -676,11 +694,12 @@ def collectCalls (decls : Typed.Decls)
676694
bs.attach.foldl (fun s ⟨(_, b), _⟩ => collectCalls decls s b) seen
677695
| .add _ _ a b | .sub _ _ a b | .mul _ _ a b
678696
| .u8Xor _ _ a b | .u8Add _ _ a b | .u8Mul _ _ a b | .u8Sub _ _ a b
679-
| .u8ChainRotr7 _ _ a b | .u8ChainRotr4 _ _ a b
697+
| .u8ChainRotr7 _ _ a b | .u8ChainRotr4 _ _ a b | .u8RangeCheck _ _ a b
680698
| .u8And _ _ a b | .u8Or _ _ a b
681699
| .u8LessThan _ _ a b | .u32LessThan _ _ a b =>
682700
collectCalls decls (collectCalls decls seen a) b
683-
| .eqZero _ _ a | .store _ _ a | .load _ _ a | .ptrVal _ _ a
701+
| .eqZero _ _ a | .store _ _ a | .load _ _ a | .ptrVal _ _ a | .toField _ _ a
702+
| .u8FromFieldUnsafe _ _ a
684703
| .u8BitDecomposition _ _ a | .u8ShiftLeft _ _ a | .u8ShiftRight _ _ a
685704
| .ioGetInfo _ _ a => collectCalls decls seen a
686705
| .proj _ _ a _ | .get _ _ a _ | .slice _ _ a _ _ => collectCalls decls seen a
@@ -782,6 +801,11 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term
782801
(substInTypedTerm subst a) (substInTypedTerm subst b)
783802
| .u32LessThan τ e a b => .u32LessThan (Typ.instantiate subst τ) e
784803
(substInTypedTerm subst a) (substInTypedTerm subst b)
804+
| .u8RangeCheck τ e a b => .u8RangeCheck (Typ.instantiate subst τ) e
805+
(substInTypedTerm subst a) (substInTypedTerm subst b)
806+
| .toField τ e a => .toField (Typ.instantiate subst τ) e (substInTypedTerm subst a)
807+
| .u8FromFieldUnsafe τ e a =>
808+
.u8FromFieldUnsafe (Typ.instantiate subst τ) e (substInTypedTerm subst a)
785809
| .debug τ e l t r =>
786810
let t' := match t with
787811
| none => none

Ix/Aiur/Compiler/Layout.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def opLayout : Bytecode.Op → LayoutM Unit
198198
| .u8ChainRotr7 .. | .u8ChainRotr4 .. => do pushDegrees #[1, 1, 1]; bumpAuxiliaries 3; bumpLookups
199199
| .u8LessThan .. => do pushDegree 1; bumpAuxiliaries; bumpLookups
200200
| .u32LessThan .. => do pushDegree 1; bumpAuxiliaries 12; bumpLookups 6
201+
-- Pure range-check lookup: no output columns/degrees, just one lookup.
202+
| .u8RangeCheck .. => bumpLookups
201203
| .debug .. => pure ()
202204

203205
/-- Termination helper for blockLayout's Block/Ctrl traversal. -/

Ix/Aiur/Compiler/Lower.lean

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,13 @@ def toIndex
299299
| .u8Or _ _ i j => do let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j; pushOp (.u8Or i j)
300300
| .u8LessThan _ _ i j => do let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j; pushOp (.u8LessThan i j)
301301
| .u32LessThan _ _ i j => do let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j; pushOp (.u32LessThan i j)
302+
| .u8RangeCheck _ _ i j => do
303+
-- Side-effecting lookup; the two `u8` outputs alias the inputs, so no new
304+
-- value slots are allocated (cf. `.assertEq`).
305+
let i ← expectIdx layoutMap bindings i
306+
let j ← expectIdx layoutMap bindings j
307+
modify fun stt => { stt with ops := stt.ops.push (.u8RangeCheck i j) }
308+
pure #[i, j]
302309
| .debug _ _ label term ret => do
303310
let term ← match term with
304311
| none => pure none

Ix/Aiur/Compiler/Match.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ def typedToSimple : Term → Simple.Term
388388
| .u8Or τ e a b => .u8Or τ e (typedToSimple a) (typedToSimple b)
389389
| .u8LessThan τ e a b => .u8LessThan τ e (typedToSimple a) (typedToSimple b)
390390
| .u32LessThan τ e a b => .u32LessThan τ e (typedToSimple a) (typedToSimple b)
391+
| .u8RangeCheck τ e a b => .u8RangeCheck τ e (typedToSimple a) (typedToSimple b)
392+
| .toField τ e a => .toField τ e (typedToSimple a)
393+
| .u8FromFieldUnsafe τ e a => .u8FromFieldUnsafe τ e (typedToSimple a)
391394
| .debug τ e l t r =>
392395
let t' := match t with | none => none | some sub => some (typedToSimple sub)
393396
.debug τ e l t' (typedToSimple r)

Ix/Aiur/Compiler/Simple.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term
108108
let a' ← simplifyTypedTerm decls a
109109
let b' ← simplifyTypedTerm decls b
110110
pure (.u32LessThan τ e a' b')
111+
| .u8RangeCheck τ e a b => do
112+
let a' ← simplifyTypedTerm decls a
113+
let b' ← simplifyTypedTerm decls b
114+
pure (.u8RangeCheck τ e a' b')
115+
| .toField τ e a => do
116+
let a' ← simplifyTypedTerm decls a
117+
pure (.toField τ e a')
118+
| .u8FromFieldUnsafe τ e a => do
119+
let a' ← simplifyTypedTerm decls a
120+
pure (.u8FromFieldUnsafe τ e a')
111121
| t => pure t
112122
termination_by t => sizeOf t
113123
decreasing_by

Ix/Aiur/Interpret.lean

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,15 @@ partial def interp (decls : Decls) (bindings : Bindings) : Term → InterpM Valu
400400
| .field a, .field b =>
401401
return .field (if a.val.toUInt32 < b.val.toUInt32 then 1 else 0)
402402
| _, _ => throwErr "u32LessThan: expected field values"
403+
| .u8RangeCheck t1 t2 => do
404+
match (← interp decls bindings t1), (← interp decls bindings t2) with
405+
| .field a, .field b =>
406+
if a.val < 256 && b.val < 256 then return .tuple #[.field a, .field b]
407+
else throwErr "u8RangeCheck: value out of range [0, 256)"
408+
| _, _ => throwErr "u8RangeCheck: expected field values"
409+
-- `toField` / `u8FromFieldUnsafe` are erased coercions: value unchanged.
410+
| .toField t | .u8FromFieldUnsafe t => interp decls bindings t
411+
| .u8Lit n => return .field (G.ofNat n)
403412
| .debug label optT cont => do
404413
match optT with
405414
| none => dbg_trace s!"{label}"

0 commit comments

Comments
 (0)