Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions Compiler/ABI/Frame.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import Compiler.CompilationModel.AbiTypeLayout
import Compiler.Yul.Ast

namespace Compiler.ABI.Frame

open Compiler.CompilationModel
open Compiler.Yul

inductive FrameSource
| calldata
| memory
| code
| storage
deriving Repr, BEq

inductive FramePassMode
| inlineWords
| pointer
deriving Repr, BEq

structure FrameField where
name : String
ty : ParamType
source : FrameSource
sourceBase : String := ""
tailBytes : Nat := 0
deriving Repr, BEq

structure FrameLayout where
fields : List FrameField
headWords : Nat
hasDynamic : Bool
mode : FramePassMode
deriving Repr, BEq

def spillThresholdWords : Nat := 4

def fieldHeadWords (field : FrameField) : Nat :=
paramParentHeadWords field.ty

def fieldsHeadWords (fields : List FrameField) : Nat :=
fields.foldl (fun acc field => acc + fieldHeadWords field) 0

def fieldsHaveDynamic (fields : List FrameField) : Bool :=
fields.any (fun field => isDynamicParamType field.ty)

def shouldPassByPointer (fields : List FrameField) : Bool :=
fieldsHaveDynamic fields || spillThresholdWords < fieldsHeadWords fields

def layout (fields : List FrameField) : FrameLayout :=
let headWords := fieldsHeadWords fields
let hasDynamic := fieldsHaveDynamic fields
{ fields
headWords
hasDynamic
mode := if hasDynamic || spillThresholdWords < headWords then .pointer else .inlineWords }

def sourceCanMaterializeEarly : FrameSource → Bool
| .calldata | .memory | .code | .storage => true

def layoutSourcesSupported (l : FrameLayout) : Bool :=
l.fields.all (fun field => sourceCanMaterializeEarly field.source)

def frameSizeBytes (l : FrameLayout) : Nat :=
l.fields.foldl (fun acc field => acc + fieldHeadWords field * 32 +
(if isDynamicParamType field.ty then field.tailBytes else 0)) 0

def fieldPayloadWords (field : FrameField) : Nat :=
fieldHeadWords field + if isDynamicParamType field.ty then (field.tailBytes + 31) / 32 else 0

def frameAllocBytes (l : FrameLayout) : Nat :=
l.fields.foldl (fun acc field => acc + fieldPayloadWords field * 32) 0

def ptrName (base : String) : String :=
"__abi_frame_" ++ base

def fieldWordName (base : String) (field : FrameField) (idx : Nat) : String :=
ptrName base ++ "_" ++ field.name ++ "_" ++ toString idx

def allocateFrame (base : String) (l : FrameLayout) : List YulStmt :=
[ YulStmt.let_ (ptrName base) (YulExpr.call "mload" [YulExpr.lit 64])
, YulStmt.expr (YulExpr.call "mstore"
[ YulExpr.lit 64
, YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit (frameAllocBytes l)] ])]

def sourceBaseName (field : FrameField) : String :=
if field.sourceBase.isEmpty then field.name else field.sourceBase

private def sourceByteOffset (field : FrameField) (idx : Nat) : YulExpr :=
if idx == 0 then
YulExpr.ident (sourceBaseName field)
else
YulExpr.call "add" [YulExpr.ident (sourceBaseName field), YulExpr.lit (idx * 32)]

private def sourceStorageSlot (field : FrameField) (idx : Nat) : YulExpr :=
if idx == 0 then
YulExpr.ident (sourceBaseName field)
else
YulExpr.call "add" [YulExpr.ident (sourceBaseName field), YulExpr.lit idx]

private def materializeSourceWord (field : FrameField) (idx : Nat) : YulExpr :=
match field.source with
| .calldata => YulExpr.call "calldataload" [sourceByteOffset field idx]
| .memory => YulExpr.call "mload" [sourceByteOffset field idx]
| .code => YulExpr.call "mload" [sourceByteOffset field idx]
| .storage => YulExpr.call "sload" [sourceStorageSlot field idx]

