1818
1919import time
2020from collections import defaultdict
21+ from cryptography .hazmat .primitives .asymmetric import rsa
22+ from cryptography .hazmat .primitives import serialization
2123from threading import Lock
2224from types import MappingProxyType
2325from typing import TYPE_CHECKING , Any , Literal
3133 ControlPlaneManager ,
3234)
3335from ._internal .data_plane .data_plane_manager import DataPlaneManager
36+ from ._internal .encryption import intersect_payload_decrypt , intersect_payload_encrypt
37+ from ._internal .encryption .models import IntersectEncryptionPublicKey
3438from ._internal .exceptions import IntersectApplicationError , IntersectError
3539from ._internal .generic_serializer import GENERIC_MESSAGE_SERIALIZER
3640from ._internal .interfaces import IntersectEventObserver
@@ -308,6 +312,24 @@ def __init__(
308312 self ._client_channel_name , {self ._handle_client_message }, persist = True
309313 )
310314
315+ # Generate a key pair for encryption
316+ self ._private_key : rsa .RSAPrivateKey = rsa .generate_private_key (
317+ public_exponent = 65537 , key_size = 2048
318+ )
319+ self ._public_key = self ._private_key .public_key ()
320+
321+ # Get the PEM encoded public key
322+ self ._public_key_pem = self ._public_key .public_bytes (
323+ encoding = serialization .Encoding .PEM ,
324+ format = serialization .PublicFormat .SubjectPublicKeyInfo ,
325+ ).decode ()
326+
327+ self ._private_key .private_bytes (
328+ encoding = serialization .Encoding .PEM ,
329+ format = serialization .PrivateFormat .TraditionalOpenSSL ,
330+ encryption_algorithm = serialization .NoEncryption (),
331+ )
332+
311333 def _get_capability (self , target : str ) -> IntersectBaseCapabilityImplementation | None :
312334 for cap in self .capabilities :
313335 if cap .intersect_sdk_capability_name == target :
@@ -825,7 +847,11 @@ def _handle_service_message_inner(
825847 )
826848 match headers .encryption_scheme :
827849 case 'RSA' :
828- # TODO - decrypt request_params here
850+ request_params = intersect_payload_decrypt (
851+ rsa_private_key = self ._private_key ,
852+ encrypted_payload = request_params ,
853+ model = bytes ,
854+ )
829855 pass
830856 case _:
831857 pass
@@ -922,7 +948,11 @@ def _handle_client_message(
922948 # error messages should never be encrypted
923949 match headers .encryption_scheme :
924950 case 'RSA' :
925- # TODO - decrypt message here
951+ msg_payload = intersect_payload_decrypt (
952+ rsa_private_key = self ._private_key ,
953+ encrypted_payload = msg_payload ,
954+ model = bytes ,
955+ )
926956 pass
927957 case _:
928958 pass
@@ -954,7 +984,7 @@ def _send_client_message(self, request_id: UUID, params: IntersectDirectMessageP
954984 if params .content_type == 'application/json' :
955985 request = GENERIC_MESSAGE_SERIALIZER .dump_json (params .payload , warnings = False )
956986 else :
957- if not isinstance (params .content_type , bytes ):
987+ if not isinstance (params .payload , bytes ):
958988 logger .error (
959989 'service-to-service message must be bytes if content-type is not application/json'
960990 )
@@ -963,8 +993,55 @@ def _send_client_message(self, request_id: UUID, params: IntersectDirectMessageP
963993
964994 match params .encryption_scheme :
965995 case 'RSA' :
966- # TODO encrypt message
967- pass
996+ # Get the public key from the destination service
997+ # The destination service has the public key through its universal capability
998+ public_key_request = IntersectDirectMessageParams (
999+ destination = params .destination ,
1000+ operation = 'intersect_sdk.get_public_key' ,
1001+ payload = None ,
1002+ )
1003+
1004+ # Create the external request to fetch the public key
1005+ public_key_request_id = self .create_external_request (
1006+ public_key_request ,
1007+ response_handler = None ,
1008+ timeout = 30.0 , # shorter timeout for key fetching
1009+ )
1010+
1011+ # Poll for the response with timeout
1012+ start_time = time .time ()
1013+ timeout = 30.0
1014+ poll_interval = 0.1
1015+
1016+ public_key_payload = None
1017+ while time .time () - start_time < timeout :
1018+ extreq = self ._get_external_request (public_key_request_id )
1019+ if extreq and extreq .request_state == 'received' :
1020+ public_key_payload = extreq .response_payload
1021+ extreq .request_state = 'finalized'
1022+ break
1023+ time .sleep (poll_interval )
1024+
1025+ if public_key_payload is None :
1026+ logger .error (
1027+ f'Failed to retrieve public key from { params .destination } within timeout'
1028+ )
1029+ return False
1030+
1031+ # public_key_payload should be an IntersectEncryptionPublicKey instance (as a dict)
1032+ # Parse it if it's a dict, otherwise use it as-is
1033+ if isinstance (public_key_payload , dict ):
1034+ key_payload = IntersectEncryptionPublicKey (** public_key_payload )
1035+ else :
1036+ key_payload = public_key_payload
1037+
1038+ # Now encrypt the request using the retrieved public key
1039+ # Convert bytes back to string if needed (since intersect_payload_encrypt expects a string)
1040+ unencrypted_string = request if isinstance (request , str ) else request .decode ()
1041+ request = intersect_payload_encrypt (
1042+ key_payload = key_payload ,
1043+ unencrypted_model = unencrypted_string ,
1044+ )
9681045 case _:
9691046 pass
9701047
0 commit comments