Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2e11936
feat(laurel): Support constrained types as composite fields
olivier-aws Jun 9, 2026
9570616
test(laurel): Add error case to constrained composite field test
olivier-aws Jun 11, 2026
ec879d7
refactor(laurel): Unify constrained type resolution algorithm
olivier-aws Jun 10, 2026
5655454
test(laurel): Document read-side completeness gap for constrained fields
olivier-aws Jun 10, 2026
7af30c6
modify pipeline
olivier-aws Jun 11, 2026
0e86b7e
fix(laurel): Make ConstrainedTypeElim sound when run before lowering
olivier-aws Jun 11, 2026
182e589
refactor(laurel): Drop dead constrained-type resolution in HeapParame…
olivier-aws Jun 11, 2026
3069d9d
refactor(laurel): Remove shared resolveConstrainedTypeWith helper
olivier-aws Jun 11, 2026
9dcdcb6
Merge branch 'main2' into pr-constrained-composite-fields-main2
olivier-aws Jun 12, 2026
713a11e
Update test to new Laurel test framework
olivier-aws Jun 12, 2026
c647619
fix(laurel): Assert field-write constraint on read-back, not RHS
olivier-aws Jun 15, 2026
95ac317
test(laurel): Add regression test for field-write RHS double-eval
olivier-aws Jun 15, 2026
2f1c0f9
Merge branch 'reviewed-kbd-will-merge-to-main' into pr-constrained-co…
olivier-aws Jun 15, 2026
a38dea5
Merge branch 'reviewed-kbd-will-merge-to-main' into pr-constrained-co…
olivier-aws Jun 18, 2026
fc354e2
refactor(laurel): Unify field constraint check via constraintCallForExpr
olivier-aws Jun 18, 2026
a6e0fbc
refactor(laurel): Unify assignment-target constraint checks, drop StateM
olivier-aws Jun 18, 2026
e18d117
Merge branch 'reviewed-kbd-will-merge-to-main' into pr-constrained-co…
olivier-aws Jun 19, 2026
0f2fd73
Simplify constrainedTargetReadback
keyboardDrummer Jun 20, 2026
74afd2f
Refactoring of ConstraintTypeElim pass
keyboardDrummer Jun 22, 2026
ead0e2f
Remove unused
keyboardDrummer Jun 22, 2026
79398f2
Further refactoring
keyboardDrummer Jun 22, 2026
d7bb6b0
Merge pull request #1400 from keyboardDrummer/constrained-composite-f…
olivier-aws Jun 23, 2026
45c7714
Merge remote-tracking branch 'origin/reviewed-kbd-will-merge-to-main'…
keyboardDrummer Jun 23, 2026
921b86f
Merge branch 'reviewed-kbd-will-merge-to-main' into pr-constrained-co…
keyboardDrummer-bot Jun 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 74 additions & 80 deletions Strata/Languages/Laurel/ConstrainedTypeElim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ namespace Strata.Laurel
open Strata

abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType
/-- Map from variable name to its constrained HighType (e.g. UserDefined "nat") -/
abbrev PredVarMap := Std.HashMap String HighType

def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap :=
types.foldl (init := {}) fun m td =>
Expand All @@ -52,15 +50,27 @@ def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd :=
def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool :=
match ty with | .UserDefined name => ptMap.contains name.text | _ => false

/-- Build a call to the constraint function for a constrained type, or `none` if not constrained -/
def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
(varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd :=
/-- Build a call to the constraint function for a constrained type, asserting
the constraint on the read-back expression `ref`. Returns `none` if `ty` is
not a constrained type.

`ref` is the expression whose value is checked (e.g. a local read
`x` or a field read `c#count`), allowing this to serve every assignment
target kind uniformly. -/
def constraintCallForExpr (ptMap : ConstrainedTypeMap) (ty : HighType)
(ref : StmtExprMd) (src : Option FileRange := none) : Option StmtExprMd :=
match ty with
| .UserDefined name => if ptMap.contains name.text then
some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Var (.Local varName), src⟩], src⟩
some ⟨.StaticCall (mkId s!"{name.text}$constraint") [ref], src⟩
else none
| _ => none

