Skip to content

Commit 49c3e1a

Browse files
keyboardDrummerkeyboardDrummer-botfabiomadge
authored
Enable calling procedures in contracts (#1352)
## Functional changes 1. [Debugging] Improve the printing of Laurel if-then-else expressions 1. `EliminateReturnsInExpression` now runs for procedures as well, which enables more types of transparent bodies for procedures. To make it work for both functions and procedures, it was also necessary for the body of functions to be immediately wrapped in a return statement during parsing. 1. Allow calling procedures from contracts. Combined with the previous change this makes procedures strictly more powerful than functions 1. Let the transparency pass rewrite the bodies of assume statements so they don't assert anything. 1. Improve diagnostics related to contracts, using the correct verbiage "precondition" and "postcondition" instead of "assertion" 1. Generalized the `LaurelPass` concept so it works for all transformation between Laurel source and Core, not just the Laurel->Laurel transformation. This helps make the documentation more complete. ### Why let the transparency pass rewrite the bodies of assume statements so they don't assert anything? After the contract pass, a call will look like `assert <preconditions>; call(..); assume <postconditions>`, where the body of the callee looks like `assume <preconditions>; <body>; assert <postconditions>`. If we now do either concrete execution, or we do inlining, then any assertions that occur inside the pre or postconditions will be asserted twice, because they occur once in an assert and once in an assume. By ignoring the assertions inside the assume, we prevent the duplication. Whether you also want this behavior for assumptions that were created by users is something I'm not sure about. However, if we want we can let those behave differently. Right now I think we don't have enough data to decide what we want for user created assumptions, and they are AFAIK not yet used, so I think it's OK to change their behavior. ## Implementation Add these passes: - [New] EliminateReturnStatements: rewrite `return` to `exit` statements, needed for the next pass. - [New] ContractPass: translate away pre and postconditions entirely by introducing assertion and assumptions at call sites and at procedure starts and ends - [Updated] Lift assertions, assumptions and procedure calls when they occur in expressions. Note: the changes in this pass could have been extracted to a different PR to reduce the scope of this one, but I think that keeping them in this PR is most efficient from a developer time perspective. ## Follow-up work - Remove the now obsolete functions from Laurel - Create WF proofs for quantifier bodies - Lift assumptions in expressions to axioms. - In the transparency phase, if something has no asserts and only calls functions, only create a function and no procedure --------- Co-authored-by: keyboardDrummer-bot <keyboardDrummer-bot@users.noreply.github.com> Co-authored-by: Fabio Madge <fabio@madge.me>
1 parent 6a074b1 commit 49c3e1a

83 files changed

Lines changed: 1701 additions & 783 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Strata/Languages/Laurel/ConstrainedTypeElim.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
7575
For nested types, the function calls the parent's constraint function. -/
7676
def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure :=
7777
let baseType := resolveType ptMap ct.base
78-
let bodyExpr := match ct.base.val with
78+
let bodyExpr: StmtExprMd := match ct.base.val with
7979
| .UserDefined parent =>
8080
if ptMap.contains parent.text then
8181
let paramId := { ct.valueName with uniqueId := none }
@@ -89,7 +89,7 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce
8989
{ name := mkId s!"{ct.name.text}$constraint"
9090
inputs := [{ name := ct.valueName, type := baseType }]
9191
outputs := [{ name := mkId "result", type := { val := .TBool, source := none } }]
92-
body := .Transparent { val := .Block [bodyExpr] none, source := none }
92+
body := .Transparent { val := .Return bodyExpr, source := none }
9393
isFunctional := true
9494
decreases := none
9595
preconditions := [] }
@@ -244,11 +244,11 @@ public def constrainedTypeElim (model : SemanticModel) (program : Program)
244244
funcDiags)
245245

