|
12 | 12 |
|
13 | 13 | from __future__ import annotations |
14 | 14 |
|
| 15 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 16 | +from cryptography.hazmat.primitives import serialization |
15 | 17 | import time |
16 | 18 | from typing import TYPE_CHECKING |
17 | 19 | from uuid import uuid4 |
18 | 20 |
|
19 | 21 | from pydantic import ValidationError |
20 | 22 | from typing_extensions import Self, final |
21 | 23 |
|
22 | | -from intersect_sdk._internal.generic_serializer import GENERIC_MESSAGE_SERIALIZER |
| 24 | +from ._internal.encryption.models.public_key import IntersectEncryptionPublicKey |
| 25 | +from ._internal.encryption import ( |
| 26 | + intersect_payload_decrypt, |
| 27 | + intersect_payload_encrypt, |
| 28 | +) |
| 29 | +from ._internal.generic_serializer import GENERIC_MESSAGE_SERIALIZER |
23 | 30 |
|
24 | 31 | from ._internal.control_plane.control_plane_manager import ( |
25 | 32 | ControlPlaneManager, |
@@ -149,6 +156,24 @@ def __init__( |
149 | 156 |
|
150 | 157 | self._campaign_id = uuid4() |
151 | 158 |
|
| 159 | + # Generate a key pair for encryption |
| 160 | + self._private_key: rsa.RSAPrivateKey = rsa.generate_private_key( |
| 161 | + public_exponent=65537, key_size=2048 |
| 162 | + ) |
| 163 | + self._public_key = self._private_key.public_key() |
| 164 | + |
| 165 | + # Get the PEM encoded public key |
| 166 | + self._public_key_pem = self._public_key.public_bytes( |
| 167 | + encoding=serialization.Encoding.PEM, |
| 168 | + format=serialization.PublicFormat.SubjectPublicKeyInfo, |
| 169 | + ).decode() |
| 170 | + |
| 171 | + self._private_key.private_bytes( |
| 172 | + encoding=serialization.Encoding.PEM, |
| 173 | + format=serialization.PrivateFormat.TraditionalOpenSSL, |
| 174 | + encryption_algorithm=serialization.NoEncryption(), |
| 175 | + ) |
| 176 | + |
152 | 177 | @final |
153 | 178 | def startup(self) -> Self: |
154 | 179 | """This function connects the client to all INTERSECT systems. |
@@ -225,6 +250,69 @@ def considered_unrecoverable(self) -> bool: |
225 | 250 | """ |
226 | 251 | return self._control_plane_manager.considered_unrecoverable() |
227 | 252 |
|
| 253 | + def _fetch_service_public_key(self, destination: str) -> str | None: |
| 254 | + """Fetch the public key from a destination service via intersect_sdk.get_public_key. |
| 255 | + |
| 256 | + Args: |
| 257 | + destination: The hierarchy string of the destination service |
| 258 | + |
| 259 | + Returns: |
| 260 | + The PEM-encoded public key from the service, or None if fetching failed |
| 261 | + """ |
| 262 | + # Create a temporary request ID for tracking the public key request |
| 263 | + public_key_request_id = uuid4() |
| 264 | + |
| 265 | + # Build and send the public key request |
| 266 | + headers = create_userspace_message_headers( |
| 267 | + source=self._hierarchy.hierarchy_string('.'), |
| 268 | + destination=destination, |
| 269 | + operation_id='intersect_sdk.get_public_key', |
| 270 | + campaign_id=uuid4(), |
| 271 | + request_id=public_key_request_id, |
| 272 | + encryption_scheme='NONE', |
| 273 | + ) |
| 274 | + |
| 275 | + # Send the request to get the public key |
| 276 | + request_channel = f'{destination.replace(".", "/")}/request' |
| 277 | + self._control_plane_manager.publish_message( |
| 278 | + request_channel, b'null', 'application/json', headers, persist=False |
| 279 | + ) |
| 280 | + |
| 281 | + # Poll for the response with timeout |
| 282 | + start_time = time.time() |
| 283 | + timeout = 30.0 |
| 284 | + poll_interval = 0.1 |
| 285 | + response_data = None |
| 286 | + |
| 287 | + while time.time() - start_time < timeout: |
| 288 | + # Check if we have received a response for this request |
| 289 | + # We'll store responses temporarily in a dict keyed by request_id |
| 290 | + if hasattr(self, '_public_key_responses') and str(public_key_request_id) in self._public_key_responses: |
| 291 | + response_data = self._public_key_responses.pop(str(public_key_request_id)) |
| 292 | + break |
| 293 | + time.sleep(poll_interval) |
| 294 | + |
| 295 | + if response_data is None: |
| 296 | + logger.error(f'Failed to retrieve public key from {destination} within timeout') |
| 297 | + return None |
| 298 | + |
| 299 | + # Extract the public key from the response |
| 300 | + try: |
| 301 | + if isinstance(response_data, dict): |
| 302 | + public_key = response_data.get('public_key') |
| 303 | + else: |
| 304 | + # If it's already parsed as IntersectEncryptionPublicKey |
| 305 | + public_key = response_data.public_key if hasattr(response_data, 'public_key') else None |
| 306 | + |
| 307 | + if public_key is None: |
| 308 | + logger.error(f'No public_key field in response from {destination}') |
| 309 | + return None |
| 310 | + |
| 311 | + return public_key |
| 312 | + except (KeyError, AttributeError) as e: |
| 313 | + logger.error(f'Failed to extract public key from response: {e}') |
| 314 | + return None |
| 315 | + |
228 | 316 | def _handle_userspace_message( |
229 | 317 | self, payload: bytes, content_type: str, raw_headers: dict[str, str] |
230 | 318 | ) -> None: |
@@ -265,11 +353,30 @@ def _handle_userspace_message( |
265 | 353 | request_params = self._data_plane_manager.incoming_message_data_handler( |
266 | 354 | payload, headers.data_handler |
267 | 355 | ) |
| 356 | + |
| 357 | + # Check if this is a response to a public key request (before processing as user message) |
| 358 | + if headers.operation_id == 'intersect_sdk.get_public_key': |
| 359 | + if not hasattr(self, '_public_key_responses'): |
| 360 | + self._public_key_responses = {} |
| 361 | + self._public_key_responses[str(headers.message_id)] = request_params |
| 362 | + return |
| 363 | + |
268 | 364 | if not headers.has_error: |
269 | 365 | match headers.encryption_scheme: |
270 | 366 | case 'RSA': |
271 | | - # TODO - decrypt and reassign request_params here |
272 | | - pass |
| 367 | + from pydantic import BaseModel |
| 368 | + |
| 369 | + # Create a simple wrapper model for bytes deserialization |
| 370 | + class _BytesWrapper(BaseModel): |
| 371 | + data: str = "" |
| 372 | + |
| 373 | + decrypted = intersect_payload_decrypt( |
| 374 | + rsa_private_key=self._private_key, |
| 375 | + encrypted_payload=request_params, |
| 376 | + model=_BytesWrapper, |
| 377 | + ) |
| 378 | + # Extract the decrypted data and return as bytes |
| 379 | + request_params = decrypted.model.data.encode() |
273 | 380 | case _: |
274 | 381 | pass |
275 | 382 | if content_type == 'application/json': |
@@ -428,8 +535,20 @@ def _send_userspace_message(self, params: IntersectDirectMessageParams) -> None: |
428 | 535 | # TWO: encrypt message |
429 | 536 | match params.encryption_scheme: |
430 | 537 | case 'RSA': |
431 | | - # TODO reassign serialized_msg here to encrypted value |
432 | | - pass |
| 538 | + # Fetch the destination service's public key |
| 539 | + public_key_pem = self._fetch_service_public_key(params.destination) |
| 540 | + if public_key_pem is None: |
| 541 | + logger.error(f'Failed to fetch public key from {params.destination}') |
| 542 | + return |
| 543 | + |
| 544 | + # Convert bytes to string if needed (since intersect_payload_encrypt expects a string) |
| 545 | + unencrypted_string = serialized_msg if isinstance(serialized_msg, str) else serialized_msg.decode() |
| 546 | + serialized_msg = intersect_payload_encrypt( |
| 547 | + key_payload=IntersectEncryptionPublicKey( |
| 548 | + public_key=public_key_pem |
| 549 | + ), |
| 550 | + unencrypted_model=unencrypted_string, |
| 551 | + ) |
433 | 552 | case _: |
434 | 553 | pass |
435 | 554 |
|
|
0 commit comments