/-- Build a call to the constraint function for a constrained type, checking a
local variable read, or `none` if not constrained. -/
def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
(varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd :=
constraintCallForExpr ptMap ty ⟨.Var (.Local varName), src⟩ src

/-- Generate a constraint function for a constrained type.
For nested types, the function calls the parent's constraint function. -/
def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure :=
Expand All @@ -84,10 +94,6 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce
decreases := none
preconditions := [] }

private def wrap (stmts : List StmtExprMd) (src : Option FileRange)
: StmtExprMd :=
match stmts with | [s] => s | ss => ⟨.Block ss none, src⟩

def resolveVariable (ptMap : ConstrainedTypeMap) (v : VariableMd) : VariableMd :=
match v.val with
| .Declare param => ⟨.Declare { param with type := resolveType ptMap param.type }, v.source⟩
Expand Down Expand Up @@ -118,88 +124,60 @@ def resolveExprNode (ptMap : ConstrainedTypeMap) (expr : StmtExprMd) : StmtExprM
| .IsType t ty => ⟨.IsType t (resolveType ptMap ty), source⟩
| _ => expr

abbrev ElimM := StateM PredVarMap

private def inScope (action : ElimM α) : ElimM α := do
let saved ← get
let result ← action
set saved
return result

def elimStmt (ptMap : ConstrainedTypeMap)
(stmt : StmtExprMd) : ElimM (List StmtExprMd) := do
let source := stmt.source

match _h : stmt.val with
/-- Per-node constrained-type elimination, applied bottom-up (with flattening)
by `mapStmtExprFlattenM`. `resultUsed` is `true` when the node occupies a
value position.

- Uninitialized constrained declaration `var x: T;` → assume its constraint.
- Assignment to constrained target(s) → emit the assignment followed by an
`assert T$constraint(<read-back>)` per constrained target. The constraint
is checked on a *read-back* of the target rather than on the RHS, so the
RHS is evaluated exactly once. In value position the read-back is also
appended as the final statement, so the resulting value-block evaluates to
the assigned value (this covers expression-position assignments such as
`y := (x := -1) + 1`); in statement position it is omitted.
- All other nodes are returned unchanged; the traversal handles recursion. -/
def elimNode (ptMap : ConstrainedTypeMap) (model : SemanticModel)
(resultUsed : Bool) (node : StmtExprMd) : List StmtExprMd :=
let source := node.source
match node.val with
| .Var (.Declare param) =>
let callOpt := constraintCallFor ptMap param.type.val param.name (src := source)
if callOpt.isSome then modify fun pv => pv.insert param.name.text param.type.val
let check := match callOpt with
| some c => [⟨.Assume c, source⟩]
| none => []
pure ([stmt] ++ check)

let check := (constraintCallFor ptMap param.type.val param.name (src := source)).toList.map
fun c => ⟨.Assume c, source⟩
[node] ++ check
| .Assign targets _value =>
-- Handle Declare targets for constrained type elimination
let declareChecks ← targets.foldlM (init := ([] : List StmtExprMd)) fun acc target =>
match target.val with
| .Declare param => do
let callOpt := constraintCallFor ptMap param.type.val param.name (src := source)
if callOpt.isSome then modify fun pv => pv.insert param.name.text param.type.val
pure (acc ++ callOpt.toList.map fun c => ⟨.Assert { condition := c }, source⟩)
| .Local name => do
match (← get).get? name.text with
| some ty =>
let assert := (constraintCallFor ptMap ty name (src := source)).toList.map
fun c => ⟨.Assert { condition := c }, source⟩
pure (acc ++ assert)
| none => pure acc
| _ => pure acc
pure ([stmt] ++ declareChecks)

| .Block stmts sep =>
let stmtss ← inScope (stmts.mapM (elimStmt ptMap))
pure [⟨.Block stmtss.flatten sep, source⟩]

| .IfThenElse cond thenBr (some elseBr) =>
let thenSs ← inScope (elimStmt ptMap thenBr)
let elseSs ← inScope (elimStmt ptMap elseBr)
pure [⟨.IfThenElse cond (wrap thenSs source) (some (wrap elseSs source)), source⟩]
| .IfThenElse cond thenBr none =>
let thenSs ← inScope (elimStmt ptMap thenBr)
pure [⟨.IfThenElse cond (wrap thenSs source) none, source⟩]

| .While cond inv dec body postTest =>
let bodySs ← inScope (elimStmt ptMap body)
pure [⟨.While cond inv dec (wrap bodySs source) postTest, source⟩]

