Skip to content

Commit d7bb6b0

Browse files
authored
Merge pull request #1400 from keyboardDrummer/constrained-composite-fields-refac
Constrained composite fields refac
2 parents e18d117 + 79398f2 commit d7bb6b0

3 files changed

Lines changed: 216 additions & 160 deletions

File tree

Strata/Languages/Laurel/ConstrainedTypeElim.lean

Lines changed: 35 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce
9494
decreases := none
9595
preconditions := [] }
9696

97-
private def wrap (stmts : List StmtExprMd) (src : Option FileRange)
98-
: StmtExprMd :=
99-
match stmts with | [s] => s | ss => ⟨.Block ss none, src⟩
100-
10197
def resolveVariable (ptMap : ConstrainedTypeMap) (v : VariableMd) : VariableMd :=
10298
match v.val with
10399
| .Declare param => ⟨.Declare { param with type := resolveType ptMap param.type }, v.source⟩
@@ -128,130 +124,43 @@ def resolveExprNode (ptMap : ConstrainedTypeMap) (expr : StmtExprMd) : StmtExprM
128124
| .IsType t ty => ⟨.IsType t (resolveType ptMap ty), source⟩
129125
| _ => expr
130126

131-
/-- If `target` is an assignment target of a constrained type, return that
132-
constrained type together with an expression that reads the target back.
133-
The declared type is taken from the `Declare` parameter or, for
134-
`Local`/`Field` targets, from the semantic model.
135-
136-
Both `Local` and `Field` targets are resolved (they carry a `uniqueId`
137-
after the resolution pass, which `constrainedTypeElimPass` requires), so the
138-
model lookup reliably returns their declared type. -/
139-
def constrainedTargetReadback (ptMap : ConstrainedTypeMap) (model : SemanticModel)
140-
(target : VariableMd) : Option (HighType × StmtExprMd) :=
141-
let src := target.source
142-
let check (ty : HighType) (ref : StmtExprMd) : Option (HighType × StmtExprMd) :=
143-
if isConstrainedType ptMap ty then some (ty, ref) else none
144-
match target.val with
145-
| .Local name => check (model.get name).getType.val ⟨.Var (.Local name), src⟩
146-
| .Declare param => check param.type.val ⟨.Var (.Local param.name), src⟩
147-
| .Field tgt fieldName => check (model.get fieldName).getType.val ⟨.Var (.Field tgt fieldName), src⟩
148-
149-
/-- Build `assert T$constraint(<read-back of target>)` for an assignment
150-
`target` of constrained type `T`, or `none` if the target's type is not
151-
constrained. This is the single point that turns any assignment target
152-
(`Local`, `Declare`, or `Field`) into its constraint check, used by both the
153-
statement-position handler (`elimStmt`) and the expression-position handler
154-
(`wrapAssignNode`).
155-
156-
`src` is the source range reported on the generated assertion (and its
157-
constraint call); it defaults to the target's own source but callers pass
158-
the enclosing assignment's source so a failed check points at the whole
159-
assignment. -/
160-
def constrainedTargetAssert (ptMap : ConstrainedTypeMap) (model : SemanticModel)
161-
(target : VariableMd) (src : Option FileRange := target.source) : Option StmtExprMd :=
162-
(constrainedTargetReadback ptMap model target).bind fun (ty, ref) =>
163-
(constraintCallForExpr ptMap ty ref (src := src)).map fun c =>
164-
⟨.Assert { condition := c }, src⟩
165-
166-
/-- Wrap an assignment that appears in *expression* position so that the
167-
constraint of any constrained-typed target is checked.
168-
169-
For `x := v` where `x : T` is constrained, produces the block expression
170-
`{ x := v; assert T$constraint(x); x }`, whose value is the assigned value.
171-
The constraint is asserted on a read-back of the target (after the
172-
assignment) rather than on the value `v`, so `v` is evaluated exactly once
173-
and the check is semantics-preserving.
174-
175-
`elimStmt` already handles assignments that appear as statements; this covers
176-
assignments nested inside expressions (e.g. `y := (x := -1) + 1`), which are
177-
only hoisted to statement level by the later `LiftExpressionAssignments`
178-
pass. That pass preserves the order of side-effecting statements within an
179-
expression-position block, so the assertion stays after the assignment.
180-
Non-`Assign` nodes are returned unchanged. -/
181-
def wrapAssignNode (ptMap : ConstrainedTypeMap) (model : SemanticModel)
182-
(node : StmtExprMd) : StmtExprMd :=
127+
/-- Per-node constrained-type elimination, applied bottom-up (with flattening)
128+
by `mapStmtExprFlattenM`. `resultUsed` is `true` when the node occupies a
129+
value position.
130+
131+
- Uninitialized constrained declaration `var x: T;` → assume its constraint.
132+
- Assignment to constrained target(s) → emit the assignment followed by an
133+
`assert T$constraint(<read-back>)` per constrained target. The constraint
134+
is checked on a *read-back* of the target rather than on the RHS, so the
135+
RHS is evaluated exactly once. In value position the read-back is also
136+
appended as the final statement, so the resulting value-block evaluates to
137+
the assigned value (this covers expression-position assignments such as
138+
`y := (x := -1) + 1`); in statement position it is omitted.
139+
- All other nodes are returned unchanged; the traversal handles recursion. -/
140+
def elimNode (ptMap : ConstrainedTypeMap) (model : SemanticModel)
141+
(resultUsed : Bool) (node : StmtExprMd) : List StmtExprMd :=
142+
let source := node.source
183143
match node.val with
184-
| .Assign targets _value =>
185-
match targets.filterMap (constrainedTargetReadback ptMap model) with
186-
| [] => node
187-
| (_, resultRef) :: _ =>
188-
let src := node.source
189-
let asserts : List StmtExprMd := targets.filterMap (constrainedTargetAssert ptMap model · (src := src))
190-
⟨.Block ([node] ++ asserts ++ [resultRef]) none, src⟩
191-
| _ => node
192-
193-
/-- Insert constraint assertions for every assignment to a constrained-typed
194-
target that appears within an expression. A no-op on expressions that
195-
contain no such assignment. -/
196-
def wrapExprAssigns (ptMap : ConstrainedTypeMap) (model : SemanticModel)
197-
(expr : StmtExprMd) : StmtExprMd :=
198-
mapStmtExpr (wrapAssignNode ptMap model) expr
199-
200-
def elimStmt (ptMap : ConstrainedTypeMap) (model : SemanticModel)
201-
(stmt : StmtExprMd) : List StmtExprMd :=
202-
let source := stmt.source
203-
204-
match _h : stmt.val with
205144
| .Var (.Declare param) =>
206-
-- Uninitialized constrained-typed declaration (`var x: T;`): assume its
207-
-- constraint, since the variable's value is otherwise unconstrained.
208145
let check := (constraintCallFor ptMap param.type.val param.name (src := source)).toList.map
209146
fun c => ⟨.Assume c, source⟩
210-
[stmt] ++ check
211-
212-
| .Assign targets value =>
213-
-- Wrap any assignments nested in the value expression (expression-position
214-
-- assignments) so their constrained-type constraints are checked too.
215-
let value := wrapExprAssigns ptMap model value
216-
let stmt' : StmtExprMd := ⟨.Assign targets value, source⟩
217-
-- Assert the constraint of every constrained-typed target, uniformly across
218-
-- `Local`, `Declare`, and `Field` targets via `constrainedTargetAssert`.
219-
--
220-
-- The constraint is asserted on a *read-back* of the target (after the
221-
-- write) rather than on `value`. `value` is the full RHS expression, already
222-
-- emitted as the assignment statement above; re-using it would emit it a
223-
-- second time, so any side effect in the RHS would run twice (e.g.
224-
-- `c#count := (x := x + 1) + 1` would increment `x` twice). Reading the
225-
-- target back evaluates the RHS exactly once.
226-
let checks := targets.filterMap (constrainedTargetAssert ptMap model · (src := source))
227-
[stmt'] ++ checks
228-
229-
| .Block stmts sep =>
230-
let stmtss := stmts.map (elimStmt ptMap model)
231-
[⟨.Block stmtss.flatten sep, source⟩]
232-
233-
| .IfThenElse cond thenBr (some elseBr) =>
234-
let cond := wrapExprAssigns ptMap model cond
235-
let thenSs := elimStmt ptMap model thenBr
236-
let elseSs := elimStmt ptMap model elseBr
237-
[⟨.IfThenElse cond (wrap thenSs source) (some (wrap elseSs source)), source⟩]
238-
| .IfThenElse cond thenBr none =>
239-
let cond := wrapExprAssigns ptMap model cond
240-
let thenSs := elimStmt ptMap model thenBr
241-
[⟨.IfThenElse cond (wrap thenSs source) none, source⟩]
242-
243-
| .While cond inv dec body =>
244-
let cond := wrapExprAssigns ptMap model cond
245-
let bodySs := elimStmt ptMap model body
246-
[⟨.While cond inv dec (wrap bodySs source), source⟩]
247-
248-
| _ => [stmt]
249-
termination_by sizeOf stmt
250-
decreasing_by
251-
all_goals simp_wf
252-
all_goals (try have := AstNode.sizeOf_val_lt stmt)
253-
all_goals (try term_by_mem)
254-
all_goals omega
147+
[node] ++ check
148+
| .Assign targets _value =>
149+
let asserts: List StmtExprMd := targets.filterMap (fun target =>
150+
let ref : StmtExprMd := VariableMd.toReadbackExpr target
151+
let ty : HighType := (computeExprType model ref).val
152+
(constraintCallForExpr ptMap ty ref (src := source)).map (⟨.Assert { condition := · }, source⟩))
153+
let suffix := match targets with
154+
| [single] => if resultUsed then [VariableMd.toReadbackExpr single] else []
155+
| _ => []
156+
[node] ++ asserts ++ suffix
157+
| _ => [node]
158+
159+
/-- Apply `elimNode` across a body via the flattening, `resultUsed`-aware
160+
traversal. A procedure body is a statement, so the top-level `resultUsed`
161+
is `false`. -/
162+
def elimStmts (ptMap : ConstrainedTypeMap) (model : SemanticModel) (body : StmtExprMd) : StmtExprMd :=
163+
mapStmtExprFlattenM (m := Id) (fun _ _ => none) (elimNode ptMap model) false body
255164

256165
def elimProc (ptMap : ConstrainedTypeMap) (model : SemanticModel) (proc : Procedure) : Procedure :=
257166
let inputRequires : List Condition := proc.inputs.filterMap fun p =>
@@ -262,14 +171,13 @@ def elimProc (ptMap : ConstrainedTypeMap) (model : SemanticModel) (proc : Proced
262171
fun c => { condition := ⟨c.val, p.type.source⟩ }
263172
let body' := match proc.body with
264173
| .Transparent bodyExpr =>
265-
let stmts := elimStmt ptMap model bodyExpr
266-
let body := wrap stmts bodyExpr.source
174+
let body := elimStmts ptMap model bodyExpr
267175
if outputEnsures.isEmpty then .Transparent body
268176
else
269177
let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.source⟩ else body
270178
.Opaque outputEnsures (some retBody) []
271179
| .Opaque postconds impl modif =>
272-
let impl' := impl.map fun b => wrap (elimStmt ptMap model b) b.source
180+
let impl' := impl.map (elimStmts ptMap model)
273181
.Opaque (postconds ++ outputEnsures) impl' modif
274182
| .Abstract postconds => .Abstract (postconds ++ outputEnsures)
275183
| .External => .External

Strata/Languages/Laurel/LaurelAST.lean

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,21 @@ def StmtExpr.constructorName (e : StmtExpr) : String :=
748748
| .Hole .. => "Hole"
749749
| .IncrDecr .. => "IncrDecr"
750750

751+
/-- Build an expression that reads back the value of a variable reference.
752+
753+
The result is always a `Var` expression that evaluates to the variable's
754+
value. A `Declare` is read back as a `Local` reference to the declared name
755+
(so a declaration target reads back the variable it introduces). -/
756+
def Variable.toReadbackExpr : Variable → StmtExpr
757+
| .Local name => .Var (.Local name)
758+
| .Declare param => .Var (.Local param.name)
759+
| .Field target fieldName => .Var (.Field target fieldName)
760+
761+
/-- Source-preserving read-back expression for a `VariableMd`
762+
(see `Variable.toReadbackExpr`). -/
763+
def VariableMd.toReadbackExpr (v : VariableMd) : StmtExprMd :=
764+
⟨ v.val.toReadbackExpr, v.source ⟩
765+
751766
/-- Check whether a single modifies entry is the wildcard (`*`). -/
752767
def StmtExprMd.isWildcard (m : StmtExprMd) : Bool := match m.val with | .All => true | _ => false
753768

0 commit comments

Comments
 (0)