Skip to content

Commit e4b02f1

Browse files
committed
Fix incorrect message serialization in transaction signing
Added jup_python_sdk/tests/ directory with tests for signing, serialization roundtrip, and key loading.
1 parent 568652a commit e4b02f1

File tree

3 files changed

+163
-1
lines changed

3 files changed

+163
-1
lines changed

jup_python_sdk/clients/jupiter_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import base58
77
import httpx
8+
from solders.message import to_bytes_versioned
89
from solders.solders import Keypair, VersionedTransaction
910

1011

@@ -93,7 +94,7 @@ def _sign_versioned_transaction(
9394

9495
signers = list(versioned_transaction.signatures)
9596

96-
message_bytes = bytes(versioned_transaction.message)
97+
message_bytes = to_bytes_versioned(versioned_transaction.message)
9798
your_signature = wallet.sign_message(message_bytes)
9899
signers[wallet_index] = your_signature
99100

jup_python_sdk/tests/__init__.py

Whitespace-only changes.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import base64
2+
import os
3+
from unittest.mock import patch
4+
5+
import base58
6+
import pytest
7+
from solders.hash import Hash
8+
from solders.instruction import AccountMeta, Instruction
9+
from solders.keypair import Keypair
10+
from solders.message import MessageV0, to_bytes_versioned
11+
from solders.signature import Signature
12+
from solders.system_program import TransferParams, transfer
13+
from solders.transaction import VersionedTransaction
14+
15+
from jup_python_sdk.clients.jupiter_client import JupiterClient
16+
17+
18+
@pytest.fixture
19+
def wallet():
20+
return Keypair()
21+
22+
23+
@pytest.fixture
24+
def client(wallet):
25+
env_var = "TEST_PRIVATE_KEY"
26+
os.environ[env_var] = base58.b58encode(bytes(wallet)).decode()
27+
c = JupiterClient(api_key=None, private_key_env_var=env_var, timeout=10)
28+
yield c
29+
c.close()
30+
del os.environ[env_var]
31+
32+
33+
def _build_versioned_tx(payer: Keypair, num_signers: int = 1):
34+
"""Build a VersionedTransaction with the given number of signer slots."""
35+
extra_signers = [Keypair() for _ in range(num_signers - 1)]
36+
all_signers = [payer] + extra_signers
37+
38+
# Build an instruction that requires all signers
39+
program_id = Keypair().pubkey()
40+
accounts = [
41+
AccountMeta(pubkey=s.pubkey(), is_signer=True, is_writable=True)
42+
for s in all_signers
43+
]
44+
ix = Instruction(program_id, bytes(), accounts)
45+
46+
blockhash = Hash.default()
47+
msg = MessageV0.try_compile(
48+
payer=payer.pubkey(),
49+
instructions=[ix],
50+
address_lookup_table_accounts=[],
51+
recent_blockhash=blockhash,
52+
)
53+
54+
# Create a fully signed tx then zero out signatures to simulate
55+
# an unsigned transaction from the API
56+
signed_tx = VersionedTransaction(msg, all_signers)
57+
signatures = list(signed_tx.signatures)
58+
zeroed = [Signature.default()] * len(signatures)
59+
return VersionedTransaction.populate(msg, zeroed)
60+
61+
62+
class TestSignVersionedTransaction:
63+
def test_single_signer(self, client, wallet):
64+
"""Signing works for a transaction with one signature slot."""
65+
tx = _build_versioned_tx(wallet, num_signers=1)
66+
67+
signed = client._sign_versioned_transaction(tx)
68+
69+
assert len(signed.signatures) == 1
70+
assert signed.signatures[0] != Signature.default()
71+
72+
def test_multi_signer_no_type_error(self, client, wallet):
73+
"""Signing works for transactions with multiple signature slots (issue #3)."""
74+
tx = _build_versioned_tx(wallet, num_signers=3)
75+
assert len(tx.signatures) == 3
76+
77+
signed = client._sign_versioned_transaction(tx)
78+
79+
assert len(signed.signatures) == 3
80+
# Wallet signature should be filled in
81+
wallet_index = list(tx.message.account_keys).index(wallet.pubkey())
82+
assert signed.signatures[wallet_index] != Signature.default()
83+
# Other slots remain default (unsigned)
84+
for i, sig in enumerate(signed.signatures):
85+
if i != wallet_index:
86+
assert sig == Signature.default()
87+
88+
def test_signature_is_valid(self, client, wallet):
89+
"""The produced signature actually verifies against the message."""
90+
tx = _build_versioned_tx(wallet, num_signers=1)
91+
92+
signed = client._sign_versioned_transaction(tx)
93+
94+
msg_bytes = to_bytes_versioned(signed.message)
95+
sig = signed.signatures[0]
96+
# Verify by re-signing and comparing
97+
expected_sig = wallet.sign_message(msg_bytes)
98+
assert sig == expected_sig
99+
100+
def test_uses_to_bytes_versioned(self, client, wallet):
101+
"""Signing uses to_bytes_versioned, not bytes(), for correct message serialization."""
102+
tx = _build_versioned_tx(wallet, num_signers=1)
103+
104+
with patch(
105+
"jup_python_sdk.clients.jupiter_client.to_bytes_versioned",
106+
wraps=to_bytes_versioned,
107+
) as mock_fn:
108+
client._sign_versioned_transaction(tx)
109+
mock_fn.assert_called_once()
110+
111+
def test_roundtrip_serialize_deserialize(self, client, wallet):
112+
"""A signed transaction can be serialized to base64 and back."""
113+
tx = _build_versioned_tx(wallet, num_signers=2)
114+
115+
signed = client._sign_versioned_transaction(tx)
116+
b64 = client._serialize_versioned_transaction(signed)
117+
118+
recovered = VersionedTransaction.from_bytes(base64.b64decode(b64))
119+
assert recovered.signatures == signed.signatures
120+
121+
122+
class TestSignBase64Transaction:
123+
def test_sign_base64_transaction(self, client, wallet):
124+
"""_sign_base64_transaction decodes, signs, and returns a VersionedTransaction."""
125+
tx = _build_versioned_tx(wallet, num_signers=1)
126+
b64 = base64.b64encode(bytes(tx)).decode()
127+
128+
signed = client._sign_base64_transaction(b64)
129+
130+
assert isinstance(signed, VersionedTransaction)
131+
assert signed.signatures[0] != Signature.default()
132+
133+
134+
class TestLoadPrivateKey:
135+
def test_base58_key(self, wallet):
136+
env_var = "TEST_PK_B58"
137+
os.environ[env_var] = base58.b58encode(bytes(wallet)).decode()
138+
c = JupiterClient(api_key=None, private_key_env_var=env_var, timeout=10)
139+
assert c._load_private_key_bytes() == bytes(wallet)
140+
del os.environ[env_var]
141+
142+
def test_uint8_array_key(self, wallet):
143+
env_var = "TEST_PK_ARR"
144+
arr = list(bytes(wallet))
145+
os.environ[env_var] = str(arr)
146+
c = JupiterClient(api_key=None, private_key_env_var=env_var, timeout=10)
147+
assert c._load_private_key_bytes() == bytes(wallet)
148+
del os.environ[env_var]
149+
150+
def test_invalid_key_raises(self):
151+
env_var = "TEST_PK_BAD"
152+
os.environ[env_var] = "not-a-valid-key!!!"
153+
c = JupiterClient(api_key=None, private_key_env_var=env_var, timeout=10)
154+
with pytest.raises(ValueError):
155+
c._load_private_key_bytes()
156+
del os.environ[env_var]
157+
158+
159+
class TestGetPublicKey:
160+
def test_returns_correct_pubkey(self, client, wallet):
161+
assert client._get_public_key() == str(wallet.pubkey())

0 commit comments

Comments
 (0)