Skip to content

Commit 5caabd8

Browse files
author
Jules
committed
Laurel: avoid quantified modifies frames when using array theory
Merge reviewed-kbd-will-merge-to-main into PR1374-kiro. Resolution kept the reviewed branch's generic LaurelPass/PassMeta pass framework, but threads the LaurelTranslateOptions argument as the first parameter of each pass's run (options-first), per review preference. Preserved the PR's enumeratedModifiesClauses option and its threading through modifiesClausesTransformPass. Dropped the PR's now-superseded translateInvokeOnAxiom (invokeOn-axiom generation is handled by ContractPass on the reviewed branch). Also documents that enumeratedModifiesClauses has no effect when the procedure's modifies clause contains sets.
2 parents 622f60c + 49c3e1a commit 5caabd8

92 files changed

Lines changed: 2933 additions & 1091 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Strata/Languages/Core/ProgramEval.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def eval (E : Env) : Except Strata.DiagnosticModel (List Env × Statistics) :=
5454

5555
| .proc proc _md =>
5656
let (E, procStats) := Procedure.eval declsE proc
57+
-- Reset path conditions to the pre-procedure state so a procedure's
58+
-- assumptions don't leak into later ones: a structured `exit` bypasses
59+
-- `Env.merge` and leaves its frames unpopped, which would otherwise be
60+
-- threaded into the next procedure (strata-org/Strata#1390). Deferred
61+
-- obligations and fresh names carry forward.
62+
let E := { E with pathConditions := declsE.pathConditions }
5763
go rest E (stats.merge procStats)
5864

5965
| .func func _ => do

Strata/Languages/Laurel/ConstrainedTypeElim.lean

Lines changed: 77 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ namespace Strata.Laurel
3131
open Strata
3232

3333
abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType
34-
/-- Map from variable name to its constrained HighType (e.g. UserDefined "nat") -/
35-
abbrev PredVarMap := Std.HashMap String HighType
3634

3735
def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap :=
3836
types.foldl (init := {}) fun m td =>
@@ -52,20 +50,32 @@ def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd :=
5250
def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool :=
5351
match ty with | .UserDefined name => ptMap.contains name.text | _ => false
5452

55-
/-- Build a call to the constraint function for a constrained type, or `none` if not constrained -/
56-
def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
57-
(varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd :=
53+
/-- Build a call to the constraint function for a constrained type, asserting
54+
the constraint on the read-back expression `ref`. Returns `none` if `ty` is
55+
not a constrained type.
56+
57+
`ref` is the expression whose value is checked (e.g. a local read
58+
`x` or a field read `c#count`), allowing this to serve every assignment
59+
target kind uniformly. -/
60+
def constraintCallForExpr (ptMap : ConstrainedTypeMap) (ty : HighType)
61+
(ref : StmtExprMd) (src : Option FileRange := none) : Option StmtExprMd :=
5862
match ty with
5963
| .UserDefined name => if ptMap.contains name.text then
60-
some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Var (.Local varName), src⟩], src⟩
64+
some ⟨.StaticCall (mkId s!"{name.text}$constraint") [ref], src⟩
6165
else none
6266
| _ => none
6367

68+
/-- Build a call to the constraint function for a constrained type, checking a
69+
local variable read, or `none` if not constrained. -/
70+
def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
71+
(varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd :=
72+
constraintCallForExpr ptMap ty ⟨.Var (.Local varName), src⟩ src
73+
6474
/-- Generate a constraint function for a constrained type.
6575
For nested types, the function calls the parent's constraint function. -/
6676
def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure :=
6777
let baseType := resolveType ptMap ct.base
68-
let bodyExpr := match ct.base.val with
78+
let bodyExpr: StmtExprMd := match ct.base.val with
6979
| .UserDefined parent =>
7080
if ptMap.contains parent.text then
7181
let paramId := { ct.valueName with uniqueId := none }
@@ -79,15 +89,11 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce
7989
{ name := mkId s!"{ct.name.text}$constraint"
8090
inputs := [{ name := ct.valueName, type := baseType }]
8191
outputs := [{ name := mkId "result", type := { val := .TBool, source := none } }]
82-
body := .Transparent { val := .Block [bodyExpr] none, source := none }
92+
body := .Transparent { val := .Return bodyExpr, source := none }
8393
isFunctional := true
8494
decreases := none
8595
preconditions := [] }
8696

87-
private def wrap (stmts : List StmtExprMd) (src : Option FileRange)
88-
: StmtExprMd :=
89-
match stmts with | [s] => s | ss => ⟨.Block ss none, src⟩
90-
9197
def resolveVariable (ptMap : ConstrainedTypeMap) (v : VariableMd) : VariableMd :=
9298
match v.val with
9399
| .Declare param => ⟨.Declare { param with type := resolveType ptMap param.type }, v.source⟩
@@ -118,88 +124,60 @@ def resolveExprNode (ptMap : ConstrainedTypeMap) (expr : StmtExprMd) : StmtExprM
118124
| .IsType t ty => ⟨.IsType t (resolveType ptMap ty), source⟩
119125
| _ => expr
120126

121-
abbrev ElimM := StateM PredVarMap
122-
123-
private def inScope (action : ElimM α) : ElimM α := do
124-
let saved ← get
125-
let result ← action
126-
set saved
127-
return result
128-
129-
def elimStmt (ptMap : ConstrainedTypeMap)
130-
(stmt : StmtExprMd) : ElimM (List StmtExprMd) := do
131-
let source := stmt.source
132-
133-
match _h : stmt.val with
127+
/-- Per-node constrained-type elimination, applied bottom-up (with flattening)
128+
by `mapStmtExprFlattenM`. `resultUsed` is `true` when the node occupies a
129+
value position.
130+
131+
- Uninitialized constrained declaration `var x: T;` → assume its constraint.
132+
- Assignment to constrained target(s) → emit the assignment followed by an
133+
`assert T$constraint(<read-back>)` per constrained target. The constraint
134+
is checked on a *read-back* of the target rather than on the RHS, so the
135+
RHS is evaluated exactly once. In value position the read-back is also
136+
appended as the final statement, so the resulting value-block evaluates to
137+
the assigned value (this covers expression-position assignments such as
138+
`y := (x := -1) + 1`); in statement position it is omitted.
139+
- All other nodes are returned unchanged; the traversal handles recursion. -/
140+
def elimNode (ptMap : ConstrainedTypeMap) (model : SemanticModel)
141+
(resultUsed : Bool) (node : StmtExprMd) : List StmtExprMd :=
142+
let source := node.source
143+
match node.val with
134144
| .Var (.Declare param) =>
135-
let callOpt := constraintCallFor ptMap param.type.val param.name (src := source)
136-
if callOpt.isSome then modify fun pv => pv.insert param.name.text param.type.val
137-
let check := match callOpt with
138-
| some c => [⟨.Assume c, source⟩]
139-
| none => []
140-
pure ([stmt] ++ check)
141-
145+
let check := (constraintCallFor ptMap param.type.val param.name (src := source)).toList.map
146+
fun c => ⟨.Assume c, source⟩
147+
[node] ++ check
142148
| .Assign targets _value =>
143-
-- Handle Declare targets for constrained type elimination
144-
let declareChecks ← targets.foldlM (init := ([] : List StmtExprMd)) fun acc target =>
145-
match target.val with
146-
| .Declare param => do
147-
let callOpt := constraintCallFor ptMap param.type.val param.name (src := source)
148-
if callOpt.isSome then modify fun pv => pv.insert param.name.text param.type.val
149-
pure (acc ++ callOpt.toList.map fun c => ⟨.Assert { condition := c }, source⟩)
150-
| .Local name => do
151-
match (← get).get? name.text with
152-
| some ty =>
153-
let assert := (constraintCallFor ptMap ty name (src := source)).toList.map
154-
fun c => ⟨.Assert { condition := c }, source⟩
155-
pure (acc ++ assert)
156-
| none => pure acc
157-
| _ => pure acc
158-
pure ([stmt] ++ declareChecks)
159-
160-
| .Block stmts sep =>
161-
let stmtss ← inScope (stmts.mapM (elimStmt ptMap))
162-
pure [⟨.Block stmtss.flatten sep, source⟩]
163-
164-
| .IfThenElse cond thenBr (some elseBr) =>
165-
let thenSs ← inScope (elimStmt ptMap thenBr)
166-
let elseSs ← inScope (elimStmt ptMap elseBr)
167-
pure [⟨.IfThenElse cond (wrap thenSs source) (some (wrap elseSs source)), source⟩]
168-
| .IfThenElse cond thenBr none =>
169-
let thenSs ← inScope (elimStmt ptMap thenBr)
170-
pure [⟨.IfThenElse cond (wrap thenSs source) none, source⟩]
171-
172-
| .While cond inv dec body =>
173-
let bodySs ← inScope (elimStmt ptMap body)
174-
pure [⟨.While cond inv dec (wrap bodySs source), source⟩]
175-
176-
| _ => pure [stmt]
177-
termination_by sizeOf stmt
178-
decreasing_by
179-
all_goals simp_wf
180-
all_goals (try have := AstNode.sizeOf_val_lt stmt)
181-
all_goals (try term_by_mem)
182-
all_goals omega
183-
184-
def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure :=
149+
let asserts: List StmtExprMd := targets.filterMap (fun target =>
150+
let ref : StmtExprMd := VariableMd.toReadbackExpr target
151+
let ty : HighType := (computeExprType model ref).val
152+
(constraintCallForExpr ptMap ty ref (src := source)).map (⟨.Assert { condition := · }, source⟩))
153+
let suffix := match targets with
154+
| [single] => if resultUsed then [VariableMd.toReadbackExpr single] else []
155+
| _ => []
156+
[node] ++ asserts ++ suffix
157+
| _ => [node]
158+
159+
/-- Apply `elimNode` across a body via the flattening, `resultUsed`-aware
160+
traversal. A procedure body is a statement, so the top-level `resultUsed`
161+
is `false`. -/
162+
def elimStmts (ptMap : ConstrainedTypeMap) (model : SemanticModel) (body : StmtExprMd) : StmtExprMd :=
163+
mapStmtExprFlattenM (m := Id) (fun _ _ => none) (elimNode ptMap model) false body
164+
165+
def elimProc (ptMap : ConstrainedTypeMap) (model : SemanticModel) (proc : Procedure) : Procedure :=
185166
let inputRequires : List Condition := proc.inputs.filterMap fun p =>
186167
(constraintCallFor ptMap p.type.val p.name (src := p.type.source)).map
187168
fun c => { condition := c }
188169
let outputEnsures : List Condition := if proc.isFunctional then [] else proc.outputs.filterMap fun p =>
189170
(constraintCallFor ptMap p.type.val p.name (src := p.type.source)).map
190171
fun c => { condition := ⟨c.val, p.type.source⟩ }
191-
let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p =>
192-
if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s
193172
let body' := match proc.body with
194173
| .Transparent bodyExpr =>
195-
let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars
196-
let body := wrap stmts bodyExpr.source
174+
let body := elimStmts ptMap model bodyExpr
197175
if outputEnsures.isEmpty then .Transparent body
198176
else
199177
let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.source⟩ else body
200178
.Opaque outputEnsures (some retBody) []
201179
| .Opaque postconds impl modif =>
202-
let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars).1 b.source
180+
let impl' := impl.map (elimStmts ptMap model)
203181
.Opaque (postconds ++ outputEnsures) impl' modif
204182
| .Abstract postconds => .Abstract (postconds ++ outputEnsures)
205183
| .External => .External
@@ -231,7 +209,20 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) :
231209
isFunctional := false
232210
decreases := none }
233211

