Skip to content

Commit c9f63fe

Browse files
fix: enforce pympp client chain policy (#134)
* fix: enforce pympp client chain policy * chore: add changelog --------- Co-authored-by: Brendan Ryan <brendanjryan@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 71953bb commit c9f63fe

4 files changed

Lines changed: 120 additions & 110 deletions

File tree

.changelog/ugly-sheep-whisper.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
pympp: patch
3+
---
4+
5+
Fixed client chain policy enforcement to reject challenges that attempt to switch the client to a different chain. Clients pinned to a chain (via `chain_id` or `rpc_url`) now raise a `ValueError` immediately instead of silently following the challenge's `chainId`.

src/mpp/methods/tempo/client.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
import asyncio
99
import time
1010
from dataclasses import dataclass, field
11-
from typing import TYPE_CHECKING
11+
from typing import TYPE_CHECKING, cast
1212

1313
from mpp import Challenge, Credential
1414
from mpp.methods.tempo._attribution import encode as encode_attribution
1515
from mpp.methods.tempo._defaults import (
1616
CHAIN_ID,
17-
CHAIN_RPC_URLS,
1817
RPC_URL,
1918
default_currency_for_chain,
2019
rpc_url_for_chain,
@@ -32,6 +31,7 @@
3231
DEFAULT_GAS_LIMIT = 1_000_000
3332
EXPIRING_NONCE_KEY = (1 << 256) - 1 # U256::MAX
3433
FEE_PAYER_VALID_BEFORE_SECS = 25
34+
_CHAIN_ID_UNSET = object()
3535

3636

3737
class TransactionError(Exception):
@@ -75,7 +75,9 @@ class TempoMethod:
7575
client_id: str | None = None
7676
_intents: dict[str, Intent] = field(default_factory=dict)
7777
_cached_chain_ids: dict[str, int] = field(default_factory=dict, init=False, repr=False)
78+
_chain_id_explicit: bool = field(default=False, init=False, repr=False)
7879
_chain_id_lock: asyncio.Lock | None = field(default=None, init=False, repr=False)
80+
_rpc_url_explicit: bool = field(default=False, init=False, repr=False)
7981

8082
@property
8183
def intents(self) -> dict[str, Intent]:
@@ -101,6 +103,21 @@ async def _get_chain_id(self, rpc_url: str) -> int:
101103
self._cached_chain_ids[rpc_url] = chain_id
102104
return chain_id
103105

106+
async def _resolve_expected_chain_id(self) -> int | None:
107+
"""Return the chain ID pinned by local client configuration.
108+
109+
A client may pin the chain explicitly via ``chain_id`` or implicitly by
110+
supplying a custom ``rpc_url``. In the latter case, trust the chain
111+
reported by that RPC instead of the server challenge.
112+
"""
113+
if self._rpc_url_explicit and not self._chain_id_explicit:
114+
return await self._get_chain_id(self.rpc_url)
115+
if self.chain_id is not None:
116+
return self.chain_id
117+
if self.rpc_url:
118+
return await self._get_chain_id(self.rpc_url)
119+
return None
120+
104121
async def create_credential(self, challenge: Challenge) -> Credential:
105122
"""Create a credential to satisfy the given challenge.
106123
@@ -149,10 +166,8 @@ async def create_credential(self, challenge: Challenge) -> Credential:
149166

150167
splits = method_details.get("splits") if isinstance(method_details, dict) else None
151168

152-
# Resolve RPC URL from challenge's chainId (like mppx), falling back
153-
# to the method-level rpc_url.
154169
rpc_url = self.rpc_url
155-
expected_chain_id: int | None = None
170+
expected_chain_id = await self._resolve_expected_chain_id()
156171
challenge_chain_id = (
157172
method_details.get("chainId") if isinstance(method_details, dict) else None
158173
)
@@ -162,18 +177,14 @@ async def create_credential(self, challenge: Challenge) -> Credential:
162177
except (TypeError, ValueError):
163178
pass
164179
else:
165-
resolved = CHAIN_RPC_URLS.get(parsed_chain_id)
166-
if resolved is not None:
167-
rpc_url = resolved
168-
# Only enforce mismatch check when we resolved to a known
169-
# RPC URL — for unknown chains we fall back to the user's
170-
# custom rpc_url and can't verify the chain ID.
180+
if expected_chain_id is not None and parsed_chain_id != expected_chain_id:
181+
raise ValueError(
182+
f"Challenge requests chain ID {parsed_chain_id}, "
183+
f"but client is restricted to {expected_chain_id}"
184+
)
185+
if expected_chain_id is None:
171186
expected_chain_id = parsed_chain_id
172187

173-
# Also check against the method-level chain_id if set.
174-
if expected_chain_id is None and self.chain_id is not None:
175-
expected_chain_id = self.chain_id
176-
177188
raw_tx, chain_id = await self._build_tempo_transfer(
178189
amount=request["amount"],
179190
currency=request["currency"],
@@ -280,7 +291,7 @@ async def _build_tempo_transfer(
280291
if expected_chain_id is not None and chain_id != expected_chain_id:
281292
raise TransactionError(
282293
f"Chain ID mismatch: RPC returned {chain_id}, "
283-
f"expected {expected_chain_id} from challenge"
294+
f"expected {expected_chain_id} from client policy"
284295
)
285296

286297
if awaiting_fee_payer:
@@ -371,7 +382,7 @@ def tempo(
371382
intents: dict[str, Intent],
372383
account: TempoAccount | None = None,
373384
fee_payer: TempoAccount | None = None,
374-
chain_id: int = CHAIN_ID,
385+
chain_id: int | None | object = _CHAIN_ID_UNSET,
375386
rpc_url: str | None = None,
376387
root_account: str | None = None,
377388
currency: str | None = None,
@@ -391,7 +402,9 @@ def tempo(
391402
chain_id: Tempo chain ID (default: 4217 for mainnet, use 42431
392403
for testnet). Resolves the RPC URL automatically from known chains.
393404
rpc_url: Tempo RPC endpoint URL. Overrides the URL resolved
394-
from ``chain_id``. Defaults to mainnet if neither is set.
405+
from ``chain_id``. When provided without ``chain_id``, the client
406+
pins itself to whatever chain that RPC reports. Defaults to mainnet
407+
if neither is set.
395408
root_account: Root account address for access key signing.
396409
currency: Default currency address for charges.
397410
recipient: Default recipient address for charges.
@@ -417,23 +430,35 @@ def tempo(
417430
intents={"charge": ChargeIntent()},
418431
)
419432
"""
433+
chain_id_explicit = chain_id is not _CHAIN_ID_UNSET
434+
resolved_chain_id: int | None
435+
if chain_id is _CHAIN_ID_UNSET:
436+
resolved_chain_id = CHAIN_ID
437+
else:
438+
resolved_chain_id = cast("int | None", chain_id)
439+
440+
rpc_url_explicit = rpc_url is not None
420441
if rpc_url is None:
421-
rpc_url = rpc_url_for_chain(chain_id)
442+
if resolved_chain_id is None:
443+
raise ValueError("chain_id or rpc_url is required")
444+
rpc_url = rpc_url_for_chain(resolved_chain_id)
422445

423446
if currency is None:
424-
currency = default_currency_for_chain(chain_id)
447+
currency = default_currency_for_chain(resolved_chain_id)
425448

426449
method = TempoMethod(
427450
account=account,
428451
fee_payer=fee_payer,
429452
rpc_url=rpc_url,
430-
chain_id=chain_id,
453+
chain_id=resolved_chain_id,
431454
root_account=root_account,
432455
currency=currency,
433456
recipient=recipient,
434457
decimals=decimals,
435458
client_id=client_id,
436459
)
460+
method._chain_id_explicit = chain_id_explicit
461+
method._rpc_url_explicit = rpc_url_explicit
437462
for intent in intents.values():
438463
if hasattr(intent, "rpc_url") and intent.rpc_url is None: # type: ignore[union-attr]
439464
intent.rpc_url = rpc_url # type: ignore[union-attr]

tests/test_client.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,37 @@ async def test_handles_multiple_www_authenticate_headers(self) -> None:
216216
assert len(inner.requests) == 2
217217
method.create_credential.assert_called_once()
218218

219+
@pytest.mark.asyncio
220+
async def test_does_not_retry_when_method_rejects_challenge(self) -> None:
221+
"""Should not send an Authorization retry when the method rejects the challenge."""
222+
challenge = Challenge(
223+
id="test-id",
224+
method="tempo",
225+
intent="charge",
226+
request={"amount": "1000", "methodDetails": {"chainId": 42431}},
227+
)
228+
www_auth = challenge.to_www_authenticate("example.com")
229+
230+
inner = MockTransport(
231+
[
232+
httpx.Response(402, headers={"www-authenticate": www_auth}),
233+
]
234+
)
235+
236+
method = MockMethod()
237+
method.create_credential.side_effect = ValueError(
238+
"Challenge requests chain ID 42431, but client is restricted to 4217"
239+
)
240+
transport = PaymentTransport(methods=[method], inner=inner)
241+
242+
request = httpx.Request("GET", "https://example.com")
243+
244+
with pytest.raises(ValueError, match="client is restricted to 4217"):
245+
await transport.handle_async_request(request)
246+
247+
assert len(inner.requests) == 1
248+
method.create_credential.assert_called_once()
249+
219250

220251
class TestClient:
221252
@pytest.mark.asyncio

0 commit comments

Comments
 (0)