@@ -48,6 +48,7 @@ inductive CheckError
4848 | unconstrainedConstructor : Global → CheckError
4949 | infiniteType : Nat → Typ → CheckError
5050 | unresolvedMVar : Nat → CheckError
51+ | u8LitOutOfRange : Nat → CheckError
5152 deriving Repr
5253
5354instance : ToString CheckError where
@@ -59,6 +60,7 @@ open Source
5960def Typ.instantiate (subst : Global → Option Typ) : Typ → Typ
6061 | .unit => .unit
6162 | .field => .field
63+ | .u8 => .u8
6264 | .tuple ts => .tuple (ts.attach.map (fun ⟨t, _⟩ => Typ.instantiate subst t))
6365 | .array t n => .array (Typ.instantiate subst t) n
6466 | .pointer t => .pointer (Typ.instantiate subst t)
@@ -83,6 +85,7 @@ def expandTypeMBound : Nat → Std.HashSet Global → Array TypeAlias →
8385 Typ → StateT (Std.HashMap Global Typ) (Except CheckError) Typ
8486 | _, _, _, .unit => pure .unit
8587 | _, _, _, .field => pure .field
88+ | _, _, _, .u8 => pure .u8
8689 | bound, visited, tops, .pointer t => do
8790 pure $ .pointer (← expandTypeMBound bound visited tops t)
8891 | bound, visited, tops, .function inputs output => do
@@ -307,7 +310,7 @@ def unifyTypBound : Nat → Typ → Typ → CheckM Bool
307310 if a == b then pure true else do bindMVar a (.mvar b); pure true
308311 | .mvar a, _ => do bindMVar a t2; pure true
309312 | _, .mvar b => do bindMVar b t1; pure true
310- | .unit, .unit | .field, .field => pure true
313+ | .unit, .unit | .field, .field | .u8, .u8 => pure true
311314 | .tuple ts1, .tuple ts2 =>
312315 if ts1.size != ts2.size then pure false else
313316 ts1.zip ts2 |>.allM fun (x, y) => unifyTypBound bound x y
@@ -330,6 +333,7 @@ available proxy for `sizeOf` in `unifyTyp`'s bound. -/
330333def Typ.nodeCount : Typ → Nat
331334 | .unit => 1
332335 | .field => 1
336+ | .u8 => 1
333337 | .mvar _ => 1
334338 | .ref _ => 1
335339 | .pointer t => 1 + Typ.nodeCount t
@@ -445,8 +449,10 @@ def checkPatternAux (pat : Pattern) (typ : Typ) : CheckM (List (Local × Typ)) :
445449 | .var var => pure [(var, typ)]
446450 | .wildcard => pure []
447451 | .field _ =>
452+ -- A numeric-literal pattern matches both `field` and `u8` scrutinees: a
453+ -- `u8` is a field element, so comparing it against a byte constant is sound.
448454 match typ with
449- | .field => pure []
455+ | .field | .u8 => pure []
450456 | _ => throw $ .incompatiblePattern pat typ
451457 | .tuple pats =>
452458 match typ with
@@ -711,54 +717,73 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with
711717 let b' ← checkNoEscape b .field
712718 pure (Typed.Term.mul .field false a' b')
713719 | .u8ShiftLeft a => do
714- let a' ← checkNoEscape a .field
715- pure (Typed.Term.u8ShiftLeft .field false a')
720+ let a' ← checkNoEscape a .u8
721+ pure (Typed.Term.u8ShiftLeft .u8 false a')
716722 | .u8ShiftRight a => do
717- let a' ← checkNoEscape a .field
718- pure (Typed.Term.u8ShiftRight .field false a')
723+ let a' ← checkNoEscape a .u8
724+ pure (Typed.Term.u8ShiftRight .u8 false a')
719725 | .u8BitDecomposition a => do
720- let a' ← checkNoEscape a .field
726+ -- Bits are 0/1 values, kept as plain `field`.
727+ let a' ← checkNoEscape a .u8
721728 pure (Typed.Term.u8BitDecomposition (.array .field 8 ) false a')
722729 | .u8Xor a b => do
723- let a' ← checkNoEscape a .field
724- let b' ← checkNoEscape b .field
725- pure (Typed.Term.u8Xor .field false a' b')
730+ let a' ← checkNoEscape a .u8
731+ let b' ← checkNoEscape b .u8
732+ pure (Typed.Term.u8Xor .u8 false a' b')
726733 | .u8And a b => do
727- let a' ← checkNoEscape a .field
728- let b' ← checkNoEscape b .field
729- pure (Typed.Term.u8And .field false a' b')
734+ let a' ← checkNoEscape a .u8
735+ let b' ← checkNoEscape b .u8
736+ pure (Typed.Term.u8And .u8 false a' b')
730737 | .u8Or a b => do
731- let a' ← checkNoEscape a .field
732- let b' ← checkNoEscape b .field
733- pure (Typed.Term.u8Or .field false a' b')
738+ let a' ← checkNoEscape a .u8
739+ let b' ← checkNoEscape b .u8
740+ pure (Typed.Term.u8Or .u8 false a' b')
734741 | .u8Add a b => do
735- let a' ← checkNoEscape a .field
736- let b' ← checkNoEscape b .field
737- pure (Typed.Term.u8Add (.tuple #[.field, .field]) false a' b')
742+ -- Low byte and the 0/1 carry are both `u8` (the carry is provably in range:
743+ -- the add lookup forces the inputs to be bytes, so `carry ∈ {0, 1}`).
744+ let a' ← checkNoEscape a .u8
745+ let b' ← checkNoEscape b .u8
746+ pure (Typed.Term.u8Add (.tuple #[.u8, .u8]) false a' b')
738747 | .u8Mul a b => do
739- let a' ← checkNoEscape a .field
740- let b' ← checkNoEscape b .field
741- pure (Typed.Term.u8Mul (.tuple #[.field, .field]) false a' b')
748+ -- Both low and high bytes are `u8`.
749+ let a' ← checkNoEscape a .u8
750+ let b' ← checkNoEscape b .u8
751+ pure (Typed.Term.u8Mul (.tuple #[.u8, .u8]) false a' b')
742752 | .u8ChainRotr7 a b => do
743- let a' ← checkNoEscape a .field
744- let b' ← checkNoEscape b .field
745- pure (Typed.Term.u8ChainRotr7 (.tuple #[.field , .field , .field ]) false a' b')
753+ let a' ← checkNoEscape a .u8
754+ let b' ← checkNoEscape b .u8
755+ pure (Typed.Term.u8ChainRotr7 (.tuple #[.u8 , .u8 , .u8 ]) false a' b')
746756 | .u8ChainRotr4 a b => do
747- let a' ← checkNoEscape a .field
748- let b' ← checkNoEscape b .field
749- pure (Typed.Term.u8ChainRotr4 (.tuple #[.field , .field , .field ]) false a' b')
757+ let a' ← checkNoEscape a .u8
758+ let b' ← checkNoEscape b .u8
759+ pure (Typed.Term.u8ChainRotr4 (.tuple #[.u8 , .u8 , .u8 ]) false a' b')
750760 | .u8Sub a b => do
751- let a' ← checkNoEscape a .field
752- let b' ← checkNoEscape b .field
753- pure (Typed.Term.u8Sub (.tuple #[.field, .field]) false a' b')
761+ -- Low byte and the 0/1 borrow are both `u8` (same range argument as add).
762+ let a' ← checkNoEscape a .u8
763+ let b' ← checkNoEscape b .u8
764+ pure (Typed.Term.u8Sub (.tuple #[.u8, .u8]) false a' b')
754765 | .u8LessThan a b => do
755- let a' ← checkNoEscape a .field
756- let b' ← checkNoEscape b .field
766+ -- Result is a 0/1 flag (`field`).
767+ let a' ← checkNoEscape a .u8
768+ let b' ← checkNoEscape b .u8
757769 pure (Typed.Term.u8LessThan .field false a' b')
758770 | .u32LessThan a b => do
759771 let a' ← checkNoEscape a .field
760772 let b' ← checkNoEscape b .field
761773 pure (Typed.Term.u32LessThan .field false a' b')
774+ | .u8Lit n => do
775+ if n ≥ 256 then throw (.u8LitOutOfRange n)
776+ pure (Typed.Term.field .u8 false (G.ofNat n))
777+ | .u8RangeCheck a b => do
778+ let a' ← checkNoEscape a .field
779+ let b' ← checkNoEscape b .field
780+ pure (Typed.Term.u8RangeCheck (.tuple #[.u8, .u8]) false a' b')
781+ | .toField a => do
782+ let a' ← checkNoEscape a .u8
783+ pure (Typed.Term.toField .field false a')
784+ | .u8FromFieldUnsafe a => do
785+ let a' ← checkNoEscape a .field
786+ pure (Typed.Term.u8FromFieldUnsafe .u8 false a')
762787 | .ioGetInfo key => do
763788 let key' ← inferNoEscape key
764789 match ← walkTyp key'.typ with
@@ -922,6 +947,11 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with
922947 pure (.u8LessThan (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
923948 | .u32LessThan τ e a b => do
924949 pure (.u32LessThan (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
950+ | .u8RangeCheck τ e a b => do
951+ pure (.u8RangeCheck (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b))
952+ | .toField τ e a => do pure (.toField (← zonkTyp τ) e (← zonkTypedTerm a))
953+ | .u8FromFieldUnsafe τ e a => do
954+ pure (.u8FromFieldUnsafe (← zonkTyp τ) e (← zonkTypedTerm a))
925955 | .debug τ e label t r => do
926956 let t' ← match t with
927957 | none => pure none
0 commit comments