234-
public def constrainedTypeElim (_model : SemanticModel) (program : Program)
212+
/-- Eliminate constrained types within a composite type definition: resolve
213+
constrained field types to their base types and run constrained type
214+
elimination on the composite's instance procedures.
215+
216+
This is necessary because `constrainedTypeElim` removes the constrained type
217+
definitions from the program. Any reference to a constrained type left inside
218+
a composite (e.g. a `count: nat` field) would otherwise dangle and fail to
219+
resolve in later passes and the final Core translation. -/
220+
def elimCompositeType (ptMap : ConstrainedTypeMap) (model : SemanticModel) (ct : CompositeType) : CompositeType :=
221+
{ ct with
222+
fields := ct.fields.map fun f => { f with type := resolveType ptMap f.type }
223+
instanceProcedures := ct.instanceProcedures.map (elimProc ptMap model) }
224+
225+
public def constrainedTypeElim (model : SemanticModel) (program : Program)
235226
: Program × List DiagnosticModel :=
236227
let ptMap := buildConstrainedTypeMap program.types
237228
if ptMap.isEmpty then (program, []) else
@@ -244,13 +235,16 @@ public def constrainedTypeElim (_model : SemanticModel) (program : Program)
244235
acc.cons (diagnosticFromSource proc.name.source "constrained return types on functions are not yet supported")
245236
else acc
246237
({ program with
247-
staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap)
238+
staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap model)
248239
++ witnessProcedures
249-
types := program.types.filter fun | .Constrained _ => false | _ => true },
240+
types := program.types.filterMap fun
241+
| .Constrained _ => none
242+
| .Composite ct => some (.Composite (elimCompositeType ptMap model ct))
243+
| other => some other },
250244
funcDiags)
251245

252246
/-- Pipeline pass: constrained type elimination. -/
253-
public def constrainedTypeElimPass : LaurelPass where
247+
public def constrainedTypeElimPass : LoweringPass where
254248
name := "ConstrainedTypeElim"
255249
documentation := "Eliminates constrained types by replacing them with their base types and generating constraint-checking functions and witness procedures. Type tests against constrained types are rewritten to call the generated constraint function."
256250
needsResolves := true

0 commit comments

Comments
 (0)