partial def spillField (base : String) (offsetWords : Nat) (field : FrameField) : List YulStmt :=
(List.range (fieldPayloadWords field)).map fun idx =>
YulStmt.expr (YulExpr.call "mstore"
[ YulExpr.call "add" [YulExpr.ident (ptrName base), YulExpr.lit ((offsetWords + idx) * 32)]
, materializeSourceWord field idx ])
Comment thread
cursor[bot] marked this conversation as resolved.

partial def spillFields (base : String) (offsetWords : Nat) : List FrameField → List YulStmt
| [] => []
| field :: rest => spillField base offsetWords field ++ spillFields base (offsetWords + fieldPayloadWords field) rest

/-- Materialize a typed ABI frame into memory before lowering calls/logs/returns.
Large or dynamic payloads are then passed as `(ptr, size)` instead of as a
long list of Yul values. -/
def spillPayloadToMemory (base : String) (l : FrameLayout) : List YulStmt :=
allocateFrame base l ++ spillFields base 0 l.fields

def pointerArgs (base : String) (l : FrameLayout) : List YulExpr :=
[YulExpr.ident (ptrName base), YulExpr.lit (frameSizeBytes l)]
Comment thread
cursor[bot] marked this conversation as resolved.

def inlinePayloadToScratch (words : List YulExpr) : List YulStmt :=
words.zipIdx.map fun (word, idx) =>
YulStmt.expr (YulExpr.call "mstore" [YulExpr.lit (idx * 32), word])

private partial def inlineArgsFrom : List FrameField → List YulExpr
| [] => []
| field :: rest =>
(List.range (fieldHeadWords field)).map (fun wordIdx =>
materializeSourceWord field wordIdx) ++
inlineArgsFrom rest

def inlineArgs (l : FrameLayout) : List YulExpr :=
inlineArgsFrom l.fields

def loweredArgs (base : String) (l : FrameLayout) : List YulExpr :=
match l.mode with
| .pointer => pointerArgs base l
| .inlineWords => inlineArgs l

def materializePayloadToMemory (base : String) (l : FrameLayout) : List YulStmt × List YulExpr :=
match l.mode with
| .pointer =>
(spillPayloadToMemory base l, pointerArgs base l)
| .inlineWords =>
(inlinePayloadToScratch (inlineArgs l), [YulExpr.lit 0, YulExpr.lit (frameSizeBytes l)])

def containsDynamicArrayOrBytes (l : FrameLayout) : Bool :=
l.fields.any fun field =>
match field.ty with
| .array _ | .bytes | .string => true
| _ => isDynamicParamType field.ty

def supportsNestedStructs (l : FrameLayout) : Bool :=
l.fields.any fun field =>
match field.ty with
| .tuple elems => elems.any (fun ty => match ty with | .tuple _ => true | _ => false)
| _ => false

end Compiler.ABI.Frame
49 changes: 49 additions & 0 deletions Compiler/ABI/FrameTest.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import Compiler.ABI.Frame

namespace Compiler.ABI.FrameTest

open Compiler.ABI.Frame
open Compiler.CompilationModel
open Compiler.Yul

private def assert (label : String) (ok : Bool) : IO Unit := do
if !ok then
throw (IO.userError s!"frame test failed: {label}")
IO.println s!"ok: {label}"

private def takeFields : List FrameField :=
[ { name := "offer", ty := .tuple [.address, .uint256, .tuple [.bytes32, .uint256]], source := .calldata }
, { name := "units", ty := .uint256, source := .calldata }
, { name := "ratifierData", ty := .bytes, source := .calldata, tailBytes := 96 } ]

private def sourceFields : List FrameField :=
[ { name := "c", ty := .uint256, source := .calldata }
, { name := "m", ty := .bytes, source := .memory, tailBytes := 64 }
, { name := "x", ty := .bytes32, source := .code }
, { name := "s", ty := .uint256, source := .storage } ]

private def inlineFields : List FrameField :=
[ { name := "pair", ty := .tuple [.uint256, .bytes32], source := .calldata }
, { name := "amount", ty := .uint256, source := .calldata } ]

