@@ -31,8 +31,6 @@ namespace Strata.Laurel
3131open Strata
3232
3333abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType
34- /-- Map from variable name to its constrained HighType (e.g. UserDefined "nat") -/
35- abbrev PredVarMap := Std.HashMap String HighType
3634
3735def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap :=
3836 types.foldl (init := {}) fun m td =>
@@ -52,20 +50,32 @@ def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd :=
5250def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool :=
5351 match ty with | .UserDefined name => ptMap.contains name.text | _ => false
5452
55- /-- Build a call to the constraint function for a constrained type, or `none` if not constrained -/
56- def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
57- (varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd :=
53+ /-- Build a call to the constraint function for a constrained type, asserting
54+ the constraint on the read-back expression `ref`. Returns `none` if `ty` is
55+ not a constrained type.
56+
57+ `ref` is the expression whose value is checked (e.g. a local read
58+ `x` or a field read `c#count`), allowing this to serve every assignment
59+ target kind uniformly. -/
60+ def constraintCallForExpr (ptMap : ConstrainedTypeMap) (ty : HighType)
61+ (ref : StmtExprMd) (src : Option FileRange := none) : Option StmtExprMd :=
5862 match ty with
5963 | .UserDefined name => if ptMap.contains name.text then
60- some ⟨.StaticCall (mkId s! "{ name.text} $constraint" ) [⟨.Var (.Local varName), src⟩ ], src⟩
64+ some ⟨.StaticCall (mkId s! "{ name.text} $constraint" ) [ref ], src⟩
6165 else none
6266 | _ => none
6367
68+ /-- Build a call to the constraint function for a constrained type, checking a
69+ local variable read, or `none` if not constrained. -/
70+ def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
71+ (varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd :=
72+ constraintCallForExpr ptMap ty ⟨.Var (.Local varName), src⟩ src
73+
6474/-- Generate a constraint function for a constrained type.
6575 For nested types, the function calls the parent's constraint function. -/
6676def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure :=
6777 let baseType := resolveType ptMap ct.base
68- let bodyExpr := match ct.base.val with
78+ let bodyExpr: StmtExprMd := match ct.base.val with
6979 | .UserDefined parent =>
7080 if ptMap.contains parent.text then
7181 let paramId := { ct.valueName with uniqueId := none }
@@ -79,15 +89,11 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce
7989 { name := mkId s! "{ ct.name.text} $constraint"
8090 inputs := [{ name := ct.valueName, type := baseType }]
8191 outputs := [{ name := mkId "result" , type := { val := .TBool, source := none } }]
82- body := .Transparent { val := .Block [ bodyExpr] none , source := none }
92+ body := .Transparent { val := .Return bodyExpr, source := none }
8393 isFunctional := true
8494 decreases := none
8595 preconditions := [] }
8696
87- private def wrap (stmts : List StmtExprMd) (src : Option FileRange)
88- : StmtExprMd :=
89- match stmts with | [s] => s | ss => ⟨.Block ss none, src⟩
90-
9197def resolveVariable (ptMap : ConstrainedTypeMap) (v : VariableMd) : VariableMd :=
9298 match v.val with
9399 | .Declare param => ⟨.Declare { param with type := resolveType ptMap param.type }, v.source⟩
@@ -118,88 +124,60 @@ def resolveExprNode (ptMap : ConstrainedTypeMap) (expr : StmtExprMd) : StmtExprM
118124 | .IsType t ty => ⟨.IsType t (resolveType ptMap ty), source⟩
119125 | _ => expr
120126
121- abbrev ElimM := StateM PredVarMap
122-
123- private def inScope (action : ElimM α) : ElimM α := do
124- let saved ← get
125- let result ← action
126- set saved
127- return result
128-
129- def elimStmt (ptMap : ConstrainedTypeMap)
130- (stmt : StmtExprMd) : ElimM (List StmtExprMd) := do
131- let source := stmt.source
132-
133- match _h : stmt.val with
127+ /-- Per-node constrained-type elimination, applied bottom-up (with flattening)
128+ by `mapStmtExprFlattenM`. `resultUsed` is `true` when the node occupies a
129+ value position.
130+
131+ - Uninitialized constrained declaration `var x: T;` → assume its constraint.
132+ - Assignment to constrained target(s) → emit the assignment followed by an
133+ `assert T$constraint(<read-back>)` per constrained target. The constraint
134+ is checked on a *read-back* of the target rather than on the RHS, so the
135+ RHS is evaluated exactly once. In value position the read-back is also
136+ appended as the final statement, so the resulting value-block evaluates to
137+ the assigned value (this covers expression-position assignments such as
138+ `y := (x := -1) + 1`); in statement position it is omitted.
139+ - All other nodes are returned unchanged; the traversal handles recursion. -/
140+ def elimNode (ptMap : ConstrainedTypeMap) (model : SemanticModel)
141+ (resultUsed : Bool) (node : StmtExprMd) : List StmtExprMd :=
142+ let source := node.source
143+ match node.val with
134144 | .Var (.Declare param) =>
135- let callOpt := constraintCallFor ptMap param.type.val param.name (src := source)
136- if callOpt.isSome then modify fun pv => pv.insert param.name.text param.type.val
137- let check := match callOpt with
138- | some c => [⟨.Assume c, source⟩]
139- | none => []
140- pure ([stmt] ++ check)
141-
145+ let check := (constraintCallFor ptMap param.type.val param.name (src := source)).toList.map
146+ fun c => ⟨.Assume c, source⟩
147+ [node] ++ check
142148 | .Assign targets _value =>
143- -- Handle Declare targets for constrained type elimination
144- let declareChecks ← targets.foldlM (init := ([] : List StmtExprMd)) fun acc target =>
145- match target.val with
146- | .Declare param => do
147- let callOpt := constraintCallFor ptMap param.type.val param.name (src := source)
148- if callOpt.isSome then modify fun pv => pv.insert param.name.text param.type.val
149- pure (acc ++ callOpt.toList.map fun c => ⟨.Assert { condition := c }, source⟩)
150- | .Local name => do
151- match (← get).get? name.text with
152- | some ty =>
153- let assert := (constraintCallFor ptMap ty name (src := source)).toList.map
154- fun c => ⟨.Assert { condition := c }, source⟩
155- pure (acc ++ assert)
156- | none => pure acc
157- | _ => pure acc
158- pure ([stmt] ++ declareChecks)
159-
160- | .Block stmts sep =>
161- let stmtss ← inScope (stmts.mapM (elimStmt ptMap))
162- pure [⟨.Block stmtss.flatten sep, source⟩]
163-
164- | .IfThenElse cond thenBr (some elseBr) =>
165- let thenSs ← inScope (elimStmt ptMap thenBr)
166- let elseSs ← inScope (elimStmt ptMap elseBr)
167- pure [⟨.IfThenElse cond (wrap thenSs source) (some (wrap elseSs source)), source⟩]
168- | .IfThenElse cond thenBr none =>
169- let thenSs ← inScope (elimStmt ptMap thenBr)
170- pure [⟨.IfThenElse cond (wrap thenSs source) none, source⟩]
171-
172- | .While cond inv dec body =>
173- let bodySs ← inScope (elimStmt ptMap body)
174- pure [⟨.While cond inv dec (wrap bodySs source), source⟩]
175-
176- | _ => pure [stmt]
177- termination_by sizeOf stmt
178- decreasing_by
179- all_goals simp_wf
180- all_goals (try have := AstNode.sizeOf_val_lt stmt)
181- all_goals (try term_by_mem)
182- all_goals omega
183-
184- def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure :=
149+ let asserts: List StmtExprMd := targets.filterMap (fun target =>
150+ let ref : StmtExprMd := VariableMd.toReadbackExpr target
151+ let ty : HighType := (computeExprType model ref).val
152+ (constraintCallForExpr ptMap ty ref (src := source)).map (⟨.Assert { condition := · }, source⟩))
153+ let suffix := match targets with
154+ | [single] => if resultUsed then [VariableMd.toReadbackExpr single] else []
155+ | _ => []
156+ [node] ++ asserts ++ suffix
157+ | _ => [node]
158+
159+ /-- Apply `elimNode` across a body via the flattening, `resultUsed`-aware
160+ traversal. A procedure body is a statement, so the top-level `resultUsed`
161+ is `false`. -/
162+ def elimStmts (ptMap : ConstrainedTypeMap) (model : SemanticModel) (body : StmtExprMd) : StmtExprMd :=
163+ mapStmtExprFlattenM (m := Id) (fun _ _ => none) (elimNode ptMap model) false body
164+
165+ def elimProc (ptMap : ConstrainedTypeMap) (model : SemanticModel) (proc : Procedure) : Procedure :=
185166 let inputRequires : List Condition := proc.inputs.filterMap fun p =>
186167 (constraintCallFor ptMap p.type.val p.name (src := p.type.source)).map
187168 fun c => { condition := c }
188169 let outputEnsures : List Condition := if proc.isFunctional then [] else proc.outputs.filterMap fun p =>
189170 (constraintCallFor ptMap p.type.val p.name (src := p.type.source)).map
190171 fun c => { condition := ⟨c.val, p.type.source⟩ }
191- let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p =>
192- if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s
193172 let body' := match proc.body with
194173 | .Transparent bodyExpr =>
195- let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars
196- let body := wrap stmts bodyExpr.source
174+ let body := elimStmts ptMap model bodyExpr
197175 if outputEnsures.isEmpty then .Transparent body
198176 else
199177 let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.source⟩ else body
200178 .Opaque outputEnsures (some retBody) []
201179 | .Opaque postconds impl modif =>
202- let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars). 1 b.source
180+ let impl' := impl.map (elimStmts ptMap model)
203181 .Opaque (postconds ++ outputEnsures) impl' modif
204182 | .Abstract postconds => .Abstract (postconds ++ outputEnsures)
205183 | .External => .External
@@ -231,7 +209,20 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) :
231209 isFunctional := false
232210 decreases := none }
233211
234- public def constrainedTypeElim (_model : SemanticModel) (program : Program)
212+ /-- Eliminate constrained types within a composite type definition: resolve
213+ constrained field types to their base types and run constrained type
214+ elimination on the composite's instance procedures.
215+
216+ This is necessary because `constrainedTypeElim` removes the constrained type
217+ definitions from the program. Any reference to a constrained type left inside
218+ a composite (e.g. a `count: nat` field) would otherwise dangle and fail to
219+ resolve in later passes and the final Core translation. -/
220+ def elimCompositeType (ptMap : ConstrainedTypeMap) (model : SemanticModel) (ct : CompositeType) : CompositeType :=
221+ { ct with
222+ fields := ct.fields.map fun f => { f with type := resolveType ptMap f.type }
223+ instanceProcedures := ct.instanceProcedures.map (elimProc ptMap model) }
224+
225+ public def constrainedTypeElim (model : SemanticModel) (program : Program)
235226 : Program × List DiagnosticModel :=
236227 let ptMap := buildConstrainedTypeMap program.types
237228 if ptMap.isEmpty then (program, []) else
@@ -244,13 +235,16 @@ public def constrainedTypeElim (_model : SemanticModel) (program : Program)
244235 acc.cons (diagnosticFromSource proc.name.source "constrained return types on functions are not yet supported" )
245236 else acc
246237 ({ program with
247- staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap)
238+ staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap model )
248239 ++ witnessProcedures
249- types := program.types.filter fun | .Constrained _ => false | _ => true },
240+ types := program.types.filterMap fun
241+ | .Constrained _ => none
242+ | .Composite ct => some (.Composite (elimCompositeType ptMap model ct))
243+ | other => some other },
250244 funcDiags)
251245
252246/-- Pipeline pass: constrained type elimination. -/
253- public def constrainedTypeElimPass : LaurelPass where
247+ public def constrainedTypeElimPass : LoweringPass where
254248 name := "ConstrainedTypeElim"
255249 documentation := "Eliminates constrained types by replacing them with their base types and generating constraint-checking functions and witness procedures. Type tests against constrained types are rewritten to call the generated constraint function."
256250 needsResolves := true
0 commit comments