diff --git a/Compiler/ABI/Frame.lean b/Compiler/ABI/Frame.lean new file mode 100644 index 000000000..5e4fecd55 --- /dev/null +++ b/Compiler/ABI/Frame.lean @@ -0,0 +1,165 @@ +import Compiler.CompilationModel.AbiTypeLayout +import Compiler.Yul.Ast + +namespace Compiler.ABI.Frame + +open Compiler.CompilationModel +open Compiler.Yul + +inductive FrameSource + | calldata + | memory + | code + | storage + deriving Repr, BEq + +inductive FramePassMode + | inlineWords + | pointer + deriving Repr, BEq + +structure FrameField where + name : String + ty : ParamType + source : FrameSource + sourceBase : String := "" + tailBytes : Nat := 0 + deriving Repr, BEq + +structure FrameLayout where + fields : List FrameField + headWords : Nat + hasDynamic : Bool + mode : FramePassMode + deriving Repr, BEq + +def spillThresholdWords : Nat := 4 + +def fieldHeadWords (field : FrameField) : Nat := + paramParentHeadWords field.ty + +def fieldsHeadWords (fields : List FrameField) : Nat := + fields.foldl (fun acc field => acc + fieldHeadWords field) 0 + +def fieldsHaveDynamic (fields : List FrameField) : Bool := + fields.any (fun field => isDynamicParamType field.ty) + +def shouldPassByPointer (fields : List FrameField) : Bool := + fieldsHaveDynamic fields || spillThresholdWords < fieldsHeadWords fields + +def layout (fields : List FrameField) : FrameLayout := + let headWords := fieldsHeadWords fields + let hasDynamic := fieldsHaveDynamic fields + { fields + headWords + hasDynamic + mode := if hasDynamic || spillThresholdWords < headWords then .pointer else .inlineWords } + +def sourceCanMaterializeEarly : FrameSource → Bool + | .calldata | .memory | .code | .storage => true + +def layoutSourcesSupported (l : FrameLayout) : Bool := + l.fields.all (fun field => sourceCanMaterializeEarly field.source) + +def frameSizeBytes (l : FrameLayout) : Nat := + l.fields.foldl (fun acc field => acc + fieldHeadWords field * 32 + + (if isDynamicParamType field.ty then field.tailBytes else 0)) 0 + +def fieldPayloadWords (field : FrameField) : Nat := + fieldHeadWords field + if isDynamicParamType field.ty then (field.tailBytes + 31) / 32 else 0 + +def frameAllocBytes (l : FrameLayout) : Nat := + l.fields.foldl (fun acc field => acc + fieldPayloadWords field * 32) 0 + +def ptrName (base : String) : String := + "__abi_frame_" ++ base + +def fieldWordName (base : String) (field : FrameField) (idx : Nat) : String := + ptrName base ++ "_" ++ field.name ++ "_" ++ toString idx + +def allocateFrame (base : String) (l : FrameLayout) : List YulStmt := + [ YulStmt.let_ (ptrName base) (YulExpr.call "mload" [YulExpr.lit 64]) + , YulStmt.expr (YulExpr.call "mstore" + [ YulExpr.lit 64 + , YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit (frameAllocBytes l)] ])] + +def sourceBaseName (field : FrameField) : String := + if field.sourceBase.isEmpty then field.name else field.sourceBase + +private def sourceByteOffset (field : FrameField) (idx : Nat) : YulExpr := + if idx == 0 then + YulExpr.ident (sourceBaseName field) + else + YulExpr.call "add" [YulExpr.ident (sourceBaseName field), YulExpr.lit (idx * 32)] + +private def sourceStorageSlot (field : FrameField) (idx : Nat) : YulExpr := + if idx == 0 then + YulExpr.ident (sourceBaseName field) + else + YulExpr.call "add" [YulExpr.ident (sourceBaseName field), YulExpr.lit idx] + +private def materializeSourceWord (field : FrameField) (idx : Nat) : YulExpr := + match field.source with + | .calldata => YulExpr.call "calldataload" [sourceByteOffset field idx] + | .memory => YulExpr.call "mload" [sourceByteOffset field idx] + | .code => YulExpr.call "mload" [sourceByteOffset field idx] + | .storage => YulExpr.call "sload" [sourceStorageSlot field idx] + +partial def spillField (base : String) (offsetWords : Nat) (field : FrameField) : List YulStmt := + (List.range (fieldPayloadWords field)).map fun idx => + YulStmt.expr (YulExpr.call "mstore" + [ YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit ((offsetWords + idx) * 32)] + , materializeSourceWord field idx ]) + +partial def spillFields (base : String) (offsetWords : Nat) : List FrameField → List YulStmt + | [] => [] + | field :: rest => spillField base offsetWords field ++ spillFields base (offsetWords + fieldPayloadWords field) rest + +/-- Materialize a typed ABI frame into memory before lowering calls/logs/returns. + Large or dynamic payloads are then passed as `(ptr, size)` instead of as a + long list of Yul values. -/ +def spillPayloadToMemory (base : String) (l : FrameLayout) : List YulStmt := + allocateFrame base l ++ spillFields base 0 l.fields + +def pointerArgs (base : String) (l : FrameLayout) : List YulExpr := + [YulExpr.ident (ptrName base), YulExpr.lit (frameSizeBytes l)] + +def inlinePayloadToScratch (words : List YulExpr) : List YulStmt := + words.zipIdx.map fun (word, idx) => + YulStmt.expr (YulExpr.call "mstore" [YulExpr.lit (idx * 32), word]) + +private partial def inlineArgsFrom : List FrameField → List YulExpr + | [] => [] + | field :: rest => + (List.range (fieldHeadWords field)).map (fun wordIdx => + materializeSourceWord field wordIdx) ++ + inlineArgsFrom rest + +def inlineArgs (l : FrameLayout) : List YulExpr := + inlineArgsFrom l.fields + +def loweredArgs (base : String) (l : FrameLayout) : List YulExpr := + match l.mode with + | .pointer => pointerArgs base l + | .inlineWords => inlineArgs l + +def materializePayloadToMemory (base : String) (l : FrameLayout) : List YulStmt × List YulExpr := + match l.mode with + | .pointer => + (spillPayloadToMemory base l, pointerArgs base l) + | .inlineWords => + (inlinePayloadToScratch (inlineArgs l), [YulExpr.lit 0, YulExpr.lit (frameSizeBytes l)]) + +def containsDynamicArrayOrBytes (l : FrameLayout) : Bool := + l.fields.any fun field => + match field.ty with + | .array _ | .bytes | .string => true + | _ => isDynamicParamType field.ty + +def supportsNestedStructs (l : FrameLayout) : Bool := + l.fields.any fun field => + match field.ty with + | .tuple elems => elems.any (fun ty => match ty with | .tuple _ => true | _ => false) + | _ => false + +end Compiler.ABI.Frame diff --git a/Compiler/ABI/FrameTest.lean b/Compiler/ABI/FrameTest.lean new file mode 100644 index 000000000..e86850935 --- /dev/null +++ b/Compiler/ABI/FrameTest.lean @@ -0,0 +1,49 @@ +import Compiler.ABI.Frame + +namespace Compiler.ABI.FrameTest + +open Compiler.ABI.Frame +open Compiler.CompilationModel +open Compiler.Yul + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"frame test failed: {label}") + IO.println s!"ok: {label}" + +private def takeFields : List FrameField := + [ { name := "offer", ty := .tuple [.address, .uint256, .tuple [.bytes32, .uint256]], source := .calldata } + , { name := "units", ty := .uint256, source := .calldata } + , { name := "ratifierData", ty := .bytes, source := .calldata, tailBytes := 96 } ] + +private def sourceFields : List FrameField := + [ { name := "c", ty := .uint256, source := .calldata } + , { name := "m", ty := .bytes, source := .memory, tailBytes := 64 } + , { name := "x", ty := .bytes32, source := .code } + , { name := "s", ty := .uint256, source := .storage } ] + +private def inlineFields : List FrameField := + [ { name := "pair", ty := .tuple [.uint256, .bytes32], source := .calldata } + , { name := "amount", ty := .uint256, source := .calldata } ] + +private def calldataLoadName? : YulExpr → Option String + | .call "calldataload" [.ident name] => some name + | .call "calldataload" [.call "add" [.ident name, _]] => some name + | _ => none + +#eval! do + let takeLayout := layout takeFields + assert "nested struct supported" (supportsNestedStructs takeLayout) + assert "dynamic bytes/arrays force pointer mode" (takeLayout.mode == FramePassMode.pointer) + assert "Take frame passes pointer pair" ((loweredArgs "take" takeLayout).length == 2) + assert "Take spills early to memory" ((spillPayloadToMemory "take" takeLayout).length > 2) + assert "dynamic tail contributes to pointer payload size" (frameSizeBytes takeLayout == 288) + assert "dynamic tail contributes to allocated words" (frameAllocBytes takeLayout == 288) + let srcLayout := layout sourceFields + assert "calldata/memory/code/storage sources supported" (layoutSourcesSupported srcLayout) + assert "dynamic source frame is pointer mode" (srcLayout.mode == FramePassMode.pointer) + let inlineNames := (inlineArgs (layout inlineFields)).filterMap calldataLoadName? + assert "inline source words are indexed per field" + (inlineNames == ["pair", "pair", "amount"]) + +end Compiler.ABI.FrameTest diff --git a/Compiler/Lowering/StackSafeAbi.lean b/Compiler/Lowering/StackSafeAbi.lean new file mode 100644 index 000000000..852c9ef15 --- /dev/null +++ b/Compiler/Lowering/StackSafeAbi.lean @@ -0,0 +1,57 @@ +import Compiler.ABI.Frame + +namespace Compiler.Lowering.StackSafeAbi + +open Compiler.ABI.Frame +open Compiler.Yul + +structure LoweredFrame where + prologue : List YulStmt + args : List YulExpr + layout : FrameLayout + deriving Repr + +def eventNameTopicWord (eventName : String) : Nat := + UInt64.toNat (hash eventName) + +def lowerFrameSpilled (base : String) (fields : List FrameField) : Except String LoweredFrame := do + let l := layout fields + if !layoutSourcesSupported l then + throw s!"ABI frame '{base}' uses an unsupported source" + let prologue := + match l.mode with + | .pointer => spillPayloadToMemory base l + | .inlineWords => [] + pure { prologue, args := loweredArgs base l, layout := l } + +def lowerFrameAsMemoryPayload (base : String) (fields : List FrameField) : Except String (List YulStmt × List YulExpr × FrameLayout) := do + let lowered ← lowerFrameSpilled base fields + let (prologue, args) := materializePayloadToMemory base lowered.layout + pure (prologue, args, lowered.layout) + +def lowerEventWithTopic (base : String) (topic0 : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do + let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload base fields + pure (prologue ++ + [YulStmt.expr (YulExpr.call "log1" (payloadArgs ++ [topic0]))]) + +def lowerEvent (eventName : String) (fields : List FrameField) : Except String (List YulStmt) := do + lowerEventWithTopic eventName (YulExpr.lit (eventNameTopicWord eventName)) fields + +def lowerExternalCall (callName : String) (target value : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do + let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload callName fields + let callArgs := [YulExpr.call "gas" [], target, value] ++ payloadArgs ++ [YulExpr.lit 0, YulExpr.lit 0] + pure (prologue ++ [YulStmt.let_ ("__" ++ callName ++ "_ok") (YulExpr.call "call" callArgs)]) + +def lowerDynamicReturn (returnName : String) (fields : List FrameField) : Except String (List YulStmt) := do + let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload returnName fields + pure (prologue ++ [YulStmt.expr (YulExpr.call "return" payloadArgs)]) + +def usesPointerAbi (stmts : List YulStmt) : Bool := + stmts.any fun stmt => + match stmt with + | .expr (.call "return" [_ptr, _size]) => true + | .expr (.call "log1" [_ptr, _size, _topic]) => true + | .let_ _ (.call "call" [_gas, _target, _value, _ptr, _size, _out, _outSize]) => true + | _ => false + +end Compiler.Lowering.StackSafeAbi diff --git a/Compiler/Lowering/StackSafeAbiTest.lean b/Compiler/Lowering/StackSafeAbiTest.lean new file mode 100644 index 000000000..22ea333df --- /dev/null +++ b/Compiler/Lowering/StackSafeAbiTest.lean @@ -0,0 +1,77 @@ +import Compiler.Lowering.StackSafeAbi + +namespace Compiler.Lowering.StackSafeAbiTest + +open Compiler.ABI.Frame +open Compiler.CompilationModel +open Compiler.Lowering.StackSafeAbi +open Compiler.Yul + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"stack safe ABI test failed: {label}") + IO.println s!"ok: {label}" + +private def bigDynamicPayload : List FrameField := + [ { name := "toId", ty := .bytes32, source := .calldata } + , { name := "toMarket", ty := .tuple [.address, .uint256, .uint256, .address], source := .calldata } + , { name := "takes", ty := .array (.tuple [.address, .uint256, .bytes]), source := .calldata, tailBytes := 128 } ] + +private def smallStaticPayload : List FrameField := + [ { name := "id", ty := .bytes32, source := .calldata } + , { name := "amount", ty := .uint256, source := .calldata } ] + +private def countMstores : List YulStmt → Nat := + List.length ∘ List.filter (fun stmt => + match stmt with + | .expr (.call "mstore" _) => true + | _ => false) + +private def returnsBytes (bytes : Nat) : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .expr (.call "return" [.lit 0, .lit n]) => n == bytes + | _ => false + +private def callsWithInputBytes (bytes : Nat) : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .let_ _ (.call "call" [_gas, _target, _value, .lit 0, .lit n, _out, _outSize]) => n == bytes + | _ => false + +private def logsWithDataBytes (bytes : Nat) : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .expr (.call "log1" [.lit 0, .lit n, .lit topic]) => n == bytes && topic != 0 + | _ => false + +#eval! do + match lowerEvent "Take" bigDynamicPayload with + | .ok ev => assert "event lowering uses memory pointer" (usesPointerAbi ev) + | .error err => throw (IO.userError err) + match lowerExternalCall "callback" (YulExpr.ident "target") (YulExpr.lit 0) bigDynamicPayload with + | .ok call => assert "external-call lowering uses memory pointer" (usesPointerAbi call) + | .error err => throw (IO.userError err) + match lowerDynamicReturn "dynamicReturn" bigDynamicPayload with + | .ok ret => assert "dynamic return lowering uses memory pointer" (usesPointerAbi ret) + | .error err => throw (IO.userError err) + match lowerFrameSpilled "toMarket" bigDynamicPayload with + | .ok lowered => assert "frame-spilled lowering allocates memory early" (!lowered.prologue.isEmpty && lowered.args.length == 2) + | .error err => throw (IO.userError err) + match lowerEvent "SmallStatic" smallStaticPayload with + | .ok ev => + assert "inline event lowering stores payload" (countMstores ev == 2) + assert "inline event lowering logs payload bytes" (logsWithDataBytes 64 ev) + | .error err => throw (IO.userError err) + match lowerExternalCall "smallCallback" (YulExpr.ident "target") (YulExpr.lit 0) smallStaticPayload with + | .ok call => + assert "inline external-call lowering stores payload" (countMstores call == 2) + assert "inline external-call lowering passes payload bytes" (callsWithInputBytes 64 call) + | .error err => throw (IO.userError err) + match lowerDynamicReturn "smallReturn" smallStaticPayload with + | .ok ret => + assert "inline dynamic return lowering stores payload" (countMstores ret == 2) + assert "inline dynamic return lowering returns payload bytes" (returnsBytes 64 ret) + | .error err => throw (IO.userError err) + +end Compiler.Lowering.StackSafeAbiTest diff --git a/Compiler/Modules.lean b/Compiler/Modules.lean index 62910313b..7fe52e6c3 100644 --- a/Compiler/Modules.lean +++ b/Compiler/Modules.lean @@ -5,3 +5,5 @@ import Compiler.Modules.ERC20 import Compiler.Modules.Hashing import Compiler.Modules.Oracle import Compiler.Modules.Precompiles +import Compiler.Modules.Create2SSTORE2 +import Compiler.Modules.CodeData diff --git a/Compiler/Modules/CodeData.lean b/Compiler/Modules/CodeData.lean new file mode 100644 index 000000000..e44f2201b --- /dev/null +++ b/Compiler/Modules/CodeData.lean @@ -0,0 +1,67 @@ +import Compiler.ABI.Frame +import Compiler.Modules.Create2SSTORE2 + +namespace Compiler.Modules.CodeData + +open Compiler.ABI.Frame +open Compiler.ECM +open Compiler.Yul + +structure CodeDataWrite where + salt : YulExpr + value : YulExpr := YulExpr.lit 0 + payload : FrameLayout + deriving Repr + +structure CodeDataRead where + pointer : YulExpr + destOffset : YulExpr + codeOffset : YulExpr + size : YulExpr + payload : FrameLayout + deriving Repr + +def trustSurface : List String := + [ "CREATE2 address derivation is trusted at the EVM boundary" + , "SSTORE2 pointer code layout is trusted as code-as-data" + , "extcodecopy reads immutable deployed code bytes into caller-owned memory" + , "ABI frame layout is typed by Compiler.ABI.Frame before write/read lowering" ] + +def writeTyped (resultVar base : String) (write : CodeDataWrite) : Except String (List YulStmt) := do + if !layoutSourcesSupported write.payload then + throw "CodeData write payload has unsupported frame source" + let (prelude, payloadArgs) := materializePayloadToMemory base write.payload + let initcodeOffset ← + match payloadArgs with + | [offset, _size] => pure offset + | _ => throw "CodeData write expected a memory payload pointer and size" + let initcodeSize ← + match payloadArgs with + | [_offset, size] => pure size + | _ => throw "CodeData write expected a memory payload pointer and size" + let deploy ← (Compiler.Modules.Create2SSTORE2.deployModule resultVar).compile {} + [write.value, initcodeOffset, initcodeSize, write.salt] + pure (prelude ++ deploy) + +def readTyped (read : CodeDataRead) : Except String (List YulStmt) := do + if !layoutSourcesSupported read.payload then + throw "CodeData read payload has unsupported frame source" + (Compiler.Modules.Create2SSTORE2.readCodeModule).compile {} + [read.pointer, read.destOffset, read.codeOffset, read.size] + +def roundtripShape (resultVar base : String) (write : CodeDataWrite) (read : CodeDataRead) : + Except String (List YulStmt) := do + let w ← writeTyped resultVar base write + let r ← readTyped read + pure (w ++ r) + +def hasCreate2AndExtcodecopy (stmts : List YulStmt) : Bool := + let hasCreate2 := stmts.any fun + | .let_ _ (.call "create2" _) => true + | _ => false + let hasExtcodecopy := stmts.any fun + | .expr (.call "extcodecopy" _) => true + | _ => false + hasCreate2 && hasExtcodecopy + +end Compiler.Modules.CodeData diff --git a/Compiler/Modules/CodeDataTest.lean b/Compiler/Modules/CodeDataTest.lean new file mode 100644 index 000000000..f205ecfce --- /dev/null +++ b/Compiler/Modules/CodeDataTest.lean @@ -0,0 +1,43 @@ +import Compiler.Modules.CodeData + +namespace Compiler.Modules.CodeDataTest + +open Compiler.ABI.Frame +open Compiler.CompilationModel +open Compiler.Modules.CodeData +open Compiler.Yul + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"CodeData test failed: {label}") + IO.println s!"ok: {label}" + +private def payload := layout + [ { name := "blob", ty := .bytes, source := .memory, tailBytes := 96 } + , { name := "meta", ty := .tuple [.bytes32, .uint256], source := .calldata } ] + +private def deployUsesPayloadBuffer : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .let_ _ (.call "create2" [_value, .ident "__abi_frame_sstore2", .lit 192, _salt]) => true + | _ => false + +#eval! do + let write : CodeDataWrite := + { salt := YulExpr.ident "salt" + payload } + let read : CodeDataRead := + { pointer := YulExpr.ident "ptr" + destOffset := YulExpr.ident "dest" + codeOffset := YulExpr.lit 1 + size := YulExpr.ident "size" + payload } + let roundtrip ← + match roundtripShape "storedPtr" "sstore2" write read with + | .ok stmts => pure stmts + | .error err => throw (IO.userError err) + assert "typed roundtrip has create2 and extcodecopy" (hasCreate2AndExtcodecopy roundtrip) + assert "typed write deploys the materialized payload buffer" (deployUsesPayloadBuffer roundtrip) + assert "trust surface is explicit" (trustSurface.length == 4) + +end Compiler.Modules.CodeDataTest diff --git a/Compiler/Modules/README.md b/Compiler/Modules/README.md index f2bcb7be3..c7487af67 100644 --- a/Compiler/Modules/README.md +++ b/Compiler/Modules/README.md @@ -16,6 +16,7 @@ structure that the compiler can plug in without modification. | `Callbacks.lean` | `callback` | `Stmt.callback` | | `Calls.lean` | `withReturn`, `callWithValue`, `callWithValueBytes`, `bubblingValueCall`, `bubblingValueCallNoOutput` | `Stmt.externalCallWithReturn`; generic `call{value:v}` adapter calls; handwritten low-level `call{value: ...}` wrappers | | `Create2SSTORE2.lean` | `create2Deploy`, `sstore2ReadCode` | handwritten CREATE2 deployment and SSTORE2-style `extcodecopy` code-as-data reads | +| `CodeData.lean` | `writeTyped`, `readTyped`, `roundtripShape` | typed CREATE2/SSTORE2 code-as-data facade using `Compiler.ABI.Frame` layouts and an explicit trust surface | ## Usage diff --git a/Compiler/Proofs/GeneratedTransition.lean b/Compiler/Proofs/GeneratedTransition.lean new file mode 100644 index 000000000..52bc72fd7 --- /dev/null +++ b/Compiler/Proofs/GeneratedTransition.lean @@ -0,0 +1,149 @@ +import Compiler.CompilationModel + +namespace Compiler.Proofs.GeneratedTransition + +open Compiler.CompilationModel + +structure TransitionSummary where + reads : List String := [] + writes : List String := [] + guards : List String := [] + events : List String := [] + deriving Repr, BEq, Inhabited + +private def dedup (xs : List String) : List String := + xs.foldl (fun acc x => if acc.contains x then acc else acc ++ [x]) [] + +private def merge (a b : TransitionSummary) : TransitionSummary := + { reads := dedup (a.reads ++ b.reads) + writes := dedup (a.writes ++ b.writes) + guards := dedup (a.guards ++ b.guards) + events := dedup (a.events ++ b.events) } + +private partial def exprReads : Expr → List String + | .storage field => [field] + | .storageAddr field => [field] + | .mapping field key => field :: exprReads key + | .mapping2 field key1 key2 => field :: exprReads key1 ++ exprReads key2 + | .mappingUint field key => field :: exprReads key + | .mappingChain field keys => field :: keys.flatMap exprReads + | .structMember field key _ => field :: exprReads key + | .structMember2 field key1 key2 _ => field :: exprReads key1 ++ exprReads key2 + | .mload a | .tload a | .calldataload a | .extcodesize a | .returndataOptionalBoolAt a + | .storageArrayElement _ a | .memoryArrayElement _ a => + exprReads a + | .keccak256 a b | .add a b | .sub a b | .mul a b | .div a b | .sdiv a b + | .mod a b | .smod a b | .eq a b | .ge a b | .gt a b | .sgt a b | .lt a b + | .slt a b | .le a b | .logicalAnd a b | .logicalOr a b | .bitAnd a b + | .bitOr a b | .bitXor a b | .shl a b | .shr a b | .sar a b | .byte a b + | .signextend a b | .ceilDiv a b | .wMulDown a b | .wDivUp a b + | .min a b | .max a b => + exprReads a ++ exprReads b + | .call gas target value inOffset inSize outOffset outSize => + [gas, target, value, inOffset, inSize, outOffset, outSize].flatMap exprReads + | .staticcall gas target inOffset inSize outOffset outSize + | .delegatecall gas target inOffset inSize outOffset outSize => + [gas, target, inOffset, inSize, outOffset, outSize].flatMap exprReads + | .bitNot a | .logicalNot a => exprReads a + | .externalCall _ args | .internalCall _ args => args.flatMap exprReads + | .intrinsic _ _ _ args => args.flatMap exprReads + | .forkIfAtLeast _ t e => exprReads t ++ exprReads e + | .arrayLength name => [name ++ ".length"] + | .arrayElement name idx => name :: exprReads idx + | .memoryArrayLength name => [name ++ ".length"] + | .arrayElementWord name idx _ _ => name :: exprReads idx + | .arrayElementDynamicWord name idx _ => name :: exprReads idx + | .arrayElementDynamicDataOffset name idx => name :: exprReads idx + | .arrayElementDynamicMemberLength name idx _ => name :: exprReads idx + | .arrayElementDynamicMemberDataOffset name idx _ => name :: exprReads idx + | .arrayElementDynamicMemberElement name idx _ memberIdx => name :: (exprReads idx ++ exprReads memberIdx) + | .paramDynamicHeadWord name _ => [name] + | .paramDynamicStaticComposite name _ => [name] + | .paramDynamicMemberLength name _ => [name] + | .paramDynamicMemberDataOffset name _ => [name] + | .paramDynamicMemberElement name _ _ => [name] + | .storageArrayLength field => [field ++ ".length"] + | .dynamicBytesEq lhs rhs => [lhs, rhs] + | .ite c t e => exprReads c ++ exprReads t ++ exprReads e + | .adtConstruct _ _ args => args.flatMap exprReads + | .adtTag _ field => [field] + | .adtField _ _ _ _ field => [field] + | .mulDivDown a b c | .mulDivUp a b c | .mulDiv512Down a b c | .mulDiv512Up a b c => + exprReads a ++ exprReads b ++ exprReads c + | _ => [] + +mutual +private partial def stmtSummary : Stmt → TransitionSummary + | .setStorage field value | .setStorageAddr field value => + { reads := dedup (exprReads value), writes := [field] } + | .setStorageWord field offset value => + { reads := dedup (exprReads value), writes := [field ++ "+" ++ toString offset] } + | .storageArrayPush field value => + { reads := dedup ((field ++ ".length") :: exprReads value), writes := [field] } + | .storageArrayPop field => + { reads := [field ++ ".length"], writes := [field] } + | .setStorageArrayElement field idx value => + { reads := dedup (exprReads idx ++ exprReads value), writes := [field ++ "[]"] } + | .setMapping field key value | .setMappingUint field key value => + { reads := dedup (exprReads key ++ exprReads value), writes := [field] } + | .setMappingWord field key offset value => + { reads := dedup (exprReads key ++ exprReads value), writes := [field ++ "+" ++ toString offset] } + | .setMappingPackedWord field key offset _ value => + { reads := dedup (field :: (exprReads key ++ exprReads value)), writes := [field ++ "+" ++ toString offset] } + | .setMapping2 field key1 key2 value => + { reads := dedup (exprReads key1 ++ exprReads key2 ++ exprReads value), writes := [field] } + | .setMapping2Word field key1 key2 offset value => + { reads := dedup (exprReads key1 ++ exprReads key2 ++ exprReads value), writes := [field ++ "+" ++ toString offset] } + | .setMappingChain field keys value => + { reads := dedup (keys.flatMap exprReads ++ exprReads value), writes := [field] } + | .setStructMember field key member value => + { reads := dedup (exprReads key ++ exprReads value), writes := [field ++ "." ++ member] } + | .setStructMember2 field key1 key2 member value => + { reads := dedup (exprReads key1 ++ exprReads key2 ++ exprReads value), writes := [field ++ "." ++ member] } + | .require cond label | .requireError cond label _ => + { reads := dedup (exprReads cond), guards := [label] } + | .revertError label args => + { reads := dedup (args.flatMap exprReads), guards := [label] } + | .emit eventName args => + { reads := dedup (args.flatMap exprReads), events := [eventName] } + | .rawLog topics dataOffset dataSize => + { reads := dedup (topics.flatMap exprReads ++ exprReads dataOffset ++ exprReads dataSize), events := ["rawLog"] } + | .externalCallBind results externalName args => + { reads := dedup (args.flatMap exprReads), writes := results.map ("local:" ++ ·), events := ["external:" ++ externalName] } + | .tryExternalCallBind success results externalName args => + { reads := dedup (args.flatMap exprReads), writes := ("local:" ++ success) :: results.map ("local:" ++ ·), events := ["external:" ++ externalName] } + | .ecm mod args => + { reads := if mod.readsState then ["ecm:" ++ mod.name] else dedup (args.flatMap exprReads) + writes := if mod.writesState then ["ecm:" ++ mod.name] else [] + events := ["ecm:" ++ mod.name] } + | .ite cond t e => + merge { reads := dedup (exprReads cond), guards := ["branch"] } (merge (stmtsSummary t) (stmtsSummary e)) + | .forEach name count body => + merge { reads := dedup (exprReads count), guards := ["loop:" ++ name] } (stmtsSummary body) + | .letVar _ value | .assignVar _ value | .return value | .mstore _ value | .tstore _ value => + { reads := dedup (exprReads value) } + | .returnValues values => { reads := dedup (values.flatMap exprReads) } + | .returnArray name | .returnBytes name | .returnStorageWords name => { reads := [name] } + | .calldatacopy dest src size | .returndataCopy dest src size => + { reads := dedup (exprReads dest ++ exprReads src ++ exprReads size) } + | .internalCall _ args | .internalCallAssign _ _ args => + { reads := dedup (args.flatMap exprReads) } + | .unsafeBlock reason body => + merge { guards := ["unsafe:" ++ reason] } (stmtsSummary body) + | .unsafeYul _ => { events := ["unsafeYul"] } + | .matchAdt adtName scrutinee branches => + merge { reads := dedup (exprReads scrutinee), guards := ["match:" ++ adtName] } + (branches.foldl (fun acc (_, _, body) => merge acc (stmtsSummary body)) {}) + | .stop | .revertReturndata => {} + +private partial def stmtsSummary (stmts : List Stmt) : TransitionSummary := + stmts.foldl (fun acc stmt => merge acc (stmtSummary stmt)) {} +end + +def extract (stmts : List Stmt) : TransitionSummary := + stmts.foldl (fun acc stmt => merge acc (stmtSummary stmt)) {} + +def enoughForMidnightRcfTotalUnits (summary : TransitionSummary) : Bool := + !summary.reads.isEmpty || !summary.writes.isEmpty || !summary.guards.isEmpty || !summary.events.isEmpty + +end Compiler.Proofs.GeneratedTransition diff --git a/Compiler/Proofs/GeneratedTransitionTest.lean b/Compiler/Proofs/GeneratedTransitionTest.lean new file mode 100644 index 000000000..e08bcdf73 --- /dev/null +++ b/Compiler/Proofs/GeneratedTransitionTest.lean @@ -0,0 +1,24 @@ +import Compiler.Proofs.GeneratedTransition + +namespace Compiler.Proofs.GeneratedTransitionTest + +open Compiler.CompilationModel +open Compiler.Proofs.GeneratedTransition + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"GeneratedTransition test failed: {label}") + IO.println s!"ok: {label}" + +#eval! do + let summary := extract + [ Stmt.require (Expr.gt (Expr.storage "totalUnits") (Expr.literal 0)) "nonzero" + , Stmt.setMapping "position" (Expr.param "borrower") (Expr.storage "totalUnits") + , Stmt.emit "Take" [Expr.param "borrower", Expr.storage "totalUnits"] ] + assert "extracts reads" (summary.reads.contains "totalUnits") + assert "extracts writes" (summary.writes.contains "position") + assert "extracts guards" (summary.guards.contains "nonzero") + assert "extracts events" (summary.events.contains "Take") + assert "non-empty summary can feed later Midnight RCF/totalUnits" (enoughForMidnightRcfTotalUnits summary) + +end Compiler.Proofs.GeneratedTransitionTest