private def calldataLoadName? : YulExpr → Option String
| .call "calldataload" [.ident name] => some name
| .call "calldataload" [.call "add" [.ident name, _]] => some name
| _ => none

#eval! do
let takeLayout := layout takeFields
assert "nested struct supported" (supportsNestedStructs takeLayout)
assert "dynamic bytes/arrays force pointer mode" (takeLayout.mode == FramePassMode.pointer)
assert "Take frame passes pointer pair" ((loweredArgs "take" takeLayout).length == 2)
assert "Take spills early to memory" ((spillPayloadToMemory "take" takeLayout).length > 2)
assert "dynamic tail contributes to pointer payload size" (frameSizeBytes takeLayout == 288)
assert "dynamic tail contributes to allocated words" (frameAllocBytes takeLayout == 288)
let srcLayout := layout sourceFields
assert "calldata/memory/code/storage sources supported" (layoutSourcesSupported srcLayout)
assert "dynamic source frame is pointer mode" (srcLayout.mode == FramePassMode.pointer)
let inlineNames := (inlineArgs (layout inlineFields)).filterMap calldataLoadName?
assert "inline source words are indexed per field"
(inlineNames == ["pair", "pair", "amount"])

end Compiler.ABI.FrameTest
57 changes: 57 additions & 0 deletions Compiler/Lowering/StackSafeAbi.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import Compiler.ABI.Frame

namespace Compiler.Lowering.StackSafeAbi

open Compiler.ABI.Frame
open Compiler.Yul

structure LoweredFrame where
prologue : List YulStmt
args : List YulExpr
layout : FrameLayout
deriving Repr

def eventNameTopicWord (eventName : String) : Nat :=
UInt64.toNat (hash eventName)

def lowerFrameSpilled (base : String) (fields : List FrameField) : Except String LoweredFrame := do
let l := layout fields
if !layoutSourcesSupported l then
throw s!"ABI frame '{base}' uses an unsupported source"
let prologue :=
match l.mode with
| .pointer => spillPayloadToMemory base l
| .inlineWords => []
pure { prologue, args := loweredArgs base l, layout := l }

def lowerFrameAsMemoryPayload (base : String) (fields : List FrameField) : Except String (List YulStmt × List YulExpr × FrameLayout) := do
let lowered ← lowerFrameSpilled base fields
let (prologue, args) := materializePayloadToMemory base lowered.layout
pure (prologue, args, lowered.layout)

