Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 35 additions & 127 deletions Strata/Languages/Laurel/ConstrainedTypeElim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce
decreases := none
preconditions := [] }

private def wrap (stmts : List StmtExprMd) (src : Option FileRange)
: StmtExprMd :=
match stmts with | [s] => s | ss => ⟨.Block ss none, src⟩

def resolveVariable (ptMap : ConstrainedTypeMap) (v : VariableMd) : VariableMd :=
match v.val with
| .Declare param => ⟨.Declare { param with type := resolveType ptMap param.type }, v.source⟩
Expand Down Expand Up @@ -128,130 +124,43 @@ def resolveExprNode (ptMap : ConstrainedTypeMap) (expr : StmtExprMd) : StmtExprM
| .IsType t ty => ⟨.IsType t (resolveType ptMap ty), source⟩
| _ => expr

/-- If `target` is an assignment target of a constrained type, return that
constrained type together with an expression that reads the target back.
The declared type is taken from the `Declare` parameter or, for
`Local`/`Field` targets, from the semantic model.

Both `Local` and `Field` targets are resolved (they carry a `uniqueId`
after the resolution pass, which `constrainedTypeElimPass` requires), so the
model lookup reliably returns their declared type. -/
def constrainedTargetReadback (ptMap : ConstrainedTypeMap) (model : SemanticModel)
(target : VariableMd) : Option (HighType × StmtExprMd) :=
let src := target.source
let check (ty : HighType) (ref : StmtExprMd) : Option (HighType × StmtExprMd) :=
if isConstrainedType ptMap ty then some (ty, ref) else none
match target.val with
| .Local name => check (model.get name).getType.val ⟨.Var (.Local name), src⟩
| .Declare param => check param.type.val ⟨.Var (.Local param.name), src⟩
| .Field tgt fieldName => check (model.get fieldName).getType.val ⟨.Var (.Field tgt fieldName), src⟩

/-- Build `assert T$constraint(<read-back of target>)` for an assignment
`target` of constrained type `T`, or `none` if the target's type is not
constrained. This is the single point that turns any assignment target
(`Local`, `Declare`, or `Field`) into its constraint check, used by both the
statement-position handler (`elimStmt`) and the expression-position handler
(`wrapAssignNode`).

`src` is the source range reported on the generated assertion (and its
constraint call); it defaults to the target's own source but callers pass
the enclosing assignment's source so a failed check points at the whole
assignment. -/
def constrainedTargetAssert (ptMap : ConstrainedTypeMap) (model : SemanticModel)
(target : VariableMd) (src : Option FileRange := target.source) : Option StmtExprMd :=
(constrainedTargetReadback ptMap model target).bind fun (ty, ref) =>
(constraintCallForExpr ptMap ty ref (src := src)).map fun c =>
⟨.Assert { condition := c }, src⟩

/-- Wrap an assignment that appears in *expression* position so that the
constraint of any constrained-typed target is checked.

For `x := v` where `x : T` is constrained, produces the block expression
`{ x := v; assert T$constraint(x); x }`, whose value is the assigned value.
The constraint is asserted on a read-back of the target (after the
assignment) rather than on the value `v`, so `v` is evaluated exactly once
and the check is semantics-preserving.

`elimStmt` already handles assignments that appear as statements; this covers
assignments nested inside expressions (e.g. `y := (x := -1) + 1`), which are
only hoisted to statement level by the later `LiftExpressionAssignments`
pass. That pass preserves the order of side-effecting statements within an
expression-position block, so the assertion stays after the assignment.
Non-`Assign` nodes are returned unchanged. -/
def wrapAssignNode (ptMap : ConstrainedTypeMap) (model : SemanticModel)
(node : StmtExprMd) : StmtExprMd :=
/-- Per-node constrained-type elimination, applied bottom-up (with flattening)
by `mapStmtExprFlattenM`. `resultUsed` is `true` when the node occupies a
value position.

- Uninitialized constrained declaration `var x: T;` → assume its constraint.
- Assignment to constrained target(s) → emit the assignment followed by an
`assert T$constraint(<read-back>)` per constrained target. The constraint
is checked on a *read-back* of the target rather than on the RHS, so the
RHS is evaluated exactly once. In value position the read-back is also
appended as the final statement, so the resulting value-block evaluates to
the assigned value (this covers expression-position assignments such as
`y := (x := -1) + 1`); in statement position it is omitted.
- All other nodes are returned unchanged; the traversal handles recursion. -/
def elimNode (ptMap : ConstrainedTypeMap) (model : SemanticModel)
(resultUsed : Bool) (node : StmtExprMd) : List StmtExprMd :=
let source := node.source
match node.val with
| .Assign targets _value =>
match targets.filterMap (constrainedTargetReadback ptMap model) with
| [] => node
| (_, resultRef) :: _ =>
let src := node.source
let asserts : List StmtExprMd := targets.filterMap (constrainedTargetAssert ptMap model · (src := src))
⟨.Block ([node] ++ asserts ++ [resultRef]) none, src⟩
| _ => node

/-- Insert constraint assertions for every assignment to a constrained-typed
target that appears within an expression. A no-op on expressions that
contain no such assignment. -/
def wrapExprAssigns (ptMap : ConstrainedTypeMap) (model : SemanticModel)
(expr : StmtExprMd) : StmtExprMd :=
mapStmtExpr (wrapAssignNode ptMap model) expr

def elimStmt (ptMap : ConstrainedTypeMap) (model : SemanticModel)
(stmt : StmtExprMd) : List StmtExprMd :=
let source := stmt.source

