Skip to content

Commit b4c4b43

Browse files
Added encryption to client + fix to service
1 parent 990d197 commit b4c4b43

2 files changed

Lines changed: 127 additions & 6 deletions

File tree

src/intersect_sdk/client.py

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,21 @@
1212

1313
from __future__ import annotations
1414

15+
from cryptography.hazmat.primitives.asymmetric import rsa
16+
from cryptography.hazmat.primitives import serialization
1517
import time
1618
from typing import TYPE_CHECKING
1719
from uuid import uuid4
1820

1921
from pydantic import ValidationError
2022
from typing_extensions import Self, final
2123

22-
from intersect_sdk._internal.generic_serializer import GENERIC_MESSAGE_SERIALIZER
24+
from ._internal.encryption.models.public_key import IntersectEncryptionPublicKey
25+
from ._internal.encryption import (
26+
intersect_payload_decrypt,
27+
intersect_payload_encrypt,
28+
)
29+
from ._internal.generic_serializer import GENERIC_MESSAGE_SERIALIZER
2330

2431
from ._internal.control_plane.control_plane_manager import (
2532
ControlPlaneManager,
@@ -149,6 +156,24 @@ def __init__(
149156

150157
self._campaign_id = uuid4()
151158

159+
# Generate a key pair for encryption
160+
self._private_key: rsa.RSAPrivateKey = rsa.generate_private_key(
161+
public_exponent=65537, key_size=2048
162+
)
163+
self._public_key = self._private_key.public_key()
164+
165+
# Get the PEM encoded public key
166+
self._public_key_pem = self._public_key.public_bytes(
167+
encoding=serialization.Encoding.PEM,
168+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
169+
).decode()
170+
171+
self._private_key.private_bytes(
172+
encoding=serialization.Encoding.PEM,
173+
format=serialization.PrivateFormat.TraditionalOpenSSL,
174+
encryption_algorithm=serialization.NoEncryption(),
175+
)
176+
152177
@final
153178
def startup(self) -> Self:
154179
"""This function connects the client to all INTERSECT systems.
@@ -225,6 +250,69 @@ def considered_unrecoverable(self) -> bool:
225250
"""
226251
return self._control_plane_manager.considered_unrecoverable()
227252

253+
def _fetch_service_public_key(self, destination: str) -> str | None:
254+
"""Fetch the public key from a destination service via intersect_sdk.get_public_key.
255+
256+
Args:
257+
destination: The hierarchy string of the destination service
258+
259+
Returns:
260+
The PEM-encoded public key from the service, or None if fetching failed
261+
"""
262+
# Create a temporary request ID for tracking the public key request
263+
public_key_request_id = uuid4()
264+
265+
# Build and send the public key request
266+
headers = create_userspace_message_headers(
267+
source=self._hierarchy.hierarchy_string('.'),
268+
destination=destination,
269+
operation_id='intersect_sdk.get_public_key',
270+
campaign_id=uuid4(),
271+
request_id=public_key_request_id,
272+
encryption_scheme='NONE',
273+
)
274+
275+
# Send the request to get the public key
276+
request_channel = f'{destination.replace(".", "/")}/request'
277+
self._control_plane_manager.publish_message(
278+
request_channel, b'null', 'application/json', headers, persist=False
279+
)
280+
281+
# Poll for the response with timeout
282+
start_time = time.time()
283+
timeout = 30.0
284+
poll_interval = 0.1
285+
response_data = None
286+
287+
while time.time() - start_time < timeout:
288+
# Check if we have received a response for this request
289+
# We'll store responses temporarily in a dict keyed by request_id
290+
if hasattr(self, '_public_key_responses') and str(public_key_request_id) in self._public_key_responses:
291+
response_data = self._public_key_responses.pop(str(public_key_request_id))
292+
break
293+
time.sleep(poll_interval)
294+
295+
if response_data is None:
296+
logger.error(f'Failed to retrieve public key from {destination} within timeout')
297+
return None
298+
299+
# Extract the public key from the response
300+
try:
301+
if isinstance(response_data, dict):
302+
public_key = response_data.get('public_key')
303+
else:
304+
# If it's already parsed as IntersectEncryptionPublicKey
305+
public_key = response_data.public_key if hasattr(response_data, 'public_key') else None
306+
307+
if public_key is None:
308+
logger.error(f'No public_key field in response from {destination}')
309+
return None
310+
311+
return public_key
312+
except (KeyError, AttributeError) as e:
313+
logger.error(f'Failed to extract public key from response: {e}')
314+
return None
315+
228316
def _handle_userspace_message(
229317
self, payload: bytes, content_type: str, raw_headers: dict[str, str]
230318
) -> None:
@@ -265,11 +353,30 @@ def _handle_userspace_message(
265353
request_params = self._data_plane_manager.incoming_message_data_handler(
266354
payload, headers.data_handler
267355
)
356+
357+
# Check if this is a response to a public key request (before processing as user message)
358+
if headers.operation_id == 'intersect_sdk.get_public_key':
359+
if not hasattr(self, '_public_key_responses'):
360+
self._public_key_responses = {}
361+
self._public_key_responses[str(headers.message_id)] = request_params
362+
return
363+
268364
if not headers.has_error:
269365
match headers.encryption_scheme:
270366
case 'RSA':
271-
# TODO - decrypt and reassign request_params here
272-
pass
367+
from pydantic import BaseModel
368+
369+
# Create a simple wrapper model for bytes deserialization
370+
class _BytesWrapper(BaseModel):
371+
data: str = ""
372+
373+
decrypted = intersect_payload_decrypt(
374+
rsa_private_key=self._private_key,
375+
encrypted_payload=request_params,
376+
model=_BytesWrapper,
377+
)
378+
# Extract the decrypted data and return as bytes
379+
request_params = decrypted.model.data.encode()
273380
case _:
274381
pass
275382
if content_type == 'application/json':
@@ -428,8 +535,20 @@ def _send_userspace_message(self, params: IntersectDirectMessageParams) -> None:
428535
# TWO: encrypt message
429536
match params.encryption_scheme:
430537
case 'RSA':
431-
# TODO reassign serialized_msg here to encrypted value
432-
pass
538+
# Fetch the destination service's public key
539+
public_key_pem = self._fetch_service_public_key(params.destination)
540+
if public_key_pem is None:
541+
logger.error(f'Failed to fetch public key from {params.destination}')
542+
return
543+
544+
# Convert bytes to string if needed (since intersect_payload_encrypt expects a string)
545+
unencrypted_string = serialized_msg if isinstance(serialized_msg, str) else serialized_msg.decode()
546+
serialized_msg = intersect_payload_encrypt(
547+
key_payload=IntersectEncryptionPublicKey(
548+
public_key=public_key_pem
549+
),
550+
unencrypted_model=unencrypted_string,
551+
)
433552
case _:
434553
pass
435554

src/intersect_sdk/service.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,9 @@ def _send_client_message(self, request_id: UUID, params: IntersectDirectMessageP
10391039
# Convert bytes back to string if needed (since intersect_payload_encrypt expects a string)
10401040
unencrypted_string = request if isinstance(request, str) else request.decode()
10411041
request = intersect_payload_encrypt(
1042-
key_payload=key_payload,
1042+
key_payload=IntersectEncryptionPublicKey(
1043+
public_key=key_payload.public_key,
1044+
),
10431045
unencrypted_model=unencrypted_string,
10441046
)
10451047
case _:

0 commit comments

Comments
 (0)