Skip to content

Commit b965b42

Browse files
committed
fix: extract chain ID helper, wrap web3 import in _get_account, fix tests for CI
1 parent 435f4ff commit b965b42

2 files changed

Lines changed: 72 additions & 71 deletions

File tree

src/google/adk_community/tools/spraay/spraay_tools.py

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@ def _get_web3():
6464

6565
def _get_account():
6666
"""Get the signing account from environment."""
67-
from web3 import Account
67+
try:
68+
from web3 import Account
69+
except ImportError:
70+
raise ImportError(
71+
"web3 is required for Spraay tools. Install with: pip install web3"
72+
)
6873

6974
private_key = os.environ.get("SPRAAY_PRIVATE_KEY")
7075
if not private_key:
@@ -105,6 +110,21 @@ def _calculate_fee(total_wei: int) -> int:
105110
return (total_wei * SPRAAY_FEE_BPS) // 10000
106111

107112

113+
def _verify_chain_id(w3) -> None:
114+
"""Verify the connected chain ID matches Base.
115+
116+
Raises:
117+
ValueError: If chain ID doesn't match BASE_CHAIN_ID.
118+
"""
119+
connected_chain_id = w3.eth.chain_id
120+
if connected_chain_id != BASE_CHAIN_ID:
121+
raise ValueError(
122+
f"Chain ID mismatch: connected to {connected_chain_id}, "
123+
f"expected {BASE_CHAIN_ID} (Base). "
124+
"Check your RPC configuration."
125+
)
126+
127+
108128
def spraay_batch_eth(
109129
recipients: list[str],
110130
amount_per_recipient_eth: str,
@@ -135,17 +155,7 @@ def spraay_batch_eth(
135155
account = _get_account()
136156
contract_address = _get_contract_address()
137157

138-
# Verify chain ID to avoid sending to wrong network
139-
connected_chain_id = w3.eth.chain_id
140-
if connected_chain_id != BASE_CHAIN_ID:
141-
return {
142-
"status": "error",
143-
"error": (
144-
f"Chain ID mismatch: connected to {connected_chain_id}, "
145-
f"expected {BASE_CHAIN_ID} (Base). "
146-
"Check your RPC configuration."
147-
),
148-
}
158+
_verify_chain_id(w3)
149159

150160
checksummed = _validate_recipients(recipients)
151161
amount_wei = w3.to_wei(Decimal(amount_per_recipient_eth), "ether")
@@ -228,17 +238,7 @@ def spraay_batch_token(
228238
account = _get_account()
229239
contract_address = _get_contract_address()
230240

231-
# Verify chain ID to avoid sending to wrong network
232-
connected_chain_id = w3.eth.chain_id
233-
if connected_chain_id != BASE_CHAIN_ID:
234-
return {
235-
"status": "error",
236-
"error": (
237-
f"Chain ID mismatch: connected to {connected_chain_id}, "
238-
f"expected {BASE_CHAIN_ID} (Base). "
239-
"Check your RPC configuration."
240-
),
241-
}
241+
_verify_chain_id(w3)
242242

243243
checksummed = _validate_recipients(recipients)
244244
token_addr = w3.to_checksum_address(token_address)
@@ -348,17 +348,7 @@ def spraay_batch_eth_variable(
348348
account = _get_account()
349349
contract_address = _get_contract_address()
350350

351-
# Verify chain ID to avoid sending to wrong network
352-
connected_chain_id = w3.eth.chain_id
353-
if connected_chain_id != BASE_CHAIN_ID:
354-
return {
355-
"status": "error",
356-
"error": (
357-
f"Chain ID mismatch: connected to {connected_chain_id}, "
358-
f"expected {BASE_CHAIN_ID} (Base). "
359-
"Check your RPC configuration."
360-
),
361-
}
351+
_verify_chain_id(w3)
362352

363353
checksummed = _validate_recipients(recipients)
364354

@@ -447,17 +437,7 @@ def spraay_batch_token_variable(
447437
account = _get_account()
448438
contract_address = _get_contract_address()
449439

450-
# Verify chain ID to avoid sending to wrong network
451-
connected_chain_id = w3.eth.chain_id
452-
if connected_chain_id != BASE_CHAIN_ID:
453-
return {
454-
"status": "error",
455-
"error": (
456-
f"Chain ID mismatch: connected to {connected_chain_id}, "
457-
f"expected {BASE_CHAIN_ID} (Base). "
458-
"Check your RPC configuration."
459-
),
460-
}
440+
_verify_chain_id(w3)
461441

462442
checksummed = _validate_recipients(recipients)
463443
token_addr = w3.to_checksum_address(token_address)

tests/unittests/tools/spraay/test_spraay_tools.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
"""Unit tests for Spraay batch payment tools."""
1616

1717
import os
18+
import sys
1819
import unittest
1920
from unittest.mock import MagicMock, patch
2021

2122
from google.adk_community.tools.spraay.constants import (
23+
BASE_CHAIN_ID,
2224
MAX_RECIPIENTS,
2325
SPRAAY_CONTRACT_ADDRESS,
2426
SPRAAY_FEE_BPS,
@@ -39,31 +41,36 @@
3941
TOKEN_ADDR = "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913" # USDC on Base
4042

4143

44+
def _make_mock_w3():
45+
"""Create a mock Web3 instance with chain_id set to Base."""
46+
mock_w3 = MagicMock()
47+
mock_w3.eth.chain_id = BASE_CHAIN_ID
48+
return mock_w3
49+
50+
4251
class TestValidateRecipients(unittest.TestCase):
4352
"""Tests for recipient address validation."""
4453

45-
@patch("google.adk_community.tools.spraay.spraay_tools.Web3")
46-
def test_valid_addresses(self, mock_web3_class):
47-
"""Valid addresses should be checksummed and returned."""
48-
mock_web3_class.is_address.return_value = True
49-
mock_web3_class.to_checksum_address.side_effect = lambda x: x
50-
result = _validate_recipients([ADDR_1, ADDR_2])
51-
self.assertEqual(len(result), 2)
52-
5354
def test_empty_list(self):
5455
"""Empty recipient list should raise ValueError."""
5556
with self.assertRaises(ValueError):
5657
_validate_recipients([])
5758

58-
@patch("google.adk_community.tools.spraay.spraay_tools.Web3")
59-
def test_too_many_recipients(self, mock_web3_class):
59+
def test_too_many_recipients(self):
6060
"""More than MAX_RECIPIENTS should raise ValueError."""
61-
mock_web3_class.is_address.return_value = True
6261
addresses = [f"0x{'0' * 39}{i:01x}" for i in range(MAX_RECIPIENTS + 1)]
6362
with self.assertRaises(ValueError):
6463
_validate_recipients(addresses)
6564

66-
@patch("google.adk_community.tools.spraay.spraay_tools.Web3")
65+
@patch("google.adk_community.tools.spraay.spraay_tools.Web3", create=True)
66+
def test_valid_addresses(self, mock_web3_class):
67+
"""Valid addresses should be checksummed and returned."""
68+
mock_web3_class.is_address.return_value = True
69+
mock_web3_class.to_checksum_address.side_effect = lambda x: x
70+
result = _validate_recipients([ADDR_1, ADDR_2])
71+
self.assertEqual(len(result), 2)
72+
73+
@patch("google.adk_community.tools.spraay.spraay_tools.Web3", create=True)
6774
def test_invalid_address(self, mock_web3_class):
6875
"""Invalid address should raise ValueError."""
6976
mock_web3_class.is_address.return_value = False
@@ -94,8 +101,11 @@ def test_small_amount(self):
94101
class TestSpraayBatchEth(unittest.TestCase):
95102
"""Tests for spraay_batch_eth function."""
96103

97-
def test_missing_private_key(self):
104+
@patch("google.adk_community.tools.spraay.spraay_tools._get_web3")
105+
def test_missing_private_key(self, mock_web3):
98106
"""Should return error if SPRAAY_PRIVATE_KEY is not set."""
107+
mock_w3 = _make_mock_w3()
108+
mock_web3.return_value = mock_w3
99109
with patch.dict(os.environ, {}, clear=True):
100110
result = spraay_batch_eth([ADDR_1], "0.01")
101111
self.assertEqual(result["status"], "error")
@@ -105,7 +115,7 @@ def test_missing_private_key(self):
105115
@patch("google.adk_community.tools.spraay.spraay_tools._get_account")
106116
def test_zero_amount_returns_error(self, mock_account, mock_web3):
107117
"""Zero ETH amount should return error."""
108-
mock_w3 = MagicMock()
118+
mock_w3 = _make_mock_w3()
109119
mock_w3.to_wei.return_value = 0
110120
mock_web3.return_value = mock_w3
111121
mock_account.return_value = MagicMock()
@@ -124,27 +134,35 @@ def test_mismatched_lengths(self):
124134
with patch(
125135
"google.adk_community.tools.spraay.spraay_tools._get_web3"
126136
) as mock_web3:
127-
mock_w3 = MagicMock()
128-
mock_w3.to_wei.side_effect = lambda x, _: int(float(x) * 10**18)
137+
mock_w3 = _make_mock_w3()
138+
mock_w3.to_wei.side_effect = lambda x, _: int(float(str(x)) * 10**18)
129139
mock_web3.return_value = mock_w3
130140

131141
with patch(
132-
"google.adk_community.tools.spraay.spraay_tools._validate_recipients"
133-
) as mock_validate:
134-
mock_validate.return_value = [ADDR_1, ADDR_2]
142+
"google.adk_community.tools.spraay.spraay_tools._get_account"
143+
) as mock_account:
144+
mock_account.return_value = MagicMock()
145+
146+
with patch(
147+
"google.adk_community.tools.spraay.spraay_tools._validate_recipients"
148+
) as mock_validate:
149+
mock_validate.return_value = [ADDR_1, ADDR_2]
135150

136-
result = spraay_batch_eth_variable(
137-
[ADDR_1, ADDR_2], ["0.1"]
138-
)
139-
self.assertEqual(result["status"], "error")
140-
self.assertIn("must match", result["error"])
151+
result = spraay_batch_eth_variable(
152+
[ADDR_1, ADDR_2], ["0.1"]
153+
)
154+
self.assertEqual(result["status"], "error")
155+
self.assertIn("must match", result["error"])
141156

142157

143158
class TestSpraayBatchToken(unittest.TestCase):
144159
"""Tests for spraay_batch_token function."""
145160

146-
def test_missing_private_key(self):
161+
@patch("google.adk_community.tools.spraay.spraay_tools._get_web3")
162+
def test_missing_private_key(self, mock_web3):
147163
"""Should return error if SPRAAY_PRIVATE_KEY is not set."""
164+
mock_w3 = _make_mock_w3()
165+
mock_web3.return_value = mock_w3
148166
with patch.dict(os.environ, {}, clear=True):
149167
result = spraay_batch_token(TOKEN_ADDR, [ADDR_1], "10")
150168
self.assertEqual(result["status"], "error")
@@ -154,8 +172,11 @@ def test_missing_private_key(self):
154172
class TestSpraayBatchTokenVariable(unittest.TestCase):
155173
"""Tests for spraay_batch_token_variable function."""
156174

157-
def test_missing_private_key(self):
175+
@patch("google.adk_community.tools.spraay.spraay_tools._get_web3")
176+
def test_missing_private_key(self, mock_web3):
158177
"""Should return error if SPRAAY_PRIVATE_KEY is not set."""
178+
mock_w3 = _make_mock_w3()
179+
mock_web3.return_value = mock_w3
159180
with patch.dict(os.environ, {}, clear=True):
160181
result = spraay_batch_token_variable(
161182
TOKEN_ADDR, [ADDR_1], ["10"]

0 commit comments

Comments
 (0)