match _h : stmt.val with
| .Var (.Declare param) =>
-- Uninitialized constrained-typed declaration (`var x: T;`): assume its
-- constraint, since the variable's value is otherwise unconstrained.
let check := (constraintCallFor ptMap param.type.val param.name (src := source)).toList.map
fun c => ⟨.Assume c, source⟩
[stmt] ++ check

| .Assign targets value =>
-- Wrap any assignments nested in the value expression (expression-position
-- assignments) so their constrained-type constraints are checked too.
let value := wrapExprAssigns ptMap model value
let stmt' : StmtExprMd := ⟨.Assign targets value, source⟩
-- Assert the constraint of every constrained-typed target, uniformly across
-- `Local`, `Declare`, and `Field` targets via `constrainedTargetAssert`.
--
-- The constraint is asserted on a *read-back* of the target (after the
-- write) rather than on `value`. `value` is the full RHS expression, already
-- emitted as the assignment statement above; re-using it would emit it a
-- second time, so any side effect in the RHS would run twice (e.g.
-- `c#count := (x := x + 1) + 1` would increment `x` twice). Reading the
-- target back evaluates the RHS exactly once.
let checks := targets.filterMap (constrainedTargetAssert ptMap model · (src := source))
[stmt'] ++ checks

| .Block stmts sep =>
let stmtss := stmts.map (elimStmt ptMap model)
[⟨.Block stmtss.flatten sep, source⟩]

| .IfThenElse cond thenBr (some elseBr) =>
let cond := wrapExprAssigns ptMap model cond
let thenSs := elimStmt ptMap model thenBr
let elseSs := elimStmt ptMap model elseBr
[⟨.IfThenElse cond (wrap thenSs source) (some (wrap elseSs source)), source⟩]
| .IfThenElse cond thenBr none =>
let cond := wrapExprAssigns ptMap model cond
let thenSs := elimStmt ptMap model thenBr
[⟨.IfThenElse cond (wrap thenSs source) none, source⟩]

| .While cond inv dec body =>
let cond := wrapExprAssigns ptMap model cond
let bodySs := elimStmt ptMap model body
[⟨.While cond inv dec (wrap bodySs source), source⟩]

| _ => [stmt]
termination_by sizeOf stmt
decreasing_by
all_goals simp_wf
all_goals (try have := AstNode.sizeOf_val_lt stmt)
all_goals (try term_by_mem)
all_goals omega
[node] ++ check
| .Assign targets _value =>
let asserts: List StmtExprMd := targets.filterMap (fun target =>
let ref : StmtExprMd := VariableMd.toReadbackExpr target
let ty : HighType := (computeExprType model ref).val
(constraintCallForExpr ptMap ty ref (src := source)).map (⟨.Assert { condition := · }, source⟩))
let suffix := match targets with
| [single] => if resultUsed then [VariableMd.toReadbackExpr single] else []
| _ => []
[node] ++ asserts ++ suffix
| _ => [node]

/-- Apply `elimNode` across a body via the flattening, `resultUsed`-aware
traversal. A procedure body is a statement, so the top-level `resultUsed`
is `false`. -/
def elimStmts (ptMap : ConstrainedTypeMap) (model : SemanticModel) (body : StmtExprMd) : StmtExprMd :=
mapStmtExprFlattenM (m := Id) (fun _ _ => none) (elimNode ptMap model) false body

def elimProc (ptMap : ConstrainedTypeMap) (model : SemanticModel) (proc : Procedure) : Procedure :=
let inputRequires : List Condition := proc.inputs.filterMap fun p =>
Expand All @@ -262,14 +171,13 @@ def elimProc (ptMap : ConstrainedTypeMap) (model : SemanticModel) (proc : Proced
fun c => { condition := ⟨c.val, p.type.source⟩ }
let body' := match proc.body with
| .Transparent bodyExpr =>
let stmts := elimStmt ptMap model bodyExpr
let body := wrap stmts bodyExpr.source
let body := elimStmts ptMap model bodyExpr
if outputEnsures.isEmpty then .Transparent body
else
let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.source⟩ else body
.Opaque outputEnsures (some retBody) []
| .Opaque postconds impl modif =>
let impl' := impl.map fun b => wrap (elimStmt ptMap model b) b.source
let impl' := impl.map (elimStmts ptMap model)
.Opaque (postconds ++ outputEnsures) impl' modif
| .Abstract postconds => .Abstract (postconds ++ outputEnsures)
| .External => .External
Expand Down
15 changes: 15 additions & 0 deletions Strata/Languages/Laurel/LaurelAST.lean
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,21 @@ def StmtExpr.constructorName (e : StmtExpr) : String :=
| .Hole .. => "Hole"
| .IncrDecr .. => "IncrDecr"

/-- Build an expression that reads back the value of a variable reference.

The result is always a `Var` expression that evaluates to the variable's
value. A `Declare` is read back as a `Local` reference to the declared name
(so a declaration target reads back the variable it introduces). -/
def Variable.toReadbackExpr : Variable → StmtExpr
| .Local name => .Var (.Local name)
| .Declare param => .Var (.Local param.name)
| .Field target fieldName => .Var (.Field target fieldName)

/-- Source-preserving read-back expression for a `VariableMd`
(see `Variable.toReadbackExpr`). -/
def VariableMd.toReadbackExpr (v : VariableMd) : StmtExprMd :=
⟨ v.val.toReadbackExpr, v.source ⟩

/-- Check whether a single modifies entry is the wildcard (`*`). -/
def StmtExprMd.isWildcard (m : StmtExprMd) : Bool := match m.val with | .All => true | _ => false

Expand Down
Loading
Loading