|
| 1 | +""" |
| 2 | +Unit tests for RSA encryption in IntersectClient |
| 3 | +
|
| 4 | +Tests cover the client-side RSA encryption flow where a client: |
| 5 | +1. Fetches the service's public key before sending encrypted messages |
| 6 | +2. Encrypts outgoing messages using the service's public key |
| 7 | +3. Decrypts incoming messages using its own private key |
| 8 | +""" |
| 9 | +from unittest.mock import MagicMock, patch |
| 10 | +from uuid import uuid4 |
| 11 | + |
| 12 | +import pytest |
| 13 | +from pydantic import BaseModel |
| 14 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 15 | +from cryptography.hazmat.primitives import serialization |
| 16 | +from cryptography.hazmat.backends import default_backend |
| 17 | + |
| 18 | +from intersect_sdk._internal.encryption import ( |
| 19 | + intersect_payload_decrypt, |
| 20 | + intersect_payload_encrypt, |
| 21 | +) |
| 22 | +from intersect_sdk._internal.encryption.models import IntersectEncryptedPayload |
| 23 | +from intersect_sdk._internal.messages.userspace import create_userspace_message_headers |
| 24 | +from intersect_sdk._internal.encryption.models import IntersectEncryptionPublicKey |
| 25 | +from intersect_sdk.client import IntersectClient |
| 26 | +from intersect_sdk.core_definitions import IntersectDataHandler |
| 27 | +from intersect_sdk.shared_callback_definitions import IntersectDirectMessageParams |
| 28 | + |
| 29 | + |
| 30 | +class SampleMessage(BaseModel): |
| 31 | + """Simple test message for encryption/decryption testing""" |
| 32 | + result: str = "" |
| 33 | + |
| 34 | + |
| 35 | +@pytest.fixture |
| 36 | +def rsa_keypair(): |
| 37 | + """Generate RSA key pair for testing""" |
| 38 | + private_key = rsa.generate_private_key( |
| 39 | + public_exponent=65537, |
| 40 | + key_size=2048, |
| 41 | + backend=default_backend(), |
| 42 | + ) |
| 43 | + public_key = private_key.public_key() |
| 44 | + return private_key, public_key |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture |
| 48 | +def public_key_pem(rsa_keypair): |
| 49 | + """Get PEM encoded public key""" |
| 50 | + _, public_key = rsa_keypair |
| 51 | + public_pem = public_key.public_bytes( |
| 52 | + encoding=serialization.Encoding.PEM, |
| 53 | + format=serialization.PublicFormat.SubjectPublicKeyInfo, |
| 54 | + ).decode() |
| 55 | + return public_pem |
| 56 | + |
| 57 | + |
| 58 | +@pytest.fixture |
| 59 | +def mock_client(rsa_keypair): |
| 60 | + """Create a mock IntersectClient for testing""" |
| 61 | + mock_cli = MagicMock(spec=IntersectClient) |
| 62 | + private_key, _ = rsa_keypair |
| 63 | + mock_cli._private_key = private_key |
| 64 | + mock_cli._hierarchy = MagicMock() |
| 65 | + mock_cli._hierarchy.hierarchy_string.return_value = "test.org-fac.sys-client" |
| 66 | + mock_cli._hierarchy.hierarchy_string.side_effect = lambda sep='.': "test.org-fac.sys-client".replace('.', sep) if sep != '.' else "test.org-fac.sys-client" |
| 67 | + mock_cli._terminate_after_initial_messages = False |
| 68 | + mock_cli._campaign_id = uuid4() |
| 69 | + mock_cli._user_callback = MagicMock() |
| 70 | + |
| 71 | + # Mock data plane manager |
| 72 | + mock_cli._data_plane_manager = MagicMock() |
| 73 | + |
| 74 | + # Mock control plane manager |
| 75 | + mock_cli._control_plane_manager = MagicMock() |
| 76 | + |
| 77 | + return mock_cli |
| 78 | + |
| 79 | + |
| 80 | +def test_send_message_encryption_scheme_none_skips_encryption(mock_client): |
| 81 | + """Test that encryption_scheme='NONE' skips encryption for outgoing messages""" |
| 82 | + params = IntersectDirectMessageParams( |
| 83 | + destination="test.org-fac.sys-service", |
| 84 | + operation="test_op", |
| 85 | + payload='{"test": "data"}', |
| 86 | + content_type="application/json", |
| 87 | + encryption_scheme="NONE", |
| 88 | + ) |
| 89 | + |
| 90 | + mock_client._data_plane_manager.outgoing_message_data_handler.return_value = ( |
| 91 | + b"serialized_data" |
| 92 | + ) |
| 93 | + |
| 94 | + # Call the send method |
| 95 | + IntersectClient._send_userspace_message(mock_client, params) |
| 96 | + |
| 97 | + # Verify publish_message was called |
| 98 | + assert mock_client._control_plane_manager.publish_message.called |
| 99 | + call_args = mock_client._control_plane_manager.publish_message.call_args |
| 100 | + # The payload should not be encrypted (not an IntersectEncryptedPayload instance) |
| 101 | + payload = call_args[0][1] |
| 102 | + assert payload == b"serialized_data" |
| 103 | + |
| 104 | +def test_send_message_rsa_fetches_public_key(mock_client, public_key_pem): |
| 105 | + """Test that RSA encryption fetches the public key before encrypting""" |
| 106 | + params = IntersectDirectMessageParams( |
| 107 | + destination="test.org-fac.sys-service", |
| 108 | + operation="test_op", |
| 109 | + payload='{"test": "data"}', |
| 110 | + content_type="application/json", |
| 111 | + encryption_scheme="RSA", |
| 112 | + ) |
| 113 | + |
| 114 | + # Mock the public key fetch to return the service's public key |
| 115 | + mock_client._fetch_service_public_key = MagicMock(return_value=public_key_pem) |
| 116 | + mock_client._data_plane_manager.outgoing_message_data_handler.return_value = ( |
| 117 | + b"encrypted_payload" |
| 118 | + ) |
| 119 | + |
| 120 | + # Call the send method |
| 121 | + IntersectClient._send_userspace_message(mock_client, params) |
| 122 | + |
| 123 | + # Verify that public key was fetched |
| 124 | + mock_client._fetch_service_public_key.assert_called_once_with("test.org-fac.sys-service") |
| 125 | + |
| 126 | +def test_send_message_rsa_encryption_payload_structure(mock_client, public_key_pem): |
| 127 | + """Test that encrypted payload has the expected structure""" |
| 128 | + params = IntersectDirectMessageParams( |
| 129 | + destination="test.org-fac.sys-service", |
| 130 | + operation="test_op", |
| 131 | + payload='{"message": "hello", "value": 42}', |
| 132 | + content_type="application/json", |
| 133 | + encryption_scheme="RSA", |
| 134 | + ) |
| 135 | + |
| 136 | + mock_client._fetch_service_public_key = MagicMock(return_value=public_key_pem) |
| 137 | + |
| 138 | + captured_data = None |
| 139 | + def capture_handler(data, content_type, handler): |
| 140 | + nonlocal captured_data |
| 141 | + captured_data = data |
| 142 | + return b"encrypted_to_broker" |
| 143 | + |
| 144 | + mock_client._data_plane_manager.outgoing_message_data_handler = capture_handler |
| 145 | + |
| 146 | + # Call the send method |
| 147 | + IntersectClient._send_userspace_message(mock_client, params) |
| 148 | + |
| 149 | + # Verify the payload structure |
| 150 | + assert isinstance(captured_data, IntersectEncryptedPayload) |
| 151 | + assert captured_data.key |
| 152 | + assert captured_data.initial_vector |
| 153 | + assert captured_data.data |
| 154 | + |
| 155 | +def test_send_message_rsa_encryption_fails_without_public_key(mock_client): |
| 156 | + """Test that sending fails if public key cannot be fetched""" |
| 157 | + params = IntersectDirectMessageParams( |
| 158 | + destination="test.org-fac.sys-service", |
| 159 | + operation="test_op", |
| 160 | + payload='{"test": "data"}', |
| 161 | + content_type="application/json", |
| 162 | + encryption_scheme="RSA", |
| 163 | + ) |
| 164 | + |
| 165 | + # Mock the public key fetch to return None (failure) |
| 166 | + mock_client._fetch_service_public_key = MagicMock(return_value=None) |
| 167 | + |
| 168 | + # Call the send method |
| 169 | + IntersectClient._send_userspace_message(mock_client, params) |
| 170 | + |
| 171 | + # Verify that outgoing_message_data_handler was NOT called (message sending failed) |
| 172 | + assert not mock_client._data_plane_manager.outgoing_message_data_handler.called |
| 173 | + |
| 174 | +def test_receive_message_encryption_scheme_none_skips_decryption(mock_client): |
| 175 | + """Test that encryption_scheme='NONE' skips decryption for incoming messages""" |
| 176 | + # Create message headers for an unencrypted message |
| 177 | + headers = create_userspace_message_headers( |
| 178 | + source="test.org-fac.sys-service", |
| 179 | + destination=mock_client._hierarchy.hierarchy_string('.'), |
| 180 | + data_handler=IntersectDataHandler.MESSAGE, |
| 181 | + operation_id="test_op", |
| 182 | + campaign_id=uuid4(), |
| 183 | + request_id=uuid4(), |
| 184 | + encryption_scheme="NONE", |
| 185 | + ) |
| 186 | + |
| 187 | + payload = b'{"test": "data"}' |
| 188 | + mock_client._data_plane_manager.incoming_message_data_handler.return_value = payload |
| 189 | + mock_client._user_callback = MagicMock() |
| 190 | + |
| 191 | + # Call the handle method |
| 192 | + IntersectClient._handle_userspace_message( |
| 193 | + mock_client, payload, "application/json", headers |
| 194 | + ) |
| 195 | + |
| 196 | + # Verify that incoming_message_data_handler was called |
| 197 | + assert mock_client._data_plane_manager.incoming_message_data_handler.called |
| 198 | + |
| 199 | +def test_receive_message_rsa_decryption_uses_private_key(mock_client, rsa_keypair): |
| 200 | + """Test that RSA decryption uses the client's private key for incoming messages""" |
| 201 | + private_key, _ = rsa_keypair |
| 202 | + mock_client._private_key = private_key |
| 203 | + |
| 204 | + # Create a simple message to encrypt |
| 205 | + # The client expects a wrapper model with 'data' field containing the JSON string |
| 206 | + test_message_json = '{"data": "{\\"result\\": \\"data from service\\"}"}' |
| 207 | + key_payload = IntersectEncryptionPublicKey( |
| 208 | + public_key=private_key.public_key().public_bytes( |
| 209 | + encoding=serialization.Encoding.PEM, |
| 210 | + format=serialization.PublicFormat.SubjectPublicKeyInfo, |
| 211 | + ).decode() |
| 212 | + ) |
| 213 | + |
| 214 | + encrypted_payload = intersect_payload_encrypt( |
| 215 | + key_payload=key_payload, |
| 216 | + unencrypted_model=test_message_json, |
| 217 | + ) |
| 218 | + |
| 219 | + # Create message headers for an encrypted message |
| 220 | + headers = create_userspace_message_headers( |
| 221 | + source="test.org-fac.sys-service", |
| 222 | + destination=mock_client._hierarchy.hierarchy_string('.'), |
| 223 | + data_handler=IntersectDataHandler.MESSAGE, |
| 224 | + operation_id="test_op", |
| 225 | + campaign_id=uuid4(), |
| 226 | + request_id=uuid4(), |
| 227 | + encryption_scheme="RSA", |
| 228 | + ) |
| 229 | + |
| 230 | + # Mock the incoming handler to return the encrypted payload |
| 231 | + mock_client._data_plane_manager.incoming_message_data_handler.return_value = ( |
| 232 | + encrypted_payload |
| 233 | + ) |
| 234 | + mock_client._user_callback = MagicMock() |
| 235 | + |
| 236 | + # Call the handle method |
| 237 | + IntersectClient._handle_userspace_message( |
| 238 | + mock_client, b"ignored_payload", "application/json", headers |
| 239 | + ) |
| 240 | + |
| 241 | + # Verify that incoming_message_data_handler was called |
| 242 | + assert mock_client._data_plane_manager.incoming_message_data_handler.called |
| 243 | + |
| 244 | +def test_receive_public_key_response_stores_in_cache(mock_client): |
| 245 | + """Test that public key responses are cached for use in key fetching""" |
| 246 | + request_id = uuid4() |
| 247 | + public_key_data = {"public_key": "-----BEGIN PUBLIC KEY-----\ntest\n-----END PUBLIC KEY-----\n"} |
| 248 | + |
| 249 | + # Create message headers for a public key response |
| 250 | + headers = create_userspace_message_headers( |
| 251 | + source="test.org-fac.sys-service", |
| 252 | + destination=mock_client._hierarchy.hierarchy_string('.'), |
| 253 | + data_handler=IntersectDataHandler.MESSAGE, |
| 254 | + operation_id="intersect_sdk.get_public_key", |
| 255 | + campaign_id=uuid4(), |
| 256 | + request_id=request_id, |
| 257 | + encryption_scheme="NONE", |
| 258 | + ) |
| 259 | + |
| 260 | + payload = b'{"public_key": "test_key"}' |
| 261 | + mock_client._data_plane_manager.incoming_message_data_handler.return_value = public_key_data |
| 262 | + |
| 263 | + # Call the handle method |
| 264 | + IntersectClient._handle_userspace_message( |
| 265 | + mock_client, payload, "application/json", headers |
| 266 | + ) |
| 267 | + |
| 268 | + # Verify that the response was cached using message_id as key |
| 269 | + assert hasattr(mock_client, '_public_key_responses') |
| 270 | + message_id = headers['message_id'] # Extract from the headers dict |
| 271 | + assert message_id in mock_client._public_key_responses |
| 272 | + assert mock_client._public_key_responses[message_id] == public_key_data |
| 273 | + |
| 274 | +def test_send_message_rsa_encryption_roundtrip(mock_client, rsa_keypair, public_key_pem): |
| 275 | + """Test that encrypted message can be decrypted""" |
| 276 | + private_key, _ = rsa_keypair |
| 277 | + class LocalMessage(BaseModel): |
| 278 | + text: str |
| 279 | + |
| 280 | + original_message = LocalMessage(text="Test message from client") |
| 281 | + params = IntersectDirectMessageParams( |
| 282 | + destination="test.org-fac.sys-service", |
| 283 | + operation="test_op", |
| 284 | + payload=original_message, |
| 285 | + content_type="application/json", |
| 286 | + encryption_scheme="RSA", |
| 287 | + ) |
| 288 | + |
| 289 | + mock_client._fetch_service_public_key = MagicMock(return_value=public_key_pem) |
| 290 | + |
| 291 | + captured_data = None |
| 292 | + def capture_handler(data, content_type, handler): |
| 293 | + nonlocal captured_data |
| 294 | + captured_data = data |
| 295 | + return b"encrypted_to_broker" |
| 296 | + |
| 297 | + mock_client._data_plane_manager.outgoing_message_data_handler = capture_handler |
| 298 | + |
| 299 | + # Send the message (which encrypts it) |
| 300 | + IntersectClient._send_userspace_message(mock_client, params) |
| 301 | + |
| 302 | + # Decrypt it to verify correctness |
| 303 | + decrypted = intersect_payload_decrypt( |
| 304 | + rsa_private_key=private_key, |
| 305 | + encrypted_payload=captured_data, |
| 306 | + model=LocalMessage, |
| 307 | + ) |
| 308 | + |
| 309 | + # The decrypted payload should match the original |
| 310 | + assert decrypted.model.text == original_message.text |
0 commit comments