Skip to content

Commit 435f4ff

Browse files
committed
fix: address review feedback - Decimal precision, gas buffer, chain ID check, web3 optional dep
1 parent 78a812e commit 435f4ff

2 files changed

Lines changed: 70 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ test = [
4545
"pytest>=8.4.2",
4646
"pytest-asyncio>=1.2.0",
4747
]
48+
spraay = ["web3>=6.0.0"]
4849

4950

5051
[tool.pyink]

src/google/adk_community/tools/spraay/spraay_tools.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import logging
3131
import os
32+
from decimal import Decimal
3233
from typing import Optional
3334

3435
from google.adk_community.tools.spraay.constants import (
@@ -134,8 +135,20 @@ def spraay_batch_eth(
134135
account = _get_account()
135136
contract_address = _get_contract_address()
136137

138+
# Verify chain ID to avoid sending to wrong network
139+
connected_chain_id = w3.eth.chain_id
140+
if connected_chain_id != BASE_CHAIN_ID:
141+
return {
142+
"status": "error",
143+
"error": (
144+
f"Chain ID mismatch: connected to {connected_chain_id}, "
145+
f"expected {BASE_CHAIN_ID} (Base). "
146+
"Check your RPC configuration."
147+
),
148+
}
149+
137150
checksummed = _validate_recipients(recipients)
138-
amount_wei = w3.to_wei(amount_per_recipient_eth, "ether")
151+
amount_wei = w3.to_wei(Decimal(amount_per_recipient_eth), "ether")
139152

140153
if amount_wei <= 0:
141154
return {"status": "error", "error": "Amount must be greater than 0."}
@@ -161,7 +174,8 @@ def spraay_batch_eth(
161174
}
162175
)
163176

164-
tx["gas"] = w3.eth.estimate_gas(tx)
177+
# Add 10% gas buffer to prevent 'out of gas' errors
178+
tx["gas"] = int(w3.eth.estimate_gas(tx) * 1.1)
165179
signed = account.sign_transaction(tx)
166180
tx_hash = w3.eth.send_raw_transaction(signed.raw_transaction)
167181

@@ -214,12 +228,26 @@ def spraay_batch_token(
214228
account = _get_account()
215229
contract_address = _get_contract_address()
216230

231+
# Verify chain ID to avoid sending to wrong network
232+
connected_chain_id = w3.eth.chain_id
233+
if connected_chain_id != BASE_CHAIN_ID:
234+
return {
235+
"status": "error",
236+
"error": (
237+
f"Chain ID mismatch: connected to {connected_chain_id}, "
238+
f"expected {BASE_CHAIN_ID} (Base). "
239+
"Check your RPC configuration."
240+
),
241+
}
242+
217243
checksummed = _validate_recipients(recipients)
218244
token_addr = w3.to_checksum_address(token_address)
219245
spraay_addr = w3.to_checksum_address(contract_address)
220246

221-
# Convert human-readable amount to token units
222-
amount_units = int(float(amount_per_recipient) * (10**token_decimals))
247+
# Use Decimal for precise token amount conversion
248+
amount_units = int(
249+
Decimal(amount_per_recipient) * Decimal(10**token_decimals)
250+
)
223251
if amount_units <= 0:
224252
return {"status": "error", "error": "Amount must be greater than 0."}
225253

@@ -246,7 +274,8 @@ def spraay_batch_token(
246274
"gas": 0,
247275
}
248276
)
249-
approve_tx["gas"] = w3.eth.estimate_gas(approve_tx)
277+
# Add 10% gas buffer
278+
approve_tx["gas"] = int(w3.eth.estimate_gas(approve_tx) * 1.1)
250279
signed_approve = account.sign_transaction(approve_tx)
251280
approve_hash = w3.eth.send_raw_transaction(
252281
signed_approve.raw_transaction
@@ -268,7 +297,8 @@ def spraay_batch_token(
268297
"gas": 0,
269298
}
270299
)
271-
tx["gas"] = w3.eth.estimate_gas(tx)
300+
# Add 10% gas buffer
301+
tx["gas"] = int(w3.eth.estimate_gas(tx) * 1.1)
272302
signed = account.sign_transaction(tx)
273303
tx_hash = w3.eth.send_raw_transaction(signed.raw_transaction)
274304

@@ -318,6 +348,18 @@ def spraay_batch_eth_variable(
318348
account = _get_account()
319349
contract_address = _get_contract_address()
320350

351+
# Verify chain ID to avoid sending to wrong network
352+
connected_chain_id = w3.eth.chain_id
353+
if connected_chain_id != BASE_CHAIN_ID:
354+
return {
355+
"status": "error",
356+
"error": (
357+
f"Chain ID mismatch: connected to {connected_chain_id}, "
358+
f"expected {BASE_CHAIN_ID} (Base). "
359+
"Check your RPC configuration."
360+
),
361+
}
362+
321363
checksummed = _validate_recipients(recipients)
322364

323365
if len(amounts_eth) != len(checksummed):
@@ -329,7 +371,7 @@ def spraay_batch_eth_variable(
329371
),
330372
}
331373

332-
amounts_wei = [w3.to_wei(a, "ether") for a in amounts_eth]
374+
amounts_wei = [w3.to_wei(Decimal(a), "ether") for a in amounts_eth]
333375
if any(a <= 0 for a in amounts_wei):
334376
return {"status": "error", "error": "All amounts must be greater than 0."}
335377

@@ -353,7 +395,8 @@ def spraay_batch_eth_variable(
353395
"gas": 0,
354396
}
355397
)
356-
tx["gas"] = w3.eth.estimate_gas(tx)
398+
# Add 10% gas buffer
399+
tx["gas"] = int(w3.eth.estimate_gas(tx) * 1.1)
357400
signed = account.sign_transaction(tx)
358401
tx_hash = w3.eth.send_raw_transaction(signed.raw_transaction)
359402