| _ => pure [stmt]
termination_by sizeOf stmt
decreasing_by
all_goals simp_wf
all_goals (try have := AstNode.sizeOf_val_lt stmt)
all_goals (try term_by_mem)
all_goals omega

def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure :=
let asserts: List StmtExprMd := targets.filterMap (fun target =>
let ref : StmtExprMd := VariableMd.toReadbackExpr target
let ty : HighType := (computeExprType model ref).val
(constraintCallForExpr ptMap ty ref (src := source)).map (⟨.Assert { condition := · }, source⟩))
let suffix := match targets with
| [single] => if resultUsed then [VariableMd.toReadbackExpr single] else []
| _ => []
[node] ++ asserts ++ suffix
| _ => [node]

/-- Apply `elimNode` across a body via the flattening, `resultUsed`-aware
traversal. A procedure body is a statement, so the top-level `resultUsed`
is `false`. -/
def elimStmts (ptMap : ConstrainedTypeMap) (model : SemanticModel) (body : StmtExprMd) : StmtExprMd :=
mapStmtExprFlattenM (m := Id) (fun _ _ => none) (elimNode ptMap model) false body

def elimProc (ptMap : ConstrainedTypeMap) (model : SemanticModel) (proc : Procedure) : Procedure :=
let inputRequires : List Condition := proc.inputs.filterMap fun p =>
(constraintCallFor ptMap p.type.val p.name (src := p.type.source)).map
fun c => { condition := c }
let outputEnsures : List Condition := if proc.isFunctional then [] else proc.outputs.filterMap fun p =>
(constraintCallFor ptMap p.type.val p.name (src := p.type.source)).map
fun c => { condition := ⟨c.val, p.type.source⟩ }
let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p =>
if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s
let body' := match proc.body with
| .Transparent bodyExpr =>
let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars
let body := wrap stmts bodyExpr.source
let body := elimStmts ptMap model bodyExpr
if outputEnsures.isEmpty then .Transparent body
else
let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.source⟩ else body
.Opaque outputEnsures (some retBody) []
| .Opaque postconds impl modif =>
let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars).1 b.source
let impl' := impl.map (elimStmts ptMap model)
.Opaque (postconds ++ outputEnsures) impl' modif
| .Abstract postconds => .Abstract (postconds ++ outputEnsures)
| .External => .External
Expand Down Expand Up @@ -231,7 +209,20 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) :
isFunctional := false
decreases := none }

public def constrainedTypeElim (_model : SemanticModel) (program : Program)
/-- Eliminate constrained types within a composite type definition: resolve
constrained field types to their base types and run constrained type
elimination on the composite's instance procedures.

This is necessary because `constrainedTypeElim` removes the constrained type
definitions from the program. Any reference to a constrained type left inside
a composite (e.g. a `count: nat` field) would otherwise dangle and fail to
resolve in later passes and the final Core translation. -/
def elimCompositeType (ptMap : ConstrainedTypeMap) (model : SemanticModel) (ct : CompositeType) : CompositeType :=
{ ct with
fields := ct.fields.map fun f => { f with type := resolveType ptMap f.type }
instanceProcedures := ct.instanceProcedures.map (elimProc ptMap model) }

public def constrainedTypeElim (model : SemanticModel) (program : Program)
: Program × List DiagnosticModel :=
let ptMap := buildConstrainedTypeMap program.types
if ptMap.isEmpty then (program, []) else
Expand All @@ -244,9 +235,12 @@ public def constrainedTypeElim (_model : SemanticModel) (program : Program)
acc.cons (diagnosticFromSource proc.name.source "constrained return types on functions are not yet supported")
else acc
({ program with
staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap)
staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap model)
++ witnessProcedures
types := program.types.filter fun | .Constrained _ => false | _ => true },
types := program.types.filterMap fun
| .Constrained _ => none
| .Composite ct => some (.Composite (elimCompositeType ptMap model ct))
| other => some other },
funcDiags)

/-- Pipeline pass: constrained type elimination. -/
Expand Down
7 changes: 6 additions & 1 deletion Strata/Languages/Laurel/HeapParameterization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ private def isDatatype (model : SemanticModel) (name : Identifier) : Bool :=

