33import math
44from abc import ABC , abstractmethod
55from dataclasses import dataclass , field
6- from typing import Callable , ClassVar , Dict , Generator , List , Sequence , Type
6+ from typing import Any , Callable , ClassVar , Dict , Generator , List , Sequence , Type
77
88import pytest
99from pydantic import ConfigDict , Field
@@ -41,27 +41,44 @@ class BenchmarkCodeGenerator(ABC):
4141 attack_block : Bytecode
4242 setup : Bytecode = field (default_factory = Bytecode )
4343 cleanup : Bytecode = field (default_factory = Bytecode )
44+ tx_kwargs : Dict [str , Any ] = field (default_factory = dict )
45+ _contract_address : Address | None = None
4446
4547 @abstractmethod
46- def deploy_contracts (self , pre : Alloc , fork : Fork ) -> Address :
48+ def deploy_contracts (self , * , pre : Alloc , fork : Fork ) -> Address :
4749 """Deploy any contracts needed for the benchmark."""
4850 ...
4951
50- @abstractmethod
51- def generate_transaction (self , pre : Alloc , gas_limit : int ) -> Transaction :
52- """Generate a transaction with the specified gas limit."""
53- ...
52+ def generate_transaction (self , * , pre : Alloc , gas_benchmark_value : int ) -> Transaction :
53+ """Generate transaction that executes the looping contract."""
54+ assert self ._contract_address is not None
55+ if "gas_limit" not in self .tx_kwargs :
56+ self .tx_kwargs ["gas_limit" ] = gas_benchmark_value
57+
58+ return Transaction (
59+ to = self ._contract_address ,
60+ sender = pre .fund_eoa (),
61+ ** self .tx_kwargs ,
62+ )
5463
5564 def generate_repeated_code (
56- self , repeated_code : Bytecode , setup : Bytecode , cleanup : Bytecode , fork : Fork
65+ self ,
66+ * ,
67+ repeated_code : Bytecode ,
68+ setup : Bytecode | None = None ,
69+ cleanup : Bytecode | None = None ,
70+ fork : Fork ,
5771 ) -> Bytecode :
5872 """
5973 Calculate the maximum number of iterations that
6074 can fit in the code size limit.
6175 """
6276 assert len (repeated_code ) > 0 , "repeated_code cannot be empty"
6377 max_code_size = fork .max_code_size ()
64-
78+ if setup is None :
79+ setup = Bytecode ()
80+ if cleanup is None :
81+ cleanup = Bytecode ()
6582 overhead = len (setup ) + len (Op .JUMPDEST ) + len (cleanup ) + len (Op .JUMP (len (setup )))
6683 available_space = max_code_size - overhead
6784 max_iterations = available_space // len (repeated_code )
@@ -87,7 +104,7 @@ class BenchmarkTest(BaseTest):
87104
88105 model_config = ConfigDict (extra = "forbid" )
89106
90- pre : Alloc
107+ pre : Alloc = Field ( default_factory = Alloc )
91108 post : Alloc = Field (default_factory = Alloc )
92109 tx : Transaction | None = None
93110 blocks : List [Block ] | None = None
@@ -118,6 +135,14 @@ class BenchmarkTest(BaseTest):
118135 "blockchain_test_only" : "Only generate a blockchain test fixture" ,
119136 }
120137
138+ def model_post_init (self , __context : Any , / ) -> None :
139+ """
140+ Model post-init to assert that the custom pre-allocation was
141+ provided and the default was not used.
142+ """
143+ super ().model_post_init (__context )
144+ assert "pre" in self .model_fields_set , "pre allocation was not provided"
145+
121146 @classmethod
122147 def pytest_parameter_name (cls ) -> str :
123148 """
@@ -181,9 +206,11 @@ def generate_blocks_from_code_generator(self, fork: Fork) -> List[Block]:
181206 if self .code_generator is None :
182207 raise Exception ("Code generator is not set" )
183208
184- self .code_generator .deploy_contracts (self .pre , fork )
209+ self .code_generator .deploy_contracts (pre = self .pre , fork = fork )
185210 gas_limit = fork .transaction_gas_limit_cap () or self .gas_benchmark_value
186- benchmark_tx = self .code_generator .generate_transaction (self .pre , gas_limit )
211+ benchmark_tx = self .code_generator .generate_transaction (
212+ pre = self .pre , gas_benchmark_value = gas_limit
213+ )
187214
188215 execution_txs = self .split_transaction (benchmark_tx , gas_limit )
189216 execution_block = Block (txs = execution_txs )
0 commit comments