@@ -404,6 +447,18 @@ def spraay_batch_token_variable(
404447
account = _get_account()
405448
contract_address = _get_contract_address()
406449

450+
# Verify chain ID to avoid sending to wrong network
451+
connected_chain_id = w3.eth.chain_id
452+
if connected_chain_id != BASE_CHAIN_ID:
453+
return {
454+
"status": "error",
455+
"error": (
456+
f"Chain ID mismatch: connected to {connected_chain_id}, "
457+
f"expected {BASE_CHAIN_ID} (Base). "
458+
"Check your RPC configuration."
459+
),
460+
}
461+
407462
checksummed = _validate_recipients(recipients)
408463
token_addr = w3.to_checksum_address(token_address)
409464
spraay_addr = w3.to_checksum_address(contract_address)
@@ -417,8 +472,9 @@ def spraay_batch_token_variable(
417472
),
418473
}
419474

475+
# Use Decimal for precise token amount conversion
420476
amounts_units = [
421-
int(float(a) * (10**token_decimals)) for a in amounts
477+
int(Decimal(a) * Decimal(10**token_decimals)) for a in amounts
422478
]
423479
if any(a <= 0 for a in amounts_units):
424480
return {"status": "error", "error": "All amounts must be greater than 0."}
@@ -446,7 +502,8 @@ def spraay_batch_token_variable(
446502
"gas": 0,
447503
}
448504
)
449-
approve_tx["gas"] = w3.eth.estimate_gas(approve_tx)
505+
# Add 10% gas buffer
506+
approve_tx["gas"] = int(w3.eth.estimate_gas(approve_tx) * 1.1)
450507
signed_approve = account.sign_transaction(approve_tx)
451508
approve_hash = w3.eth.send_raw_transaction(
452509
signed_approve.raw_transaction
@@ -468,7 +525,8 @@ def spraay_batch_token_variable(
468525
"gas": 0,
469526
}
470527
)
471-
tx["gas"] = w3.eth.estimate_gas(tx)
528+
# Add 10% gas buffer
529+
tx["gas"] = int(w3.eth.estimate_gas(tx) * 1.1)
472530
signed = account.sign_transaction(tx)
473531
tx_hash = w3.eth.send_raw_transaction(signed.raw_transaction)
474532

0 commit comments

Comments
 (0)