Skip to content

Commit 990d197

Browse files
Adds encryption to service w/ testing + fixes
1 parent 5915bc2 commit 990d197

8 files changed

Lines changed: 502 additions & 61 deletions

File tree

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from .aes_cypher import AESCipher
2-
from .client_encryption import intersect_client_encryption
3-
from .service_decryption import intersect_service_decryption
2+
from .client_encryption import intersect_payload_encrypt
3+
from .service_decryption import intersect_payload_decrypt
44

55

66
__all__ = (
77
'AESCipher',
8-
'intersect_client_encryption',
9-
'intersect_service_decryption',
8+
'intersect_payload_encrypt',
9+
'intersect_payload_decrypt',
1010
)

src/intersect_sdk/_internal/encryption/client_encryption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .models import IntersectEncryptedPayload, IntersectEncryptionPublicKey
1515
from ..logger import logger
1616

17-
def intersect_client_encryption(
17+
def intersect_payload_encrypt(
1818
key_payload: IntersectEncryptionPublicKey,
1919
unencrypted_model: str,
2020
) -> IntersectEncryptedPayload:

src/intersect_sdk/_internal/encryption/service_decryption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..logger import logger
1717

1818

19-
def intersect_service_decryption(
19+
def intersect_payload_decrypt(
2020
rsa_private_key: rsa.RSAPrivateKey,
2121
encrypted_payload: IntersectEncryptedPayload,
2222
model: Type[BaseModel],

src/intersect_sdk/capability/universal_capability/universal_capability.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
"""
77

88
import datetime
9-
from cryptography.hazmat.primitives.asymmetric import rsa
10-
from cryptography.hazmat.primitives import serialization
119
import os
1210
import time
13-
from typing import final
11+
from typing import Dict, final
1412

1513
import psutil
1614

@@ -39,24 +37,6 @@ def __init__(self) -> None: # noqa: D107
3937
self.process = psutil.Process(os.getpid())
4038
"""psutil.Process caches most functions it calls after it calls the function once, so just save the object itself"""
4139

42-
# Generate a key pair for encryption
43-
self._private_key: rsa.RSAPrivateKey = rsa.generate_private_key(
44-
public_exponent=65537, key_size=2048
45-
)
46-
self._public_key = self._private_key.public_key()
47-
48-
# Get the PEM encoded public key
49-
self._public_key_pem = self._public_key.public_bytes(
50-
encoding=serialization.Encoding.PEM,
51-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
52-
).decode()
53-
54-
self._private_key.private_bytes(
55-
encoding=serialization.Encoding.PEM,
56-
format=serialization.PrivateFormat.TraditionalOpenSSL,
57-
encryption_algorithm=serialization.NoEncryption(),
58-
)
59-
6040

6141
@intersect_status
6242
def system_capability(self) -> IntersectCoreStatus:

src/intersect_sdk/service.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import time
2020
from collections import defaultdict
21+
from cryptography.hazmat.primitives.asymmetric import rsa
22+
from cryptography.hazmat.primitives import serialization
2123
from threading import Lock
2224
from types import MappingProxyType
2325
from typing import TYPE_CHECKING, Any, Literal
@@ -31,6 +33,8 @@
3133
ControlPlaneManager,
3234
)
3335
from ._internal.data_plane.data_plane_manager import DataPlaneManager
36+
from ._internal.encryption import intersect_payload_decrypt, intersect_payload_encrypt
37+
from ._internal.encryption.models import IntersectEncryptionPublicKey
3438
from ._internal.exceptions import IntersectApplicationError, IntersectError
3539
from ._internal.generic_serializer import GENERIC_MESSAGE_SERIALIZER
3640
from ._internal.interfaces import IntersectEventObserver
@@ -308,6 +312,24 @@ def __init__(
308312
self._client_channel_name, {self._handle_client_message}, persist=True
309313
)
310314

315+
# Generate a key pair for encryption
316+
self._private_key: rsa.RSAPrivateKey = rsa.generate_private_key(
317+
public_exponent=65537, key_size=2048
318+
)
319+
self._public_key = self._private_key.public_key()
320+
321+
# Get the PEM encoded public key
322+
self._public_key_pem = self._public_key.public_bytes(
323+
encoding=serialization.Encoding.PEM,
324+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
325+
).decode()
326+
327+
self._private_key.private_bytes(
328+
encoding=serialization.Encoding.PEM,
329+
format=serialization.PrivateFormat.TraditionalOpenSSL,
330+
encryption_algorithm=serialization.NoEncryption(),
331+
)
332+
311333
def _get_capability(self, target: str) -> IntersectBaseCapabilityImplementation | None:
312334
for cap in self.capabilities:
313335
if cap.intersect_sdk_capability_name == target:
@@ -825,7 +847,11 @@ def _handle_service_message_inner(
825847
)
826848
match headers.encryption_scheme:
827849
case 'RSA':
828-
# TODO - decrypt request_params here
850+
request_params = intersect_payload_decrypt(
851+
rsa_private_key=self._private_key,
852+
encrypted_payload=request_params,
853+
model=bytes,
854+
)
829855
pass
830856
case _:
831857
pass
@@ -922,7 +948,11 @@ def _handle_client_message(
922948
# error messages should never be encrypted
923949
match headers.encryption_scheme:
924950
case 'RSA':
925-
# TODO - decrypt message here
951+
msg_payload = intersect_payload_decrypt(
952+
rsa_private_key=self._private_key,
953+
encrypted_payload=msg_payload,
954+
model=bytes,
955+
)
926956
pass
927957
case _:
928958
pass
@@ -954,7 +984,7 @@ def _send_client_message(self, request_id: UUID, params: IntersectDirectMessageP
954984
if params.content_type == 'application/json':
955985
request = GENERIC_MESSAGE_SERIALIZER.dump_json(params.payload, warnings=False)
956986
else:
957-
if not isinstance(params.content_type, bytes):
987+
if not isinstance(params.payload, bytes):
958988
logger.error(
959989
'service-to-service message must be bytes if content-type is not application/json'
960990
)
@@ -963,8 +993,55 @@ def _send_client_message(self, request_id: UUID, params: IntersectDirectMessageP
963993

964994
match params.encryption_scheme:
965995
case 'RSA':
966-
# TODO encrypt message
967-
pass
996+
# Get the public key from the destination service
997+
# The destination service has the public key through its universal capability
998+
public_key_request = IntersectDirectMessageParams(
999+
destination=params.destination,
1000+
operation='intersect_sdk.get_public_key',
1001+
payload=None,
1002+
)
1003+
1004+
# Create the external request to fetch the public key
1005+
public_key_request_id = self.create_external_request(
1006+
public_key_request,
1007+
response_handler=None,
1008+
timeout=30.0, # shorter timeout for key fetching
1009+
)
1010+
1011+
# Poll for the response with timeout
1012+
start_time = time.time()
1013+
timeout = 30.0
1014+
poll_interval = 0.1
1015+
1016+
public_key_payload = None
1017+
while time.time() - start_time < timeout:
1018+
extreq = self._get_external_request(public_key_request_id)
1019+
if extreq and extreq.request_state == 'received':
1020+
public_key_payload = extreq.response_payload
1021+
extreq.request_state = 'finalized'
1022+
break
1023+
time.sleep(poll_interval)
1024+
1025+
if public_key_payload is None:
1026+
logger.error(
1027+
f'Failed to retrieve public key from {params.destination} within timeout'
1028+
)
1029+
return False
1030+
1031+
# public_key_payload should be an IntersectEncryptionPublicKey instance (as a dict)
1032+
# Parse it if it's a dict, otherwise use it as-is
1033+
if isinstance(public_key_payload, dict):
1034+
key_payload = IntersectEncryptionPublicKey(**public_key_payload)
1035+
else:
1036+
key_payload = public_key_payload
1037+
1038+
# Now encrypt the request using the retrieved public key
1039+
# Convert bytes back to string if needed (since intersect_payload_encrypt expects a string)
1040+
unencrypted_string = request if isinstance(request, str) else request.decode()
1041+
request = intersect_payload_encrypt(
1042+
key_payload=key_payload,
1043+
unencrypted_model=unencrypted_string,
1044+
)
9681045
case _:
9691046
pass
9701047

tests/fixtures/example_schema.json

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,45 @@
99
"defaultContentType": "application/json",
1010
"capabilities": {
1111
"intersect_sdk": {
12-
"endpoints": {},
12+
"endpoints": {
13+
"get_public_key": {
14+
"publish": {
15+
"message": {
16+
"schemaFormat": "application/vnd.aai.asyncapi+json;version=2.6.0",
17+
"contentType": "application/json",
18+
"encryption_schemes": [
19+
"NONE",
20+
"RSA"
21+
],
22+
"traits": {
23+
"$ref": "#/components/messageTraits/commonHeaders"
24+
},
25+
"payload": {
26+
"additionalProperties": {
27+
"type": "string"
28+
},
29+
"type": "object",
30+
"title": "get_public_key"
31+
}
32+
},
33+
"description": "Returns the public key for clients / services to use for encryption"
34+
},
35+
"subscribe": {
36+
"message": {
37+
"schemaFormat": "application/vnd.aai.asyncapi+json;version=2.6.0",
38+
"contentType": "application/json",
39+
"encryption_schemes": [
40+
"NONE",
41+
"RSA"
42+
],
43+
"traits": {
44+
"$ref": "#/components/messageTraits/commonHeaders"
45+
}
46+
},
47+
"description": "Returns the public key for clients / services to use for encryption"
48+
}
49+
}
50+
},
1351
"events": {},
1452
"status": {
1553
"$ref": "#/components/schemas/IntersectCoreStatus"
@@ -1473,4 +1511,4 @@
14731511
}
14741512
}
14751513
}
1476-
}
1514+
}

0 commit comments

Comments
 (0)