Skip to content

Commit 9ff8f5d

Browse files
Adds tests for client encryption code
1 parent b4c4b43 commit 9ff8f5d

1 file changed

Lines changed: 310 additions & 0 deletions

File tree

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
"""
2+
Unit tests for RSA encryption in IntersectClient
3+
4+
Tests cover the client-side RSA encryption flow where a client:
5+
1. Fetches the service's public key before sending encrypted messages
6+
2. Encrypts outgoing messages using the service's public key
7+
3. Decrypts incoming messages using its own private key
8+
"""
9+
from unittest.mock import MagicMock, patch
10+
from uuid import uuid4
11+
12+
import pytest
13+
from pydantic import BaseModel
14+
from cryptography.hazmat.primitives.asymmetric import rsa
15+
from cryptography.hazmat.primitives import serialization
16+
from cryptography.hazmat.backends import default_backend
17+
18+
from intersect_sdk._internal.encryption import (
19+
intersect_payload_decrypt,
20+
intersect_payload_encrypt,
21+
)
22+
from intersect_sdk._internal.encryption.models import IntersectEncryptedPayload
23+
from intersect_sdk._internal.messages.userspace import create_userspace_message_headers
24+
from intersect_sdk._internal.encryption.models import IntersectEncryptionPublicKey
25+
from intersect_sdk.client import IntersectClient
26+
from intersect_sdk.core_definitions import IntersectDataHandler
27+
from intersect_sdk.shared_callback_definitions import IntersectDirectMessageParams
28+
29+
30+
class SampleMessage(BaseModel):
31+
"""Simple test message for encryption/decryption testing"""
32+
result: str = ""
33+
34+
35+
@pytest.fixture
36+
def rsa_keypair():
37+
"""Generate RSA key pair for testing"""
38+
private_key = rsa.generate_private_key(
39+
public_exponent=65537,
40+
key_size=2048,
41+
backend=default_backend(),
42+
)
43+
public_key = private_key.public_key()
44+
return private_key, public_key
45+
46+
47+
@pytest.fixture
48+
def public_key_pem(rsa_keypair):
49+
"""Get PEM encoded public key"""
50+
_, public_key = rsa_keypair
51+
public_pem = public_key.public_bytes(
52+
encoding=serialization.Encoding.PEM,
53+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
54+
).decode()
55+
return public_pem
56+
57+
58+
@pytest.fixture
59+
def mock_client(rsa_keypair):
60+
"""Create a mock IntersectClient for testing"""
61+
mock_cli = MagicMock(spec=IntersectClient)
62+
private_key, _ = rsa_keypair
63+
mock_cli._private_key = private_key
64+
mock_cli._hierarchy = MagicMock()
65+
mock_cli._hierarchy.hierarchy_string.return_value = "test.org-fac.sys-client"
66+
mock_cli._hierarchy.hierarchy_string.side_effect = lambda sep='.': "test.org-fac.sys-client".replace('.', sep) if sep != '.' else "test.org-fac.sys-client"
67+
mock_cli._terminate_after_initial_messages = False
68+
mock_cli._campaign_id = uuid4()
69+
mock_cli._user_callback = MagicMock()
70+
71+
# Mock data plane manager
72+
mock_cli._data_plane_manager = MagicMock()
73+
74+
# Mock control plane manager
75+
mock_cli._control_plane_manager = MagicMock()
76+
77+
return mock_cli
78+
79+
80+
def test_send_message_encryption_scheme_none_skips_encryption(mock_client):
81+
"""Test that encryption_scheme='NONE' skips encryption for outgoing messages"""
82+
params = IntersectDirectMessageParams(
83+
destination="test.org-fac.sys-service",
84+
operation="test_op",
85+
payload='{"test": "data"}',
86+
content_type="application/json",
87+
encryption_scheme="NONE",
88+
)
89+
90+
mock_client._data_plane_manager.outgoing_message_data_handler.return_value = (
91+
b"serialized_data"
92+
)
93+
94+
# Call the send method
95+
IntersectClient._send_userspace_message(mock_client, params)
96+
97+
# Verify publish_message was called
98+
assert mock_client._control_plane_manager.publish_message.called
99+
call_args = mock_client._control_plane_manager.publish_message.call_args
100+
# The payload should not be encrypted (not an IntersectEncryptedPayload instance)
101+
payload = call_args[0][1]
102+
assert payload == b"serialized_data"
103+
104+
def test_send_message_rsa_fetches_public_key(mock_client, public_key_pem):
105+
"""Test that RSA encryption fetches the public key before encrypting"""
106+
params = IntersectDirectMessageParams(
107+
destination="test.org-fac.sys-service",
108+
operation="test_op",
109+
payload='{"test": "data"}',
110+
content_type="application/json",
111+
encryption_scheme="RSA",
112+
)
113+
114+
# Mock the public key fetch to return the service's public key
115+
mock_client._fetch_service_public_key = MagicMock(return_value=public_key_pem)
116+
mock_client._data_plane_manager.outgoing_message_data_handler.return_value = (
117+
b"encrypted_payload"
118+
)
119+
120+
# Call the send method
121+
IntersectClient._send_userspace_message(mock_client, params)
122+
123+
# Verify that public key was fetched
124+
mock_client._fetch_service_public_key.assert_called_once_with("test.org-fac.sys-service")
125+
126+
def test_send_message_rsa_encryption_payload_structure(mock_client, public_key_pem):
127+
"""Test that encrypted payload has the expected structure"""
128+
params = IntersectDirectMessageParams(
129+
destination="test.org-fac.sys-service",
130+
operation="test_op",
131+
payload='{"message": "hello", "value": 42}',
132+
content_type="application/json",
133+
encryption_scheme="RSA",
134+
)
135+
136+
mock_client._fetch_service_public_key = MagicMock(return_value=public_key_pem)
137+
138+
captured_data = None
139+
def capture_handler(data, content_type, handler):
140+
nonlocal captured_data
141+
captured_data = data
142+
return b"encrypted_to_broker"
143+
144+
mock_client._data_plane_manager.outgoing_message_data_handler = capture_handler
145+
146+
# Call the send method
147+
IntersectClient._send_userspace_message(mock_client, params)
148+
149+
# Verify the payload structure
150+
assert isinstance(captured_data, IntersectEncryptedPayload)
151+
assert captured_data.key
152+
assert captured_data.initial_vector
153+
assert captured_data.data
154+
155+
def test_send_message_rsa_encryption_fails_without_public_key(mock_client):
156+
"""Test that sending fails if public key cannot be fetched"""
157+
params = IntersectDirectMessageParams(
158+
destination="test.org-fac.sys-service",
159+
operation="test_op",
160+
payload='{"test": "data"}',
161+
content_type="application/json",
162+
encryption_scheme="RSA",
163+
)
164+
165+
# Mock the public key fetch to return None (failure)
166+
mock_client._fetch_service_public_key = MagicMock(return_value=None)
167+
168+
# Call the send method
169+
IntersectClient._send_userspace_message(mock_client, params)
170+
171+
# Verify that outgoing_message_data_handler was NOT called (message sending failed)
172+
assert not mock_client._data_plane_manager.outgoing_message_data_handler.called
173+
174+
def test_receive_message_encryption_scheme_none_skips_decryption(mock_client):
175+
"""Test that encryption_scheme='NONE' skips decryption for incoming messages"""
176+
# Create message headers for an unencrypted message
177+
headers = create_userspace_message_headers(
178+
source="test.org-fac.sys-service",
179+
destination=mock_client._hierarchy.hierarchy_string('.'),
180+
data_handler=IntersectDataHandler.MESSAGE,
181+
operation_id="test_op",
182+
campaign_id=uuid4(),
183+
request_id=uuid4(),
184+
encryption_scheme="NONE",
185+
)
186+
187+
payload = b'{"test": "data"}'
188+
mock_client._data_plane_manager.incoming_message_data_handler.return_value = payload
189+
mock_client._user_callback = MagicMock()
190+
191+
# Call the handle method
192+
IntersectClient._handle_userspace_message(
193+
mock_client, payload, "application/json", headers
194+
)
195+
196+
# Verify that incoming_message_data_handler was called
197+
assert mock_client._data_plane_manager.incoming_message_data_handler.called
198+
199+
def test_receive_message_rsa_decryption_uses_private_key(mock_client, rsa_keypair):
200+
"""Test that RSA decryption uses the client's private key for incoming messages"""
201+
private_key, _ = rsa_keypair
202+
mock_client._private_key = private_key
203+
204+
# Create a simple message to encrypt
205+
# The client expects a wrapper model with 'data' field containing the JSON string
206+
test_message_json = '{"data": "{\\"result\\": \\"data from service\\"}"}'
207+
key_payload = IntersectEncryptionPublicKey(
208+
public_key=private_key.public_key().public_bytes(
209+
encoding=serialization.Encoding.PEM,
210+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
211+
).decode()
212+
)
213+
214+
encrypted_payload = intersect_payload_encrypt(
215+
key_payload=key_payload,
216+
unencrypted_model=test_message_json,
217+
)
218+
219+
# Create message headers for an encrypted message
220+
headers = create_userspace_message_headers(
221+
source="test.org-fac.sys-service",
222+
destination=mock_client._hierarchy.hierarchy_string('.'),
223+
data_handler=IntersectDataHandler.MESSAGE,
224+
operation_id="test_op",
225+
campaign_id=uuid4(),
226+
request_id=uuid4(),
227+
encryption_scheme="RSA",
228+
)
229+
230+
# Mock the incoming handler to return the encrypted payload
231+
mock_client._data_plane_manager.incoming_message_data_handler.return_value = (
232+
encrypted_payload
233+
)
234+
mock_client._user_callback = MagicMock()
235+
236+
# Call the handle method
237+
IntersectClient._handle_userspace_message(
238+
mock_client, b"ignored_payload", "application/json", headers
239+
)
240+
241+
# Verify that incoming_message_data_handler was called
242+
assert mock_client._data_plane_manager.incoming_message_data_handler.called
243+
244+
def test_receive_public_key_response_stores_in_cache(mock_client):
245+
"""Test that public key responses are cached for use in key fetching"""
246+
request_id = uuid4()
247+
public_key_data = {"public_key": "-----BEGIN PUBLIC KEY-----\ntest\n-----END PUBLIC KEY-----\n"}
248+
249+
# Create message headers for a public key response
250+
headers = create_userspace_message_headers(
251+
source="test.org-fac.sys-service",
252+
destination=mock_client._hierarchy.hierarchy_string('.'),
253+
data_handler=IntersectDataHandler.MESSAGE,
254+
operation_id="intersect_sdk.get_public_key",
255+
campaign_id=uuid4(),
256+
request_id=request_id,
257+
encryption_scheme="NONE",
258+
)
259+
260+
payload = b'{"public_key": "test_key"}'
261+
mock_client._data_plane_manager.incoming_message_data_handler.return_value = public_key_data
262+
263+
# Call the handle method
264+
IntersectClient._handle_userspace_message(
265+
mock_client, payload, "application/json", headers
266+
)
267+
268+
# Verify that the response was cached using message_id as key
269+
assert hasattr(mock_client, '_public_key_responses')
270+
message_id = headers['message_id'] # Extract from the headers dict
271+
assert message_id in mock_client._public_key_responses
272+
assert mock_client._public_key_responses[message_id] == public_key_data
273+
274+
def test_send_message_rsa_encryption_roundtrip(mock_client, rsa_keypair, public_key_pem):
275+
"""Test that encrypted message can be decrypted"""
276+
private_key, _ = rsa_keypair
277+
class LocalMessage(BaseModel):
278+
text: str
279+
280+
original_message = LocalMessage(text="Test message from client")
281+
params = IntersectDirectMessageParams(
282+
destination="test.org-fac.sys-service",
283+
operation="test_op",
284+
payload=original_message,
285+
content_type="application/json",
286+
encryption_scheme="RSA",
287+
)
288+
289+
mock_client._fetch_service_public_key = MagicMock(return_value=public_key_pem)
290+
291+
captured_data = None
292+
def capture_handler(data, content_type, handler):
293+
nonlocal captured_data
294+
captured_data = data
295+
return b"encrypted_to_broker"
296+
297+
mock_client._data_plane_manager.outgoing_message_data_handler = capture_handler
298+
299+
# Send the message (which encrypts it)
300+
IntersectClient._send_userspace_message(mock_client, params)
301+
302+
# Decrypt it to verify correctness
303+
decrypted = intersect_payload_decrypt(
304+
rsa_private_key=private_key,
305+
encrypted_payload=captured_data,
306+
model=LocalMessage,
307+
)
308+
309+
# The decrypted payload should match the original
310+
assert decrypted.model.text == original_message.text

0 commit comments

Comments
 (0)