/-- Get the Box destructor name for a given Laurel HighType.
For UserDefined datatypes, uses "Box..<datatypeName>Val!";
for Composite types, uses "Box..compositeVal!". -/
for Composite types, uses "Box..compositeVal!".

Constrained types do not need resolving here: `ConstrainedTypeElim` runs
before this pass and has already lowered every constrained type to its base
type (and removed the constrained type definitions), so `ty` is never a
constrained-type reference. -/
def boxDestructorName (model : SemanticModel) (ty : HighType) : Identifier :=
match ty with
| .TInt => "Box..intVal!"
Expand Down
15 changes: 15 additions & 0 deletions Strata/Languages/Laurel/LaurelAST.lean
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,21 @@ def StmtExpr.constructorName (e : StmtExpr) : String :=
| .Hole .. => "Hole"
| .IncrDecr .. => "IncrDecr"

/-- Build an expression that reads back the value of a variable reference.

The result is always a `Var` expression that evaluates to the variable's
value. A `Declare` is read back as a `Local` reference to the declared name
(so a declaration target reads back the variable it introduces). -/
def Variable.toReadbackExpr : Variable → StmtExpr
| .Local name => .Var (.Local name)
| .Declare param => .Var (.Local param.name)
| .Field target fieldName => .Var (.Field target fieldName)

/-- Source-preserving read-back expression for a `VariableMd`
(see `Variable.toReadbackExpr`). -/
def VariableMd.toReadbackExpr (v : VariableMd) : StmtExprMd :=
⟨ v.val.toReadbackExpr, v.source ⟩

/-- Check whether a single modifies entry is the wildcard (`*`). -/
def StmtExprMd.isWildcard (m : StmtExprMd) : Bool := match m.val with | .All => true | _ => false

Expand Down
4 changes: 2 additions & 2 deletions Strata/Languages/Laurel/LaurelCompilationPipeline.lean
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def laurelPipeline : Array LaurelPass := #[
eliminateDoWhilePass,
eliminateIncrDecrPass,
typeAliasElimPass,
constrainedTypeElimPass,
filterNonCompositeModifiesPass,
liftInstanceProceduresPass,
eliminateValueInReturnsPass,
Expand All @@ -112,8 +113,7 @@ def laurelPipeline : Array LaurelPass := #[
eliminateDeterministicHolesPass,
desugarShortCircuitPass,
liftExpressionAssignmentsPass,
mergeAndLiftReturnsPass,
constrainedTypeElimPass
mergeAndLiftReturnsPass
]

/-- Every `comesBefore` constraint is respected by the pipeline order.
Expand Down
2 changes: 1 addition & 1 deletion Strata/Languages/Laurel/LaurelToCoreTranslator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def translateType (ty : HighTypeMd) : TranslateM LMonoTy := do
| some (.datatypeDefinition dt) => return .tcons dt.name.text []
| some (.datatypeConstructor typeName _) => return .tcons typeName.text []
| _ => do -- resolution should have already emitted a diagnostic
emitCoreDiagnostic (diagnosticFromSource ty.source s!"UserDefined type could not be resolved to a composite or datatype" DiagnosticType.StrataBug)
emitCoreDiagnostic (diagnosticFromSource ty.source s!"UserDefined type {name} could not be resolved to a composite or datatype" DiagnosticType.StrataBug)
return .tcons "Composite" []
| .TCore s => return .tcons s []
| .TReal => return LMonoTy.real
Expand Down
16 changes: 16 additions & 0 deletions Strata/Languages/Laurel/LiftImperativeExpressions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,22 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
else
return expr

| .Assert _ =>
-- An assert in expression position (e.g. inside a block used as a value)
-- is lifted as a side effect. Prepend it *here*, during the right-to-left
-- traversal, so it keeps its position relative to assignments lifted from
-- the same block. (If it were left for `onlyKeepSideEffectStmtsAndLast` to
-- prepend afterwards, it would be moved ahead of those assignments.)
-- Core has no assert-expression, so the expression yields a dummy value
-- that the surrounding block discards as a non-final statement.
prepend expr
return ⟨.LiteralBool true, source⟩

| .Assume _ =>
-- See the `.Assert` case above: same side-effect lifting for assumes.
prepend expr
return ⟨.LiteralBool true, source⟩

| _ => return expr
termination_by (sizeOf expr, 0)
decreasing_by
Expand Down
Loading
Loading