Skip to content

Commit e22be50

Browse files
authored
fix: fixed issue in confidential chat (#30)
1 parent 0ea3bc9 commit e22be50

2 files changed

Lines changed: 34 additions & 7 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ __pycache__/
88
.DS_Store
99
pyrightconfig.json
1010
dist/
11-
.env
11+
.env
12+
main_test.py

src/atoma_sdk/confidential_chat.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -731,25 +731,51 @@ async def create_stream_async(
731731

732732
if utils.match_response(http_res, "200", "text/event-stream"):
733733
##################################################################################################
734-
# Create a decryption wrapper function for each chunk
735-
async def decrypt_chunk(raw_chunk):
734+
# Fix: Create a synchronous wrapper for the async decrypt_chunk function
735+
# This is needed because EventStreamAsync expects a sync function that returns the value directly
736+
def decrypt_chunk_wrapper(raw_chunk):
736737
try:
737-
encrypted_chunk = utils.unmarshal_json(raw_chunk, models.ConfidentialComputeResponse)
738+
# Parse the raw JSON string directly
739+
parsed_data = json.loads(raw_chunk)
740+
741+
# Create a simple class to hold the encrypted data with attributes
742+
class EncryptedData:
743+
def __init__(self, ciphertext, nonce, signature, response_hash):
744+
self.ciphertext = ciphertext
745+
self.nonce = nonce
746+
self.signature = signature
747+
self.response_hash = response_hash
748+
749+
# Instantiate this class with the data from the JSON
750+
encrypted_data = EncryptedData(
751+
ciphertext=parsed_data["data"]["ciphertext"],
752+
nonce=parsed_data["data"]["nonce"],
753+
signature=parsed_data["data"]["signature"],
754+
response_hash=parsed_data["data"]["response_hash"]
755+
)
756+
757+
# Now pass this properly structured object to decrypt_message
738758
decrypted_chunk = crypto_utils.decrypt_message(
739759
client_dh_private_key=client_dh_private_key,
740760
node_dh_public_key=node_dh_public_key,
741761
salt=salt,
742-
encrypted_message=encrypted_chunk.data
762+
encrypted_message=encrypted_data
743763
)
764+
744765
decrypted_json = json.loads(decrypted_chunk.decode('utf-8'))
745-
# Wrap the chunk in a StreamResponse to maintain consistent API
766+
767+
# Skip chunks with empty choices
768+
if not decrypted_json.get('choices'):
769+
return None
770+
771+
# Wrap the chunk in a StreamResponse
746772
return models.ChatCompletionStreamResponse(data=models.ChatCompletionChunk.model_validate(decrypted_json))
747773
except Exception as e:
748774
raise models.APIError(f"Failed to decrypt stream chunk: {str(e)}", 500, str(e), None)
749775

750776
return utils.eventstreaming.EventStreamAsync(
751777
http_res,
752-
decrypt_chunk,
778+
decrypt_chunk_wrapper,
753779
sentinel="[DONE]"
754780
)
755781
##################################################################################################

0 commit comments

Comments
 (0)