246246
/-- Pipeline pass: constrained type elimination. -/
247-
public def constrainedTypeElimPass : LaurelPass where
247+
public def constrainedTypeElimPass : LoweringPass where
248248
name := "ConstrainedTypeElim"
249249
documentation := "Eliminates constrained types by replacing them with their base types and generating constraint-checking functions and witness procedures. Type tests against constrained types are rewritten to call the generated constraint function."
250250
needsResolves := true
251-
run := fun p m =>
251+
run := fun p m _ =>
252252
let (p', diags) := constrainedTypeElim m p
253253
(p', diags, {})
254254

Strata/Languages/Laurel/ContractPass.lean

Lines changed: 461 additions & 0 deletions
Large diffs are not rendered by default.

Strata/Languages/Laurel/CoreDefinitionsForLaurel.lean

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,21 @@ program Laurel;
2727

2828
datatype LaurelResolutionErrorPlaceholder {}
2929
datatype Float64IsNotSupportedYet {}
30+
datatype LaurelUnit { MkLaurelUnit() }
3031

3132
// The types for these Map functions are incorrect.
3233
// We'll fix them when Laurel supports polymorphism
33-
function select(map: int, key: int) : int
34+
// And then we can remove the datatype Box as well
35+
// And remove the hacky filter in HeapParameterization
36+
datatype Box { MkBox() }
37+
38+
function select(map: int, key: int) : Box
3439
external;
3540

36-
function update(map: int, key: int, value: int) : int
41+
function update(map: int, key: int, value: int) : Box
3742
external;
3843

39-
function const(value: int) : int
44+
function const(value: int) : Box
4045
external;
4146

4247
#end

Strata/Languages/Laurel/CoreGroupingAndOrdering.lean

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
-/
66

77
module
8-
public import Strata.Languages.Laurel.TransparencyPass
8+
public import Strata.Languages.Laurel.LaurelAST
9+
public import Strata.Languages.Laurel.UnorderedCore
10+
public import Strata.Languages.Laurel.LaurelPass
911
import Strata.DL.Lambda.LExpr
1012
import StrataDDM.Util.Graph.Tarjan
1113
import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator
@@ -27,6 +29,7 @@ declarations before they are emitted as Strata Core declarations.
2729
namespace Strata.Laurel
2830

2931
open Lambda (LMonoTy LExpr)
32+
open Std (Format ToFormat)
3033

3134
/-- Collect all `UserDefined` type names referenced in a `HighType`, including nested ones. -/
3235
def collectTypeRefs : HighTypeMd → List String
@@ -54,7 +57,7 @@ def collectStaticCallNames (expr : StmtExprMd) : List String :=
5457
match val with
5558
| .StaticCall callee args =>
5659
callee.text :: args.flatMap (fun a => collectStaticCallNames a)
57-
| .PrimitiveOp _ args => args.flatMap (fun a => collectStaticCallNames a)
60+
| .PrimitiveOp _ args _ => args.flatMap (fun a => collectStaticCallNames a)
5861
| .IfThenElse cond t e =>
5962
collectStaticCallNames cond ++
6063
collectStaticCallNames t ++
@@ -113,18 +116,18 @@ Build the procedure call graph, run Tarjan's SCC algorithm, and return each SCC
113116
as a list of procedures paired with a flag indicating whether the SCC is recursive.
114117
Results are in reverse topological order: dependencies before dependents.
115118
116-
Procedures with `invokeOn` are placed as early as possible — before
119+
Procedures with axioms are placed as early as possible — before
117120
unrelated procedures without them — by stably partitioning them first before building
118121
the graph. Tarjan then naturally assigns them lower indices, causing them to appear
119122
earlier in the output.
120123
-/
121124
public def computeSccDecls (program : UnorderedCoreWithLaurelTypes) : List (List Procedure × Bool) :=
122-
-- Stable partition: procedures with invokeOn come first, preserving relative
125+
-- Stable partition: procedures with axioms come first, preserving relative
123126
-- order within each group. Tarjan then places them earlier in the topological output.
124127
let allProcs := program.functions ++ program.coreProcedures
125-
let (withInvokeOn, withoutInvokeOn) :=
126-
allProcs.partition (fun p => p.invokeOn.isSome)
127-
let orderedProcs : List Procedure := withInvokeOn ++ withoutInvokeOn
128+
let (withAxioms, withoutAxioms) :=
129+
allProcs.partition (fun p => !p.axioms.isEmpty)
130+
let orderedProcs : List Procedure := withAxioms ++ withoutAxioms
128131

129132
-- Build a call-graph over all procedures.
130133
-- An edge proc → callee means proc's body/contracts contain a StaticCall to callee.
@@ -142,7 +145,8 @@ public def computeSccDecls (program : UnorderedCoreWithLaurelTypes) : List (List
142145
| _ => []
143146
let contractExprs : List StmtExprMd :=
144147
proc.preconditions.map (·.condition) ++
145-
proc.invokeOn.toList
148+
proc.invokeOn.toList ++
149+
proc.axioms
146150
(bodyExprs ++ contractExprs).flatMap collectStaticCallNames
147151

148152
-- Build the OutGraph for Tarjan.
@@ -225,7 +229,7 @@ Functions are grouped into SCCs (for mutual recursion). Proofs are emitted
225229
as individual `procedure` decls. Both participate in the topological ordering
226230
so that axioms are available to functions that need them.
227231
-/
228-
public def orderFunctionsAndProcedures (program : UnorderedCoreWithLaurelTypes) : CoreWithLaurelTypes :=
232+
def orderFunctionsAndProcedures (program : UnorderedCoreWithLaurelTypes) : CoreWithLaurelTypes :=
229233
let datatypeDecls := (groupDatatypesByScc' program).map OrderedDecl.datatypes
230234
let constantDecls := program.constants.map OrderedDecl.constant
231235
let funcNames : Std.HashSet String :=
@@ -254,4 +258,16 @@ where
254258
let members := comp.toList.filterMap fun idx => dtsArr[idx]?
255259
if members.isEmpty then none else some members
256260

261+
public def orderingPass : LaurelPass UnorderedCoreWithLaurelTypes CoreWithLaurelTypes where
262+
name := "OrderingPass"
263+
comesBefore := []
264+
documentation := "Produce a `CoreWithLaurelTypes` from a `UnorderedCoreWithLaurelTypes` by
265+
computing a combined ordering of functions and proofs using the call graph,
266+
then collecting datatypes and constants.
267+
Functions are grouped into SCCs (for mutual recursion). Proofs are emitted
268+
as individual `procedure` decls. Both participate in the topological ordering
269+
so that axioms are available to functions that need them."
270+
run := fun p _ _ =>
271+
(orderFunctionsAndProcedures p, [], {})
272+
257273
end Strata.Laurel

Strata/Languages/Laurel/DesugarShortCircuit.lean

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,43 @@ namespace Strata.Laurel
2525

2626
public section
2727

28-
private def bare (v : StmtExpr) : StmtExprMd := ⟨v, none⟩
2928

3029
/-- Local rewrite of a single short-circuit node. Recursion is handled by `mapStmtExpr`. -/
31-
private def desugarShortCircuitNode (model : SemanticModel) (expr : StmtExprMd) : StmtExprMd :=
30+
private def desugarShortCircuitNode (imperativeCallees : List String) (expr : StmtExprMd) : StmtExprMd :=
3231
let source := expr.source
32+
let wrap (v : StmtExpr) : StmtExprMd := ⟨v, source⟩
3333
match expr.val with
3434
| .PrimitiveOp op args _ =>
3535
match op, args with
3636
-- With bottom-up traversal, `a` and `b` are already desugared (nested
3737
-- short-circuits converted to IfThenElse). The check still works because
3838
-- `containsAssignmentOrImperativeCall` recurses into IfThenElse.
3939
| .AndThen, [a, b] | .Implies, [a, b] =>
40-
if containsAssignmentOrImperativeCall model b then
40+
if containsAssignmentOrImperativeCall imperativeCallees b then
4141
let elseVal := match op with | .AndThen => false | _ => true
42-
⟨.IfThenElse a b (some (bare (.LiteralBool elseVal))), source⟩
42+
⟨.IfThenElse a b (some (wrap (.LiteralBool elseVal))), source⟩
4343
else expr
4444
| .OrElse, [a, b] =>
45-
if containsAssignmentOrImperativeCall model b then
46-
⟨.IfThenElse a (bare (.LiteralBool true)) (some b), source⟩
45+
if containsAssignmentOrImperativeCall imperativeCallees b then
46+
⟨.IfThenElse a (wrap (.LiteralBool true)) (some b), source⟩
4747
else expr
4848
| _, _ => expr
4949
| _ => expr
5050

5151
/-- Desugar short-circuit operators in a program. -/
52-
def desugarShortCircuit (model : SemanticModel) (program : Program) : Program :=
53-
mapProgram (mapStmtExpr (desugarShortCircuitNode model)) program
52+
def desugarShortCircuit (program : Program) : Program :=
53+
let imperativeCallees := (program.staticProcedures.filter (!·.isFunctional)).map (·.name.text)
54+
mapProgram (mapStmtExpr (desugarShortCircuitNode imperativeCallees)) program
5455

5556
end -- public section
5657

5758
/-- Pipeline pass: desugar short-circuit operators. -/
58-
public def desugarShortCircuitPass : LaurelPass where
59+
public def desugarShortCircuitPass : LoweringPass where
5960
name := "DesugarShortCircuit"
6061
documentation := "Rewrites short-circuit boolean operators (`&&` and `||`) into equivalent conditional expressions. This simplifies subsequent passes and the final translation to Core, which does not have short-circuit semantics built in."
61-
run := fun p m =>
62-
(desugarShortCircuit m p, [], {})
62+
run := fun p _ _ =>
63+
(desugarShortCircuit p, [], {})
6364
comesBefore := [
64-
liftExpressionAssignmentsPass, "The desugar short circuit pass introduces if-then-else expressions whose control-flow must be taken into account by the lifting pass."⟩]
65+
liftImperativeExpressionsPass.meta, "The desugar short circuit pass introduces if-then-else expressions whose control-flow must be taken into account by the lifting pass."⟩]
6566

6667
end Strata.Laurel

Strata/Languages/Laurel/EliminateDeterministicHoles.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ def eliminateDeterministicHoles (program : Program) : Program × Statistics :=
9090
end -- public section
9191

9292
/-- Pipeline pass: eliminate deterministic holes. -/
93-
public def eliminateDeterministicHolesPass : LaurelPass where
93+
public def eliminateDeterministicHolesPass : LoweringPass where
9494
name := "EliminateDeterministicHoles"
9595
documentation := "Replaces every deterministic hole with a call to a freshly generated uninterpreted function. After this pass the program contains only non-deterministic holes. Assumes `InferHoleTypes` has already annotated holes with types."
96-
run := fun p _m =>
96+
run := fun p _m _ =>
9797
let (p', stats) := eliminateDeterministicHoles p
9898
(p', [], stats)
9999

Strata/Languages/Laurel/EliminateDoWhile.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def eliminateDoWhile (program : Program) : Program :=
7373
(mapProgramProceduresM rewrite program |>.run {}).fst
7474

7575
/-- Pipeline pass: eliminate post-test (`do … while`) loops. -/
76-
public def eliminateDoWhilePass : LaurelPass where
76+
public def eliminateDoWhilePass : LoweringPass where
7777
name := "EliminateDoWhile"
7878
documentation := "Lowers post-test `While` loops (the `do … while` form) into the pre-test loop `{ while(true) invariant I { BODY; if (!COND) exit L } } L`, with a fresh `$`-prefixed exit label `L`. Runs early so no later pass observes a post-test loop; the invariant is checked at the loop head, matching `while`."
79-
run := fun p _m => (eliminateDoWhile p, [], {})
79+
run := fun p _m _ => (eliminateDoWhile p, [], {})
8080

8181
end -- public section
8282
end Strata.Laurel

Strata/Languages/Laurel/EliminateIncrDecr.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def eliminateIncrDecr (program : Program) : Program :=
100100
mapProgramProcedures lowerProcedure program
101101

102102
/-- Pipeline pass: eliminate increment/decrement operators. -/
103-
public def eliminateIncrDecrPass : LaurelPass where
103+
public def eliminateIncrDecrPass : LoweringPass where
104104
name := "EliminateIncrDecr"
105105
documentation := "Lowers Java-style increment/decrement operators (`++x`, `x++`, `--x`, `x--`) into existing Laurel assignment and arithmetic constructs. Prefix forms yield the new value; postfix forms yield the old value. Runs early so that no later pass observes an `.IncrDecr` node."
106-
run := fun p _m => (eliminateIncrDecr p, [], {})
106+
run := fun p _m _ => (eliminateIncrDecr p, [], {})
107107

108108
end -- public section
109109
end Strata.Laurel
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/-
2+
Copyright Strata Contributors
3+
4+
SPDX-License-Identifier: Apache-2.0 OR MIT
5+
-/
6+
module
7+
8+
public import Strata.Languages.Laurel.MapStmtExpr
9+
public import Strata.Languages.Laurel.LaurelPass
10+
11+
/-!
12+
# Eliminate Return Statements
13+
14+
Replaces `return` statements in imperative procedure bodies with assignments
15+
to the output parameters followed by an `exit` to a labelled block that wraps
16+
the entire body. This ensures that code placed after the body block (e.g.,
17+
postcondition assertions inserted by the contract pass) is always reached.
18+
19+
This pass should run after `EliminateReturnsInExpression` (which handles
20+
functional/expression-position returns) and before the contract pass.
21+
-/
22+
23+
namespace Strata.Laurel
24+
25+
public section
26+
27+
private def returnLabel : String := "$return"
28+
29+
30+
31+
32+
/-- Transform a single procedure: wrap body in a labelled block and replace returns. -/
33+
private def eliminateReturnStmts (proc : Procedure) : Procedure :=
34+
match proc.body with
35+
| .Opaque postconds (some impl) mods =>
36+
let impl' := replaceReturn proc.outputs impl
37+
let wrapped := match impl'.val with
38+
| .Block stmts none => ⟨.Block stmts (some returnLabel), impl'.source⟩
39+
| _ => ⟨ .Block [impl'] (some returnLabel), proc.name.source ⟩
40+
{ proc with body := .Opaque postconds (some wrapped) mods }
41+
| .Transparent body =>
42+
let body' := replaceReturn proc.outputs body
43+
let wrapped := match body'.val with
44+
| .Block stmts none => ⟨.Block stmts (some returnLabel), body'.source⟩
45+
| _ => ⟨ .Block [body'] (some returnLabel), proc.name.source ⟩
46+
{ proc with body := .Transparent wrapped }
47+
| _ => proc
48+
where
49+
50+
/-- Replace `Return val` with `output := val; exit "$return"` (or just `exit`
51+
for valueless returns). Uses `mapStmtExpr` for bottom-up traversal. -/
52+
replaceReturn (outputs : List Parameter) (expr : StmtExprMd) : StmtExprMd :=
53+
mapStmtExpr (fun e =>
54+
match e.val with
55+
| .Return (some val) =>
56+
/- Handling valued return is required because the heap param pass introduces valued return in
57+
Strata/Languages/Laurel/HeapParameterizationConstants.lean
58+
We should change that so we can remove this case.
59+
-/
60+
match outputs with
61+
| [out] =>
62+
let assign := ⟨ .Assign [⟨ .Local out.name, expr.source ⟩] val, expr.source ⟩
63+
let exit := ⟨ .Exit returnLabel, expr.source ⟩
64+
⟨.Block [assign, exit] none, e.source⟩
65+
| _ => ⟨ .Exit returnLabel, expr.source ⟩
66+
| .Return none => ⟨ .Exit returnLabel, expr.source ⟩
67+
| _ => e) expr
68+
69+
/-- Transform a program by eliminating return statements in all procedure bodies. -/
70+
def eliminateReturnStatements (program : Program) : Program :=
71+
{ program with staticProcedures := program.staticProcedures.map eliminateReturnStmts }
72+
73+
public def eliminateReturnStatementsPass : LoweringPass where
74+
name := "EliminateReturnStatements"
75+
documentation := "Lower return statements to exit statements. Wrap each procedure body with a 'return' block"
76+
run := fun p _m _ =>
77+
let p' := eliminateReturnStatements p
78+
(p', [], {})
79+
-- comesBefore := [contractPass]
80+
81+
end -- public section
82+
83+
end Strata.Laurel

0 commit comments

Comments
 (0)