def lowerEventWithTopic (base : String) (topic0 : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do
let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload base fields
pure (prologue ++
[YulStmt.expr (YulExpr.call "log1" (payloadArgs ++ [topic0]))])

def lowerEvent (eventName : String) (fields : List FrameField) : Except String (List YulStmt) := do
lowerEventWithTopic eventName (YulExpr.lit (eventNameTopicWord eventName)) fields

def lowerExternalCall (callName : String) (target value : YulExpr) (fields : List FrameField) : Except String (List YulStmt) := do
let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload callName fields
let callArgs := [YulExpr.call "gas" [], target, value] ++ payloadArgs ++ [YulExpr.lit 0, YulExpr.lit 0]
pure (prologue ++ [YulStmt.let_ ("__" ++ callName ++ "_ok") (YulExpr.call "call" callArgs)])

def lowerDynamicReturn (returnName : String) (fields : List FrameField) : Except String (List YulStmt) := do
let (prologue, payloadArgs, _) ← lowerFrameAsMemoryPayload returnName fields
pure (prologue ++ [YulStmt.expr (YulExpr.call "return" payloadArgs)])

def usesPointerAbi (stmts : List YulStmt) : Bool :=
stmts.any fun stmt =>
match stmt with
| .expr (.call "return" [_ptr, _size]) => true
| .expr (.call "log1" [_ptr, _size, _topic]) => true
| .let_ _ (.call "call" [_gas, _target, _value, _ptr, _size, _out, _outSize]) => true
| _ => false

end Compiler.Lowering.StackSafeAbi
77 changes: 77 additions & 0 deletions Compiler/Lowering/StackSafeAbiTest.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import Compiler.Lowering.StackSafeAbi

namespace Compiler.Lowering.StackSafeAbiTest

open Compiler.ABI.Frame
open Compiler.CompilationModel
open Compiler.Lowering.StackSafeAbi
open Compiler.Yul

private def assert (label : String) (ok : Bool) : IO Unit := do
if !ok then
throw (IO.userError s!"stack safe ABI test failed: {label}")
IO.println s!"ok: {label}"

private def bigDynamicPayload : List FrameField :=
[ { name := "toId", ty := .bytes32, source := .calldata }
, { name := "toMarket", ty := .tuple [.address, .uint256, .uint256, .address], source := .calldata }
, { name := "takes", ty := .array (.tuple [.address, .uint256, .bytes]), source := .calldata, tailBytes := 128 } ]

private def smallStaticPayload : List FrameField :=
[ { name := "id", ty := .bytes32, source := .calldata }
, { name := "amount", ty := .uint256, source := .calldata } ]

private def countMstores : List YulStmt → Nat :=
List.length ∘ List.filter (fun stmt =>
match stmt with
| .expr (.call "mstore" _) => true
| _ => false)

private def returnsBytes (bytes : Nat) : List YulStmt → Bool :=
fun stmts => stmts.any fun stmt =>
match stmt with
| .expr (.call "return" [.lit 0, .lit n]) => n == bytes
| _ => false

private def callsWithInputBytes (bytes : Nat) : List YulStmt → Bool :=
fun stmts => stmts.any fun stmt =>
match stmt with
| .let_ _ (.call "call" [_gas, _target, _value, .lit 0, .lit n, _out, _outSize]) => n == bytes
| _ => false

private def logsWithDataBytes (bytes : Nat) : List YulStmt → Bool :=
fun stmts => stmts.any fun stmt =>
match stmt with
| .expr (.call "log1" [.lit 0, .lit n, .lit topic]) => n == bytes && topic != 0
| _ => false

#eval! do
match lowerEvent "Take" bigDynamicPayload with
| .ok ev => assert "event lowering uses memory pointer" (usesPointerAbi ev)
| .error err => throw (IO.userError err)
match lowerExternalCall "callback" (YulExpr.ident "target") (YulExpr.lit 0) bigDynamicPayload with
| .ok call => assert "external-call lowering uses memory pointer" (usesPointerAbi call)
| .error err => throw (IO.userError err)
match lowerDynamicReturn "dynamicReturn" bigDynamicPayload with
| .ok ret => assert "dynamic return lowering uses memory pointer" (usesPointerAbi ret)
| .error err => throw (IO.userError err)
match lowerFrameSpilled "toMarket" bigDynamicPayload with
| .ok lowered => assert "frame-spilled lowering allocates memory early" (!lowered.prologue.isEmpty && lowered.args.length == 2)
| .error err => throw (IO.userError err)
match lowerEvent "SmallStatic" smallStaticPayload with
| .ok ev =>
assert "inline event lowering stores payload" (countMstores ev == 2)
assert "inline event lowering logs payload bytes" (logsWithDataBytes 64 ev)
| .error err => throw (IO.userError err)
match lowerExternalCall "smallCallback" (YulExpr.ident "target") (YulExpr.lit 0) smallStaticPayload with
| .ok call =>
assert "inline external-call lowering stores payload" (countMstores call == 2)
assert "inline external-call lowering passes payload bytes" (callsWithInputBytes 64 call)
| .error err => throw (IO.userError err)
match lowerDynamicReturn "smallReturn" smallStaticPayload with
| .ok ret =>
assert "inline dynamic return lowering stores payload" (countMstores ret == 2)
assert "inline dynamic return lowering returns payload bytes" (returnsBytes 64 ret)
| .error err => throw (IO.userError err)

end Compiler.Lowering.StackSafeAbiTest
2 changes: 2 additions & 0 deletions Compiler/Modules.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ import Compiler.Modules.ERC20
import Compiler.Modules.Hashing
import Compiler.Modules.Oracle
import Compiler.Modules.Precompiles
import Compiler.Modules.Create2SSTORE2
import Compiler.Modules.CodeData
Loading