From 487fbb8267b44e178902a8183c8f91e8339abb18 Mon Sep 17 00:00:00 2001 From: Thomas Marchand Date: Thu, 4 Jun 2026 15:25:51 +0200 Subject: [PATCH 1/3] Add ABI frame lowering layers --- Compiler/ABI/Frame.lean | 131 ++++++++++++++++ Compiler/ABI/FrameTest.lean | 34 +++++ Compiler/Lowering/StackSafeAbi.lean | 56 +++++++ Compiler/Lowering/StackSafeAbiTest.lean | 34 +++++ Compiler/Modules.lean | 2 + Compiler/Modules/CodeData.lean | 64 ++++++++ Compiler/Modules/CodeDataTest.lean | 38 +++++ Compiler/Modules/README.md | 1 + Compiler/Proofs/GeneratedTransition.lean | 149 +++++++++++++++++++ Compiler/Proofs/GeneratedTransitionTest.lean | 24 +++ 10 files changed, 533 insertions(+) create mode 100644 Compiler/ABI/Frame.lean create mode 100644 Compiler/ABI/FrameTest.lean create mode 100644 Compiler/Lowering/StackSafeAbi.lean create mode 100644 Compiler/Lowering/StackSafeAbiTest.lean create mode 100644 Compiler/Modules/CodeData.lean create mode 100644 Compiler/Modules/CodeDataTest.lean create mode 100644 Compiler/Proofs/GeneratedTransition.lean create mode 100644 Compiler/Proofs/GeneratedTransitionTest.lean diff --git a/Compiler/ABI/Frame.lean b/Compiler/ABI/Frame.lean new file mode 100644 index 000000000..4d78aab11 --- /dev/null +++ b/Compiler/ABI/Frame.lean @@ -0,0 +1,131 @@ +import Compiler.CompilationModel.AbiTypeLayout +import Compiler.Yul.Ast + +namespace Compiler.ABI.Frame + +open Compiler.CompilationModel +open Compiler.Yul + +inductive FrameSource + | calldata + | memory + | code + | storage + deriving Repr, BEq + +inductive FramePassMode + | inlineWords + | pointer + deriving Repr, BEq + +structure FrameField where + name : String + ty : ParamType + source : FrameSource + deriving Repr, BEq + +structure FrameLayout where + fields : List FrameField + headWords : Nat + hasDynamic : Bool + mode : FramePassMode + deriving Repr, BEq + +def spillThresholdWords : Nat := 4 + +def fieldHeadWords (field : FrameField) : Nat := + paramParentHeadWords field.ty + +def fieldsHeadWords (fields : List FrameField) : Nat := + fields.foldl (fun acc field => acc + fieldHeadWords field) 0 + +def fieldsHaveDynamic (fields : List FrameField) : Bool := + fields.any (fun field => isDynamicParamType field.ty) + +def shouldPassByPointer (fields : List FrameField) : Bool := + fieldsHaveDynamic fields || spillThresholdWords < fieldsHeadWords fields + +def layout (fields : List FrameField) : FrameLayout := + let headWords := fieldsHeadWords fields + let hasDynamic := fieldsHaveDynamic fields + { fields + headWords + hasDynamic + mode := if hasDynamic || spillThresholdWords < headWords then .pointer else .inlineWords } + +def sourceCanMaterializeEarly : FrameSource → Bool + | .calldata | .memory | .code | .storage => true + +def layoutSourcesSupported (l : FrameLayout) : Bool := + l.fields.all (fun field => sourceCanMaterializeEarly field.source) + +def frameSizeBytes (l : FrameLayout) : Nat := + l.headWords * 32 + +def ptrName (base : String) : String := + "__abi_frame_" ++ base + +def fieldWordName (base : String) (field : FrameField) (idx : Nat) : String := + ptrName base ++ "_" ++ field.name ++ "_" ++ toString idx + +def allocateFrame (base : String) (l : FrameLayout) : List YulStmt := + [ YulStmt.let_ (ptrName base) (YulExpr.call "mload" [YulExpr.lit 64]) + , YulStmt.expr (YulExpr.call "mstore" + [ YulExpr.lit 64 + , YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit (frameSizeBytes l)] ])] + +private def materializeSourceWord (field : FrameField) (idx : Nat) : YulExpr := + let name := fieldWordName "src" field idx + match field.source with + | .calldata => YulExpr.call "calldataload" [YulExpr.ident name] + | .memory => YulExpr.call "mload" [YulExpr.ident name] + | .code => YulExpr.call "mload" [YulExpr.ident name] + | .storage => YulExpr.call "sload" [YulExpr.ident name] + +partial def spillField (base : String) (offsetWords : Nat) (field : FrameField) : List YulStmt := + (List.range (fieldHeadWords field)).map fun idx => + YulStmt.expr (YulExpr.call "mstore" + [ YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit ((offsetWords + idx) * 32)] + , materializeSourceWord field idx ]) + +partial def spillFields (base : String) (offsetWords : Nat) : List FrameField → List YulStmt + | [] => [] + | field :: rest => spillField base offsetWords field ++ spillFields base (offsetWords + fieldHeadWords field) rest + +/-- Materialize a typed ABI frame into memory before lowering calls/logs/returns. + Large or dynamic payloads are then passed as `(ptr, size)` instead of as a + long list of Yul values. -/ +def spillPayloadToMemory (base : String) (l : FrameLayout) : List YulStmt := + allocateFrame base l ++ spillFields base 0 l.fields + +def pointerArgs (base : String) (l : FrameLayout) : List YulExpr := + [YulExpr.ident (ptrName base), YulExpr.lit (frameSizeBytes l)] + +private partial def inlineArgsFrom (idx : Nat) : List FrameField → List YulExpr + | [] => [] + | field :: rest => + (List.range (fieldHeadWords field)).map (fun wordIdx => + materializeSourceWord field (idx + wordIdx)) ++ + inlineArgsFrom (idx + 1) rest + +def inlineArgs (l : FrameLayout) : List YulExpr := + inlineArgsFrom 0 l.fields + +def loweredArgs (base : String) (l : FrameLayout) : List YulExpr := + match l.mode with + | .pointer => pointerArgs base l + | .inlineWords => inlineArgs l + +def containsDynamicArrayOrBytes (l : FrameLayout) : Bool := + l.fields.any fun field => + match field.ty with + | .array _ | .bytes | .string => true + | _ => isDynamicParamType field.ty + +def supportsNestedStructs (l : FrameLayout) : Bool := + l.fields.any fun field => + match field.ty with + | .tuple elems => elems.any (fun ty => match ty with | .tuple _ => true | _ => false) + | _ => false + +end Compiler.ABI.Frame diff --git a/Compiler/ABI/FrameTest.lean b/Compiler/ABI/FrameTest.lean new file mode 100644 index 000000000..485eeb59c --- /dev/null +++ b/Compiler/ABI/FrameTest.lean @@ -0,0 +1,34 @@ +import Compiler.ABI.Frame + +namespace Compiler.ABI.FrameTest + +open Compiler.ABI.Frame +open Compiler.CompilationModel + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"frame test failed: {label}") + IO.println s!"ok: {label}" + +private def takeFields : List FrameField := + [ { name := "offer", ty := .tuple [.address, .uint256, .tuple [.bytes32, .uint256]], source := .calldata } + , { name := "units", ty := .uint256, source := .calldata } + , { name := "ratifierData", ty := .bytes, source := .calldata } ] + +private def sourceFields : List FrameField := + [ { name := "c", ty := .uint256, source := .calldata } + , { name := "m", ty := .bytes, source := .memory } + , { name := "x", ty := .bytes32, source := .code } + , { name := "s", ty := .uint256, source := .storage } ] + +#eval! do + let takeLayout := layout takeFields + assert "nested struct supported" (supportsNestedStructs takeLayout) + assert "dynamic bytes/arrays force pointer mode" (takeLayout.mode == FramePassMode.pointer) + assert "Take frame passes pointer pair" ((loweredArgs "take" takeLayout).length == 2) + assert "Take spills early to memory" ((spillPayloadToMemory "take" takeLayout).length > 2) + let srcLayout := layout sourceFields + assert "calldata/memory/code/storage sources supported" (layoutSourcesSupported srcLayout) + assert "dynamic source frame is pointer mode" (srcLayout.mode == FramePassMode.pointer) + +end Compiler.ABI.FrameTest diff --git a/Compiler/Lowering/StackSafeAbi.lean b/Compiler/Lowering/StackSafeAbi.lean new file mode 100644 index 000000000..747c1b876 --- /dev/null +++ b/Compiler/Lowering/StackSafeAbi.lean @@ -0,0 +1,56 @@ +import Compiler.ABI.Frame + +namespace Compiler.Lowering.StackSafeAbi + +open Compiler.ABI.Frame +open Compiler.Yul + +structure LoweredFrame where + prologue : List YulStmt + args : List YulExpr + layout : FrameLayout + deriving Repr + +def lowerFrameSpilled (base : String) (fields : List FrameField) : Except String LoweredFrame := do + let l := layout fields + if !layoutSourcesSupported l then + throw s!"ABI frame '{base}' uses an unsupported source" + let prologue := + match l.mode with + | .pointer => spillPayloadToMemory base l + | .inlineWords => [] + pure { prologue, args := loweredArgs base l, layout := l } + +def lowerEvent (eventName : String) (fields : List FrameField) : Except String (List YulStmt) := do + let lowered ← lowerFrameSpilled eventName fields + match lowered.layout.mode with + | .pointer => + pure (lowered.prologue ++ + [YulStmt.expr (YulExpr.call "log1" (lowered.args ++ [YulExpr.call "keccak256" [YulExpr.str eventName, YulExpr.lit 0]]))]) + | .inlineWords => + pure (lowered.prologue ++ + [YulStmt.expr (YulExpr.call "log1" [YulExpr.lit 0, YulExpr.lit 0, YulExpr.call "keccak256" [YulExpr.str eventName, YulExpr.lit 0]])]) + +def lowerExternalCall (callName : String) (target value : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do + let lowered ← lowerFrameSpilled callName fields + let callArgs := + match lowered.layout.mode with + | .pointer => [YulExpr.call "gas" [], target, value] ++ lowered.args ++ [YulExpr.lit 0, YulExpr.lit 0] + | .inlineWords => [YulExpr.call "gas" [], target, value, YulExpr.lit 0, YulExpr.lit 0, YulExpr.lit 0, YulExpr.lit 0] + pure (lowered.prologue ++ [YulStmt.let_ ("__" ++ callName ++ "_ok") (YulExpr.call "call" callArgs)]) + +def lowerDynamicReturn (returnName : String) (fields : List FrameField) : Except String (List YulStmt) := do + let lowered ← lowerFrameSpilled returnName fields + match lowered.layout.mode with + | .pointer => pure (lowered.prologue ++ [YulStmt.expr (YulExpr.call "return" lowered.args)]) + | .inlineWords => pure (lowered.prologue ++ [YulStmt.expr (YulExpr.call "return" [YulExpr.lit 0, YulExpr.lit 32])]) + +def usesPointerAbi (stmts : List YulStmt) : Bool := + stmts.any fun stmt => + match stmt with + | .expr (.call "return" [_ptr, _size]) => true + | .expr (.call "log1" [_ptr, _size, _topic]) => true + | .let_ _ (.call "call" [_gas, _target, _value, _ptr, _size, _out, _outSize]) => true + | _ => false + +end Compiler.Lowering.StackSafeAbi diff --git a/Compiler/Lowering/StackSafeAbiTest.lean b/Compiler/Lowering/StackSafeAbiTest.lean new file mode 100644 index 000000000..a1ecba2e3 --- /dev/null +++ b/Compiler/Lowering/StackSafeAbiTest.lean @@ -0,0 +1,34 @@ +import Compiler.Lowering.StackSafeAbi + +namespace Compiler.Lowering.StackSafeAbiTest + +open Compiler.ABI.Frame +open Compiler.CompilationModel +open Compiler.Lowering.StackSafeAbi +open Compiler.Yul + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"stack safe ABI test failed: {label}") + IO.println s!"ok: {label}" + +private def bigDynamicPayload : List FrameField := + [ { name := "toId", ty := .bytes32, source := .calldata } + , { name := "toMarket", ty := .tuple [.address, .uint256, .uint256, .address], source := .calldata } + , { name := "takes", ty := .array (.tuple [.address, .uint256, .bytes]), source := .calldata } ] + +#eval! do + match lowerEvent "Take" bigDynamicPayload with + | .ok ev => assert "event lowering uses memory pointer" (usesPointerAbi ev) + | .error err => throw (IO.userError err) + match lowerExternalCall "callback" (YulExpr.ident "target") (YulExpr.lit 0) bigDynamicPayload with + | .ok call => assert "external-call lowering uses memory pointer" (usesPointerAbi call) + | .error err => throw (IO.userError err) + match lowerDynamicReturn "dynamicReturn" bigDynamicPayload with + | .ok ret => assert "dynamic return lowering uses memory pointer" (usesPointerAbi ret) + | .error err => throw (IO.userError err) + match lowerFrameSpilled "toMarket" bigDynamicPayload with + | .ok lowered => assert "frame-spilled lowering allocates memory early" (!lowered.prologue.isEmpty && lowered.args.length == 2) + | .error err => throw (IO.userError err) + +end Compiler.Lowering.StackSafeAbiTest diff --git a/Compiler/Modules.lean b/Compiler/Modules.lean index 62910313b..7fe52e6c3 100644 --- a/Compiler/Modules.lean +++ b/Compiler/Modules.lean @@ -5,3 +5,5 @@ import Compiler.Modules.ERC20 import Compiler.Modules.Hashing import Compiler.Modules.Oracle import Compiler.Modules.Precompiles +import Compiler.Modules.Create2SSTORE2 +import Compiler.Modules.CodeData diff --git a/Compiler/Modules/CodeData.lean b/Compiler/Modules/CodeData.lean new file mode 100644 index 000000000..e8191f987 --- /dev/null +++ b/Compiler/Modules/CodeData.lean @@ -0,0 +1,64 @@ +import Compiler.ABI.Frame +import Compiler.Modules.Create2SSTORE2 + +namespace Compiler.Modules.CodeData + +open Compiler.ABI.Frame +open Compiler.ECM +open Compiler.Yul + +structure CodeDataWrite where + salt : YulExpr + value : YulExpr := YulExpr.lit 0 + initcodeOffset : YulExpr + initcodeSize : YulExpr + payload : FrameLayout + deriving Repr + +structure CodeDataRead where + pointer : YulExpr + destOffset : YulExpr + codeOffset : YulExpr + size : YulExpr + payload : FrameLayout + deriving Repr + +def trustSurface : List String := + [ "CREATE2 address derivation is trusted at the EVM boundary" + , "SSTORE2 pointer code layout is trusted as code-as-data" + , "extcodecopy reads immutable deployed code bytes into caller-owned memory" + , "ABI frame layout is typed by Compiler.ABI.Frame before write/read lowering" ] + +def writeTyped (resultVar base : String) (write : CodeDataWrite) : Except String (List YulStmt) := do + if !layoutSourcesSupported write.payload then + throw "CodeData write payload has unsupported frame source" + let prelude := + match write.payload.mode with + | .pointer => spillPayloadToMemory base write.payload + | .inlineWords => [] + let deploy ← (Compiler.Modules.Create2SSTORE2.deployModule resultVar).compile {} + [write.value, write.initcodeOffset, write.initcodeSize, write.salt] + pure (prelude ++ deploy) + +def readTyped (read : CodeDataRead) : Except String (List YulStmt) := do + if !layoutSourcesSupported read.payload then + throw "CodeData read payload has unsupported frame source" + (Compiler.Modules.Create2SSTORE2.readCodeModule).compile {} + [read.pointer, read.destOffset, read.codeOffset, read.size] + +def roundtripShape (resultVar base : String) (write : CodeDataWrite) (read : CodeDataRead) : + Except String (List YulStmt) := do + let w ← writeTyped resultVar base write + let r ← readTyped read + pure (w ++ r) + +def hasCreate2AndExtcodecopy (stmts : List YulStmt) : Bool := + let hasCreate2 := stmts.any fun + | .let_ _ (.call "create2" _) => true + | _ => false + let hasExtcodecopy := stmts.any fun + | .expr (.call "extcodecopy" _) => true + | _ => false + hasCreate2 && hasExtcodecopy + +end Compiler.Modules.CodeData diff --git a/Compiler/Modules/CodeDataTest.lean b/Compiler/Modules/CodeDataTest.lean new file mode 100644 index 000000000..c953dfe1f --- /dev/null +++ b/Compiler/Modules/CodeDataTest.lean @@ -0,0 +1,38 @@ +import Compiler.Modules.CodeData + +namespace Compiler.Modules.CodeDataTest + +open Compiler.ABI.Frame +open Compiler.CompilationModel +open Compiler.Modules.CodeData +open Compiler.Yul + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"CodeData test failed: {label}") + IO.println s!"ok: {label}" + +private def payload := layout + [ { name := "blob", ty := .bytes, source := .memory } + , { name := "meta", ty := .tuple [.bytes32, .uint256], source := .calldata } ] + +#eval! do + let write : CodeDataWrite := + { salt := YulExpr.ident "salt" + initcodeOffset := YulExpr.ident "init" + initcodeSize := YulExpr.ident "size" + payload } + let read : CodeDataRead := + { pointer := YulExpr.ident "ptr" + destOffset := YulExpr.ident "dest" + codeOffset := YulExpr.lit 1 + size := YulExpr.ident "size" + payload } + let roundtrip ← + match roundtripShape "storedPtr" "sstore2" write read with + | .ok stmts => pure stmts + | .error err => throw (IO.userError err) + assert "typed roundtrip has create2 and extcodecopy" (hasCreate2AndExtcodecopy roundtrip) + assert "trust surface is explicit" (trustSurface.length == 4) + +end Compiler.Modules.CodeDataTest diff --git a/Compiler/Modules/README.md b/Compiler/Modules/README.md index f2bcb7be3..c7487af67 100644 --- a/Compiler/Modules/README.md +++ b/Compiler/Modules/README.md @@ -16,6 +16,7 @@ structure that the compiler can plug in without modification. | `Callbacks.lean` | `callback` | `Stmt.callback` | | `Calls.lean` | `withReturn`, `callWithValue`, `callWithValueBytes`, `bubblingValueCall`, `bubblingValueCallNoOutput` | `Stmt.externalCallWithReturn`; generic `call{value:v}` adapter calls; handwritten low-level `call{value: ...}` wrappers | | `Create2SSTORE2.lean` | `create2Deploy`, `sstore2ReadCode` | handwritten CREATE2 deployment and SSTORE2-style `extcodecopy` code-as-data reads | +| `CodeData.lean` | `writeTyped`, `readTyped`, `roundtripShape` | typed CREATE2/SSTORE2 code-as-data facade using `Compiler.ABI.Frame` layouts and an explicit trust surface | ## Usage diff --git a/Compiler/Proofs/GeneratedTransition.lean b/Compiler/Proofs/GeneratedTransition.lean new file mode 100644 index 000000000..52bc72fd7 --- /dev/null +++ b/Compiler/Proofs/GeneratedTransition.lean @@ -0,0 +1,149 @@ +import Compiler.CompilationModel + +namespace Compiler.Proofs.GeneratedTransition + +open Compiler.CompilationModel + +structure TransitionSummary where + reads : List String := [] + writes : List String := [] + guards : List String := [] + events : List String := [] + deriving Repr, BEq, Inhabited + +private def dedup (xs : List String) : List String := + xs.foldl (fun acc x => if acc.contains x then acc else acc ++ [x]) [] + +private def merge (a b : TransitionSummary) : TransitionSummary := + { reads := dedup (a.reads ++ b.reads) + writes := dedup (a.writes ++ b.writes) + guards := dedup (a.guards ++ b.guards) + events := dedup (a.events ++ b.events) } + +private partial def exprReads : Expr → List String + | .storage field => [field] + | .storageAddr field => [field] + | .mapping field key => field :: exprReads key + | .mapping2 field key1 key2 => field :: exprReads key1 ++ exprReads key2 + | .mappingUint field key => field :: exprReads key + | .mappingChain field keys => field :: keys.flatMap exprReads + | .structMember field key _ => field :: exprReads key + | .structMember2 field key1 key2 _ => field :: exprReads key1 ++ exprReads key2 + | .mload a | .tload a | .calldataload a | .extcodesize a | .returndataOptionalBoolAt a + | .storageArrayElement _ a | .memoryArrayElement _ a => + exprReads a + | .keccak256 a b | .add a b | .sub a b | .mul a b | .div a b | .sdiv a b + | .mod a b | .smod a b | .eq a b | .ge a b | .gt a b | .sgt a b | .lt a b + | .slt a b | .le a b | .logicalAnd a b | .logicalOr a b | .bitAnd a b + | .bitOr a b | .bitXor a b | .shl a b | .shr a b | .sar a b | .byte a b + | .signextend a b | .ceilDiv a b | .wMulDown a b | .wDivUp a b + | .min a b | .max a b => + exprReads a ++ exprReads b + | .call gas target value inOffset inSize outOffset outSize => + [gas, target, value, inOffset, inSize, outOffset, outSize].flatMap exprReads + | .staticcall gas target inOffset inSize outOffset outSize + | .delegatecall gas target inOffset inSize outOffset outSize => + [gas, target, inOffset, inSize, outOffset, outSize].flatMap exprReads + | .bitNot a | .logicalNot a => exprReads a + | .externalCall _ args | .internalCall _ args => args.flatMap exprReads + | .intrinsic _ _ _ args => args.flatMap exprReads + | .forkIfAtLeast _ t e => exprReads t ++ exprReads e + | .arrayLength name => [name ++ ".length"] + | .arrayElement name idx => name :: exprReads idx + | .memoryArrayLength name => [name ++ ".length"] + | .arrayElementWord name idx _ _ => name :: exprReads idx + | .arrayElementDynamicWord name idx _ => name :: exprReads idx + | .arrayElementDynamicDataOffset name idx => name :: exprReads idx + | .arrayElementDynamicMemberLength name idx _ => name :: exprReads idx + | .arrayElementDynamicMemberDataOffset name idx _ => name :: exprReads idx + | .arrayElementDynamicMemberElement name idx _ memberIdx => name :: (exprReads idx ++ exprReads memberIdx) + | .paramDynamicHeadWord name _ => [name] + | .paramDynamicStaticComposite name _ => [name] + | .paramDynamicMemberLength name _ => [name] + | .paramDynamicMemberDataOffset name _ => [name] + | .paramDynamicMemberElement name _ _ => [name] + | .storageArrayLength field => [field ++ ".length"] + | .dynamicBytesEq lhs rhs => [lhs, rhs] + | .ite c t e => exprReads c ++ exprReads t ++ exprReads e + | .adtConstruct _ _ args => args.flatMap exprReads + | .adtTag _ field => [field] + | .adtField _ _ _ _ field => [field] + | .mulDivDown a b c | .mulDivUp a b c | .mulDiv512Down a b c | .mulDiv512Up a b c => + exprReads a ++ exprReads b ++ exprReads c + | _ => [] + +mutual +private partial def stmtSummary : Stmt → TransitionSummary + | .setStorage field value | .setStorageAddr field value => + { reads := dedup (exprReads value), writes := [field] } + | .setStorageWord field offset value => + { reads := dedup (exprReads value), writes := [field ++ "+" ++ toString offset] } + | .storageArrayPush field value => + { reads := dedup ((field ++ ".length") :: exprReads value), writes := [field] } + | .storageArrayPop field => + { reads := [field ++ ".length"], writes := [field] } + | .setStorageArrayElement field idx value => + { reads := dedup (exprReads idx ++ exprReads value), writes := [field ++ "[]"] } + | .setMapping field key value | .setMappingUint field key value => + { reads := dedup (exprReads key ++ exprReads value), writes := [field] } + | .setMappingWord field key offset value => + { reads := dedup (exprReads key ++ exprReads value), writes := [field ++ "+" ++ toString offset] } + | .setMappingPackedWord field key offset _ value => + { reads := dedup (field :: (exprReads key ++ exprReads value)), writes := [field ++ "+" ++ toString offset] } + | .setMapping2 field key1 key2 value => + { reads := dedup (exprReads key1 ++ exprReads key2 ++ exprReads value), writes := [field] } + | .setMapping2Word field key1 key2 offset value => + { reads := dedup (exprReads key1 ++ exprReads key2 ++ exprReads value), writes := [field ++ "+" ++ toString offset] } + | .setMappingChain field keys value => + { reads := dedup (keys.flatMap exprReads ++ exprReads value), writes := [field] } + | .setStructMember field key member value => + { reads := dedup (exprReads key ++ exprReads value), writes := [field ++ "." ++ member] } + | .setStructMember2 field key1 key2 member value => + { reads := dedup (exprReads key1 ++ exprReads key2 ++ exprReads value), writes := [field ++ "." ++ member] } + | .require cond label | .requireError cond label _ => + { reads := dedup (exprReads cond), guards := [label] } + | .revertError label args => + { reads := dedup (args.flatMap exprReads), guards := [label] } + | .emit eventName args => + { reads := dedup (args.flatMap exprReads), events := [eventName] } + | .rawLog topics dataOffset dataSize => + { reads := dedup (topics.flatMap exprReads ++ exprReads dataOffset ++ exprReads dataSize), events := ["rawLog"] } + | .externalCallBind results externalName args => + { reads := dedup (args.flatMap exprReads), writes := results.map ("local:" ++ ·), events := ["external:" ++ externalName] } + | .tryExternalCallBind success results externalName args => + { reads := dedup (args.flatMap exprReads), writes := ("local:" ++ success) :: results.map ("local:" ++ ·), events := ["external:" ++ externalName] } + | .ecm mod args => + { reads := if mod.readsState then ["ecm:" ++ mod.name] else dedup (args.flatMap exprReads) + writes := if mod.writesState then ["ecm:" ++ mod.name] else [] + events := ["ecm:" ++ mod.name] } + | .ite cond t e => + merge { reads := dedup (exprReads cond), guards := ["branch"] } (merge (stmtsSummary t) (stmtsSummary e)) + | .forEach name count body => + merge { reads := dedup (exprReads count), guards := ["loop:" ++ name] } (stmtsSummary body) + | .letVar _ value | .assignVar _ value | .return value | .mstore _ value | .tstore _ value => + { reads := dedup (exprReads value) } + | .returnValues values => { reads := dedup (values.flatMap exprReads) } + | .returnArray name | .returnBytes name | .returnStorageWords name => { reads := [name] } + | .calldatacopy dest src size | .returndataCopy dest src size => + { reads := dedup (exprReads dest ++ exprReads src ++ exprReads size) } + | .internalCall _ args | .internalCallAssign _ _ args => + { reads := dedup (args.flatMap exprReads) } + | .unsafeBlock reason body => + merge { guards := ["unsafe:" ++ reason] } (stmtsSummary body) + | .unsafeYul _ => { events := ["unsafeYul"] } + | .matchAdt adtName scrutinee branches => + merge { reads := dedup (exprReads scrutinee), guards := ["match:" ++ adtName] } + (branches.foldl (fun acc (_, _, body) => merge acc (stmtsSummary body)) {}) + | .stop | .revertReturndata => {} + +private partial def stmtsSummary (stmts : List Stmt) : TransitionSummary := + stmts.foldl (fun acc stmt => merge acc (stmtSummary stmt)) {} +end + +def extract (stmts : List Stmt) : TransitionSummary := + stmts.foldl (fun acc stmt => merge acc (stmtSummary stmt)) {} + +def enoughForMidnightRcfTotalUnits (summary : TransitionSummary) : Bool := + !summary.reads.isEmpty || !summary.writes.isEmpty || !summary.guards.isEmpty || !summary.events.isEmpty + +end Compiler.Proofs.GeneratedTransition diff --git a/Compiler/Proofs/GeneratedTransitionTest.lean b/Compiler/Proofs/GeneratedTransitionTest.lean new file mode 100644 index 000000000..e08bcdf73 --- /dev/null +++ b/Compiler/Proofs/GeneratedTransitionTest.lean @@ -0,0 +1,24 @@ +import Compiler.Proofs.GeneratedTransition + +namespace Compiler.Proofs.GeneratedTransitionTest + +open Compiler.CompilationModel +open Compiler.Proofs.GeneratedTransition + +private def assert (label : String) (ok : Bool) : IO Unit := do + if !ok then + throw (IO.userError s!"GeneratedTransition test failed: {label}") + IO.println s!"ok: {label}" + +#eval! do + let summary := extract + [ Stmt.require (Expr.gt (Expr.storage "totalUnits") (Expr.literal 0)) "nonzero" + , Stmt.setMapping "position" (Expr.param "borrower") (Expr.storage "totalUnits") + , Stmt.emit "Take" [Expr.param "borrower", Expr.storage "totalUnits"] ] + assert "extracts reads" (summary.reads.contains "totalUnits") + assert "extracts writes" (summary.writes.contains "position") + assert "extracts guards" (summary.guards.contains "nonzero") + assert "extracts events" (summary.events.contains "Take") + assert "non-empty summary can feed later Midnight RCF/totalUnits" (enoughForMidnightRcfTotalUnits summary) + +end Compiler.Proofs.GeneratedTransitionTest From 918faadae369943ac11a851c302b2e4a28ec3cd7 Mon Sep 17 00:00:00 2001 From: Thomas Marchand Date: Thu, 4 Jun 2026 15:47:05 +0200 Subject: [PATCH 2/3] Fix stack-safe ABI lowering review issues --- Compiler/ABI/Frame.lean | 8 ++--- Compiler/ABI/FrameTest.lean | 12 +++++++ Compiler/Lowering/StackSafeAbi.lean | 42 +++++++++++++++--------- Compiler/Lowering/StackSafeAbiTest.lean | 43 +++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 20 deletions(-) diff --git a/Compiler/ABI/Frame.lean b/Compiler/ABI/Frame.lean index 4d78aab11..82120c2db 100644 --- a/Compiler/ABI/Frame.lean +++ b/Compiler/ABI/Frame.lean @@ -101,15 +101,15 @@ def spillPayloadToMemory (base : String) (l : FrameLayout) : List YulStmt := def pointerArgs (base : String) (l : FrameLayout) : List YulExpr := [YulExpr.ident (ptrName base), YulExpr.lit (frameSizeBytes l)] -private partial def inlineArgsFrom (idx : Nat) : List FrameField → List YulExpr +private partial def inlineArgsFrom : List FrameField → List YulExpr | [] => [] | field :: rest => (List.range (fieldHeadWords field)).map (fun wordIdx => - materializeSourceWord field (idx + wordIdx)) ++ - inlineArgsFrom (idx + 1) rest + materializeSourceWord field wordIdx) ++ + inlineArgsFrom rest def inlineArgs (l : FrameLayout) : List YulExpr := - inlineArgsFrom 0 l.fields + inlineArgsFrom l.fields def loweredArgs (base : String) (l : FrameLayout) : List YulExpr := match l.mode with diff --git a/Compiler/ABI/FrameTest.lean b/Compiler/ABI/FrameTest.lean index 485eeb59c..d086e057a 100644 --- a/Compiler/ABI/FrameTest.lean +++ b/Compiler/ABI/FrameTest.lean @@ -4,6 +4,7 @@ namespace Compiler.ABI.FrameTest open Compiler.ABI.Frame open Compiler.CompilationModel +open Compiler.Yul private def assert (label : String) (ok : Bool) : IO Unit := do if !ok then @@ -21,6 +22,14 @@ private def sourceFields : List FrameField := , { name := "x", ty := .bytes32, source := .code } , { name := "s", ty := .uint256, source := .storage } ] +private def inlineFields : List FrameField := + [ { name := "pair", ty := .tuple [.uint256, .bytes32], source := .calldata } + , { name := "amount", ty := .uint256, source := .calldata } ] + +private def calldataLoadName? : YulExpr → Option String + | .call "calldataload" [.ident name] => some name + | _ => none + #eval! do let takeLayout := layout takeFields assert "nested struct supported" (supportsNestedStructs takeLayout) @@ -30,5 +39,8 @@ private def sourceFields : List FrameField := let srcLayout := layout sourceFields assert "calldata/memory/code/storage sources supported" (layoutSourcesSupported srcLayout) assert "dynamic source frame is pointer mode" (srcLayout.mode == FramePassMode.pointer) + let inlineNames := (inlineArgs (layout inlineFields)).filterMap calldataLoadName? + assert "inline source words are indexed per field" + (inlineNames == ["__abi_frame_src_pair_0", "__abi_frame_src_pair_1", "__abi_frame_src_amount_0"]) end Compiler.ABI.FrameTest diff --git a/Compiler/Lowering/StackSafeAbi.lean b/Compiler/Lowering/StackSafeAbi.lean index 747c1b876..a364e5d9f 100644 --- a/Compiler/Lowering/StackSafeAbi.lean +++ b/Compiler/Lowering/StackSafeAbi.lean @@ -11,6 +11,13 @@ structure LoweredFrame where layout : FrameLayout deriving Repr +def eventNameTopicWord (eventName : String) : Nat := + UInt64.toNat (hash eventName) + +def inlinePayloadToScratch (words : List YulExpr) : List YulStmt := + words.zipIdx.map fun (word, idx) => + YulStmt.expr (YulExpr.call "mstore" [YulExpr.lit (idx * 32), word]) + def lowerFrameSpilled (base : String) (fields : List FrameField) : Except String LoweredFrame := do let l := layout fields if !layoutSourcesSupported l then @@ -21,29 +28,32 @@ def lowerFrameSpilled (base : String) (fields : List FrameField) : Except String | .inlineWords => [] pure { prologue, args := loweredArgs base l, layout := l } -def lowerEvent (eventName : String) (fields : List FrameField) : Except String (List YulStmt) := do - let lowered ← lowerFrameSpilled eventName fields +def lowerFrameAsMemoryPayload (base : String) (fields : List FrameField) : Except String (List YulStmt × List YulExpr × FrameLayout) := do + let lowered ← lowerFrameSpilled base fields match lowered.layout.mode with | .pointer => - pure (lowered.prologue ++ - [YulStmt.expr (YulExpr.call "log1" (lowered.args ++ [YulExpr.call "keccak256" [YulExpr.str eventName, YulExpr.lit 0]]))]) + pure (lowered.prologue, lowered.args, lowered.layout) | .inlineWords => - pure (lowered.prologue ++ - [YulStmt.expr (YulExpr.call "log1" [YulExpr.lit 0, YulExpr.lit 0, YulExpr.call "keccak256" [YulExpr.str eventName, YulExpr.lit 0]])]) + pure (lowered.prologue ++ inlinePayloadToScratch lowered.args, + [YulExpr.lit 0, YulExpr.lit (lowered.layout.headWords * 32)], + lowered.layout) + +def lowerEventWithTopic (base : String) (topic0 : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do + let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload base fields + pure (prologue ++ + [YulStmt.expr (YulExpr.call "log1" (payloadArgs ++ [topic0]))]) + +def lowerEvent (eventName : String) (fields : List FrameField) : Except String (List YulStmt) := do + lowerEventWithTopic eventName (YulExpr.lit (eventNameTopicWord eventName)) fields def lowerExternalCall (callName : String) (target value : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do - let lowered ← lowerFrameSpilled callName fields - let callArgs := - match lowered.layout.mode with - | .pointer => [YulExpr.call "gas" [], target, value] ++ lowered.args ++ [YulExpr.lit 0, YulExpr.lit 0] - | .inlineWords => [YulExpr.call "gas" [], target, value, YulExpr.lit 0, YulExpr.lit 0, YulExpr.lit 0, YulExpr.lit 0] - pure (lowered.prologue ++ [YulStmt.let_ ("__" ++ callName ++ "_ok") (YulExpr.call "call" callArgs)]) + let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload callName fields + let callArgs := [YulExpr.call "gas" [], target, value] ++ payloadArgs ++ [YulExpr.lit 0, YulExpr.lit 0] + pure (prologue ++ [YulStmt.let_ ("__" ++ callName ++ "_ok") (YulExpr.call "call" callArgs)]) def lowerDynamicReturn (returnName : String) (fields : List FrameField) : Except String (List YulStmt) := do - let lowered ← lowerFrameSpilled returnName fields - match lowered.layout.mode with - | .pointer => pure (lowered.prologue ++ [YulStmt.expr (YulExpr.call "return" lowered.args)]) - | .inlineWords => pure (lowered.prologue ++ [YulStmt.expr (YulExpr.call "return" [YulExpr.lit 0, YulExpr.lit 32])]) + let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload returnName fields + pure (prologue ++ [YulStmt.expr (YulExpr.call "return" payloadArgs)]) def usesPointerAbi (stmts : List YulStmt) : Bool := stmts.any fun stmt => diff --git a/Compiler/Lowering/StackSafeAbiTest.lean b/Compiler/Lowering/StackSafeAbiTest.lean index a1ecba2e3..93dd3b42d 100644 --- a/Compiler/Lowering/StackSafeAbiTest.lean +++ b/Compiler/Lowering/StackSafeAbiTest.lean @@ -17,6 +17,34 @@ private def bigDynamicPayload : List FrameField := , { name := "toMarket", ty := .tuple [.address, .uint256, .uint256, .address], source := .calldata } , { name := "takes", ty := .array (.tuple [.address, .uint256, .bytes]), source := .calldata } ] +private def smallStaticPayload : List FrameField := + [ { name := "id", ty := .bytes32, source := .calldata } + , { name := "amount", ty := .uint256, source := .calldata } ] + +private def countMstores : List YulStmt → Nat := + List.length ∘ List.filter (fun stmt => + match stmt with + | .expr (.call "mstore" _) => true + | _ => false) + +private def returnsBytes (bytes : Nat) : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .expr (.call "return" [.lit 0, .lit n]) => n == bytes + | _ => false + +private def callsWithInputBytes (bytes : Nat) : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .let_ _ (.call "call" [_gas, _target, _value, .lit 0, .lit n, _out, _outSize]) => n == bytes + | _ => false + +private def logsWithDataBytes (bytes : Nat) : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .expr (.call "log1" [.lit 0, .lit n, .lit topic]) => n == bytes && topic != 0 + | _ => false + #eval! do match lowerEvent "Take" bigDynamicPayload with | .ok ev => assert "event lowering uses memory pointer" (usesPointerAbi ev) @@ -30,5 +58,20 @@ private def bigDynamicPayload : List FrameField := match lowerFrameSpilled "toMarket" bigDynamicPayload with | .ok lowered => assert "frame-spilled lowering allocates memory early" (!lowered.prologue.isEmpty && lowered.args.length == 2) | .error err => throw (IO.userError err) + match lowerEvent "SmallStatic" smallStaticPayload with + | .ok ev => + assert "inline event lowering stores payload" (countMstores ev == 2) + assert "inline event lowering logs payload bytes" (logsWithDataBytes 64 ev) + | .error err => throw (IO.userError err) + match lowerExternalCall "smallCallback" (YulExpr.ident "target") (YulExpr.lit 0) smallStaticPayload with + | .ok call => + assert "inline external-call lowering stores payload" (countMstores call == 2) + assert "inline external-call lowering passes payload bytes" (callsWithInputBytes 64 call) + | .error err => throw (IO.userError err) + match lowerDynamicReturn "smallReturn" smallStaticPayload with + | .ok ret => + assert "inline dynamic return lowering stores payload" (countMstores ret == 2) + assert "inline dynamic return lowering returns payload bytes" (returnsBytes 64 ret) + | .error err => throw (IO.userError err) end Compiler.Lowering.StackSafeAbiTest From 0df4e11b7db499f5d962197c51632ee6026fb266 Mon Sep 17 00:00:00 2001 From: Thomas Marchand Date: Thu, 4 Jun 2026 15:54:01 +0200 Subject: [PATCH 3/3] Wire ABI frame payloads through lowering --- Compiler/ABI/Frame.lean | 52 ++++++++++++++++++++----- Compiler/ABI/FrameTest.lean | 9 +++-- Compiler/Lowering/StackSafeAbi.lean | 13 +------ Compiler/Lowering/StackSafeAbiTest.lean | 2 +- Compiler/Modules/CodeData.lean | 17 ++++---- Compiler/Modules/CodeDataTest.lean | 11 ++++-- 6 files changed, 70 insertions(+), 34 deletions(-) diff --git a/Compiler/ABI/Frame.lean b/Compiler/ABI/Frame.lean index 82120c2db..5e4fecd55 100644 --- a/Compiler/ABI/Frame.lean +++ b/Compiler/ABI/Frame.lean @@ -22,6 +22,8 @@ structure FrameField where name : String ty : ParamType source : FrameSource + sourceBase : String := "" + tailBytes : Nat := 0 deriving Repr, BEq structure FrameLayout where @@ -60,7 +62,14 @@ def layoutSourcesSupported (l : FrameLayout) : Bool := l.fields.all (fun field => sourceCanMaterializeEarly field.source) def frameSizeBytes (l : FrameLayout) : Nat := - l.headWords * 32 + l.fields.foldl (fun acc field => acc + fieldHeadWords field * 32 + + (if isDynamicParamType field.ty then field.tailBytes else 0)) 0 + +def fieldPayloadWords (field : FrameField) : Nat := + fieldHeadWords field + if isDynamicParamType field.ty then (field.tailBytes + 31) / 32 else 0 + +def frameAllocBytes (l : FrameLayout) : Nat := + l.fields.foldl (fun acc field => acc + fieldPayloadWords field * 32) 0 def ptrName (base : String) : String := "__abi_frame_" ++ base @@ -72,25 +81,39 @@ def allocateFrame (base : String) (l : FrameLayout) : List YulStmt := [ YulStmt.let_ (ptrName base) (YulExpr.call "mload" [YulExpr.lit 64]) , YulStmt.expr (YulExpr.call "mstore" [ YulExpr.lit 64 - , YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit (frameSizeBytes l)] ])] + , YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit (frameAllocBytes l)] ])] + +def sourceBaseName (field : FrameField) : String := + if field.sourceBase.isEmpty then field.name else field.sourceBase + +private def sourceByteOffset (field : FrameField) (idx : Nat) : YulExpr := + if idx == 0 then + YulExpr.ident (sourceBaseName field) + else + YulExpr.call "add" [YulExpr.ident (sourceBaseName field), YulExpr.lit (idx * 32)] + +private def sourceStorageSlot (field : FrameField) (idx : Nat) : YulExpr := + if idx == 0 then + YulExpr.ident (sourceBaseName field) + else + YulExpr.call "add" [YulExpr.ident (sourceBaseName field), YulExpr.lit idx] private def materializeSourceWord (field : FrameField) (idx : Nat) : YulExpr := - let name := fieldWordName "src" field idx match field.source with - | .calldata => YulExpr.call "calldataload" [YulExpr.ident name] - | .memory => YulExpr.call "mload" [YulExpr.ident name] - | .code => YulExpr.call "mload" [YulExpr.ident name] - | .storage => YulExpr.call "sload" [YulExpr.ident name] + | .calldata => YulExpr.call "calldataload" [sourceByteOffset field idx] + | .memory => YulExpr.call "mload" [sourceByteOffset field idx] + | .code => YulExpr.call "mload" [sourceByteOffset field idx] + | .storage => YulExpr.call "sload" [sourceStorageSlot field idx] partial def spillField (base : String) (offsetWords : Nat) (field : FrameField) : List YulStmt := - (List.range (fieldHeadWords field)).map fun idx => + (List.range (fieldPayloadWords field)).map fun idx => YulStmt.expr (YulExpr.call "mstore" [ YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit ((offsetWords + idx) * 32)] , materializeSourceWord field idx ]) partial def spillFields (base : String) (offsetWords : Nat) : List FrameField → List YulStmt | [] => [] - | field :: rest => spillField base offsetWords field ++ spillFields base (offsetWords + fieldHeadWords field) rest + | field :: rest => spillField base offsetWords field ++ spillFields base (offsetWords + fieldPayloadWords field) rest /-- Materialize a typed ABI frame into memory before lowering calls/logs/returns. Large or dynamic payloads are then passed as `(ptr, size)` instead of as a @@ -101,6 +124,10 @@ def spillPayloadToMemory (base : String) (l : FrameLayout) : List YulStmt := def pointerArgs (base : String) (l : FrameLayout) : List YulExpr := [YulExpr.ident (ptrName base), YulExpr.lit (frameSizeBytes l)] +def inlinePayloadToScratch (words : List YulExpr) : List YulStmt := + words.zipIdx.map fun (word, idx) => + YulStmt.expr (YulExpr.call "mstore" [YulExpr.lit (idx * 32), word]) + private partial def inlineArgsFrom : List FrameField → List YulExpr | [] => [] | field :: rest => @@ -116,6 +143,13 @@ def loweredArgs (base : String) (l : FrameLayout) : List YulExpr := | .pointer => pointerArgs base l | .inlineWords => inlineArgs l +def materializePayloadToMemory (base : String) (l : FrameLayout) : List YulStmt × List YulExpr := + match l.mode with + | .pointer => + (spillPayloadToMemory base l, pointerArgs base l) + | .inlineWords => + (inlinePayloadToScratch (inlineArgs l), [YulExpr.lit 0, YulExpr.lit (frameSizeBytes l)]) + def containsDynamicArrayOrBytes (l : FrameLayout) : Bool := l.fields.any fun field => match field.ty with diff --git a/Compiler/ABI/FrameTest.lean b/Compiler/ABI/FrameTest.lean index d086e057a..e86850935 100644 --- a/Compiler/ABI/FrameTest.lean +++ b/Compiler/ABI/FrameTest.lean @@ -14,11 +14,11 @@ private def assert (label : String) (ok : Bool) : IO Unit := do private def takeFields : List FrameField := [ { name := "offer", ty := .tuple [.address, .uint256, .tuple [.bytes32, .uint256]], source := .calldata } , { name := "units", ty := .uint256, source := .calldata } - , { name := "ratifierData", ty := .bytes, source := .calldata } ] + , { name := "ratifierData", ty := .bytes, source := .calldata, tailBytes := 96 } ] private def sourceFields : List FrameField := [ { name := "c", ty := .uint256, source := .calldata } - , { name := "m", ty := .bytes, source := .memory } + , { name := "m", ty := .bytes, source := .memory, tailBytes := 64 } , { name := "x", ty := .bytes32, source := .code } , { name := "s", ty := .uint256, source := .storage } ] @@ -28,6 +28,7 @@ private def inlineFields : List FrameField := private def calldataLoadName? : YulExpr → Option String | .call "calldataload" [.ident name] => some name + | .call "calldataload" [.call "add" [.ident name, _]] => some name | _ => none #eval! do @@ -36,11 +37,13 @@ private def calldataLoadName? : YulExpr → Option String assert "dynamic bytes/arrays force pointer mode" (takeLayout.mode == FramePassMode.pointer) assert "Take frame passes pointer pair" ((loweredArgs "take" takeLayout).length == 2) assert "Take spills early to memory" ((spillPayloadToMemory "take" takeLayout).length > 2) + assert "dynamic tail contributes to pointer payload size" (frameSizeBytes takeLayout == 288) + assert "dynamic tail contributes to allocated words" (frameAllocBytes takeLayout == 288) let srcLayout := layout sourceFields assert "calldata/memory/code/storage sources supported" (layoutSourcesSupported srcLayout) assert "dynamic source frame is pointer mode" (srcLayout.mode == FramePassMode.pointer) let inlineNames := (inlineArgs (layout inlineFields)).filterMap calldataLoadName? assert "inline source words are indexed per field" - (inlineNames == ["__abi_frame_src_pair_0", "__abi_frame_src_pair_1", "__abi_frame_src_amount_0"]) + (inlineNames == ["pair", "pair", "amount"]) end Compiler.ABI.FrameTest diff --git a/Compiler/Lowering/StackSafeAbi.lean b/Compiler/Lowering/StackSafeAbi.lean index a364e5d9f..852c9ef15 100644 --- a/Compiler/Lowering/StackSafeAbi.lean +++ b/Compiler/Lowering/StackSafeAbi.lean @@ -14,10 +14,6 @@ structure LoweredFrame where def eventNameTopicWord (eventName : String) : Nat := UInt64.toNat (hash eventName) -def inlinePayloadToScratch (words : List YulExpr) : List YulStmt := - words.zipIdx.map fun (word, idx) => - YulStmt.expr (YulExpr.call "mstore" [YulExpr.lit (idx * 32), word]) - def lowerFrameSpilled (base : String) (fields : List FrameField) : Except String LoweredFrame := do let l := layout fields if !layoutSourcesSupported l then @@ -30,13 +26,8 @@ def lowerFrameSpilled (base : String) (fields : List FrameField) : Except String def lowerFrameAsMemoryPayload (base : String) (fields : List FrameField) : Except String (List YulStmt × List YulExpr × FrameLayout) := do let lowered ← lowerFrameSpilled base fields - match lowered.layout.mode with - | .pointer => - pure (lowered.prologue, lowered.args, lowered.layout) - | .inlineWords => - pure (lowered.prologue ++ inlinePayloadToScratch lowered.args, - [YulExpr.lit 0, YulExpr.lit (lowered.layout.headWords * 32)], - lowered.layout) + let (prologue, args) := materializePayloadToMemory base lowered.layout + pure (prologue, args, lowered.layout) def lowerEventWithTopic (base : String) (topic0 : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload base fields diff --git a/Compiler/Lowering/StackSafeAbiTest.lean b/Compiler/Lowering/StackSafeAbiTest.lean index 93dd3b42d..22ea333df 100644 --- a/Compiler/Lowering/StackSafeAbiTest.lean +++ b/Compiler/Lowering/StackSafeAbiTest.lean @@ -15,7 +15,7 @@ private def assert (label : String) (ok : Bool) : IO Unit := do private def bigDynamicPayload : List FrameField := [ { name := "toId", ty := .bytes32, source := .calldata } , { name := "toMarket", ty := .tuple [.address, .uint256, .uint256, .address], source := .calldata } - , { name := "takes", ty := .array (.tuple [.address, .uint256, .bytes]), source := .calldata } ] + , { name := "takes", ty := .array (.tuple [.address, .uint256, .bytes]), source := .calldata, tailBytes := 128 } ] private def smallStaticPayload : List FrameField := [ { name := "id", ty := .bytes32, source := .calldata } diff --git a/Compiler/Modules/CodeData.lean b/Compiler/Modules/CodeData.lean index e8191f987..e44f2201b 100644 --- a/Compiler/Modules/CodeData.lean +++ b/Compiler/Modules/CodeData.lean @@ -10,8 +10,6 @@ open Compiler.Yul structure CodeDataWrite where salt : YulExpr value : YulExpr := YulExpr.lit 0 - initcodeOffset : YulExpr - initcodeSize : YulExpr payload : FrameLayout deriving Repr @@ -32,12 +30,17 @@ def trustSurface : List String := def writeTyped (resultVar base : String) (write : CodeDataWrite) : Except String (List YulStmt) := do if !layoutSourcesSupported write.payload then throw "CodeData write payload has unsupported frame source" - let prelude := - match write.payload.mode with - | .pointer => spillPayloadToMemory base write.payload - | .inlineWords => [] + let (prelude, payloadArgs) := materializePayloadToMemory base write.payload + let initcodeOffset ← + match payloadArgs with + | [offset, _size] => pure offset + | _ => throw "CodeData write expected a memory payload pointer and size" + let initcodeSize ← + match payloadArgs with + | [_offset, size] => pure size + | _ => throw "CodeData write expected a memory payload pointer and size" let deploy ← (Compiler.Modules.Create2SSTORE2.deployModule resultVar).compile {} - [write.value, write.initcodeOffset, write.initcodeSize, write.salt] + [write.value, initcodeOffset, initcodeSize, write.salt] pure (prelude ++ deploy) def readTyped (read : CodeDataRead) : Except String (List YulStmt) := do diff --git a/Compiler/Modules/CodeDataTest.lean b/Compiler/Modules/CodeDataTest.lean index c953dfe1f..f205ecfce 100644 --- a/Compiler/Modules/CodeDataTest.lean +++ b/Compiler/Modules/CodeDataTest.lean @@ -13,14 +13,18 @@ private def assert (label : String) (ok : Bool) : IO Unit := do IO.println s!"ok: {label}" private def payload := layout - [ { name := "blob", ty := .bytes, source := .memory } + [ { name := "blob", ty := .bytes, source := .memory, tailBytes := 96 } , { name := "meta", ty := .tuple [.bytes32, .uint256], source := .calldata } ] +private def deployUsesPayloadBuffer : List YulStmt → Bool := + fun stmts => stmts.any fun stmt => + match stmt with + | .let_ _ (.call "create2" [_value, .ident "__abi_frame_sstore2", .lit 192, _salt]) => true + | _ => false + #eval! do let write : CodeDataWrite := { salt := YulExpr.ident "salt" - initcodeOffset := YulExpr.ident "init" - initcodeSize := YulExpr.ident "size" payload } let read : CodeDataRead := { pointer := YulExpr.ident "ptr" @@ -33,6 +37,7 @@ private def payload := layout | .ok stmts => pure stmts | .error err => throw (IO.userError err) assert "typed roundtrip has create2 and extcodecopy" (hasCreate2AndExtcodecopy roundtrip) + assert "typed write deploys the materialized payload buffer" (deployUsesPayloadBuffer roundtrip) assert "trust surface is explicit" (trustSurface.length == 4) end Compiler.Modules.CodeDataTest