Skip to content

Commit efd16bb

Browse files
authored
fix: consolidate API response on streams (#26)
1 parent cd38a96 commit efd16bb

2 files changed

Lines changed: 22 additions & 19 deletions

File tree

src/atoma_sdk/confidential_chat.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def create(
173173
encrypted_message=encrypted_response
174174
)
175175
return utils.unmarshal_json(
176-
decrypted_response.decode('utf-8'),
176+
decrypted_response.decode('utf-8'),
177177
models.ChatCompletionResponse
178178
)
179179
except Exception as e:
@@ -362,7 +362,7 @@ async def create_async(
362362
encrypted_message=encrypted_response
363363
)
364364
return utils.unmarshal_json(
365-
decrypted_response.decode('utf-8'),
365+
decrypted_response.decode('utf-8'),
366366
models.ChatCompletionResponse
367367
)
368368
except Exception as e:
@@ -554,7 +554,8 @@ def decrypt_chunk(raw_chunk):
554554
# Skip chunks with empty choices
555555
if not decrypted_json.get('choices'):
556556
return None
557-
return models.ChatCompletionChunk.model_validate(decrypted_json)
557+
# Wrap the chunk in a StreamResponse to maintain consistent API
558+
return models.ChatCompletionStreamResponse(data=models.ChatCompletionChunk.model_validate(decrypted_json))
558559
except Exception as e:
559560
raise models.APIError(f"Failed to decrypt stream chunk: {str(e)}", 500, str(e), None)
560561

@@ -740,7 +741,9 @@ async def decrypt_chunk(raw_chunk):
740741
salt=salt,
741742
encrypted_message=encrypted_chunk.data
742743
)
743-
return utils.unmarshal_json(decrypted_chunk.decode('utf-8'), models.ChatCompletionStreamResponse)
744+
decrypted_json = json.loads(decrypted_chunk.decode('utf-8'))
745+
# Wrap the chunk in a StreamResponse to maintain consistent API
746+
return models.ChatCompletionStreamResponse(data=models.ChatCompletionChunk.model_validate(decrypted_json))
744747
except Exception as e:
745748
raise models.APIError(f"Failed to decrypt stream chunk: {str(e)}", 500, str(e), None)
746749

test/test_confidential_chat.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def test_chat_completion_basic(client):
2626
]
2727
)
2828

29-
print(completion.choices[0].message)
30-
29+
print(completion.data.choices[0].message)
30+
3131
assert completion is not None
32-
assert len(completion.choices) > 0
33-
assert completion.choices[0].message.content is not None
32+
assert len(completion.data.choices) > 0
33+
assert completion.data.choices[0].message.content is not None
3434

3535
def test_chat_completion_with_system_message(client):
3636
completion = client.confidential_chat.create(
@@ -40,10 +40,10 @@ def test_chat_completion_with_system_message(client):
4040
{"role": "user", "content": "Hello!"}
4141
]
4242
)
43-
43+
4444
assert completion is not None
45-
assert len(completion.choices) > 0
46-
assert completion.choices[0].message.content is not None
45+
assert len(completion.data.choices) > 0
46+
assert completion.data.choices[0].message.content is not None
4747

4848
@pytest.mark.asyncio
4949
async def test_chat_completion_async(client):
@@ -54,10 +54,10 @@ async def test_chat_completion_async(client):
5454
{"role": "user", "content": "Hello!"}
5555
]
5656
)
57-
57+
5858
assert completion is not None
59-
assert len(completion.choices) > 0
60-
assert completion.choices[0].message.content is not None
59+
assert len(completion.data.choices) > 0
60+
assert completion.data.choices[0].message.content is not None
6161

6262
def test_chat_completion_stream(client):
6363
completion = client.confidential_chat.create_stream(
@@ -70,15 +70,15 @@ def test_chat_completion_stream(client):
7070

7171
# Verify we get a valid stream
7272
assert completion is not None
73-
73+
7474
# Process the stream and verify chunks
7575
chunk_count = 0
7676
for chunk in completion:
7777
assert chunk is not None
78-
assert chunk is not None
79-
assert len(chunk.choices) > 0
80-
assert chunk.choices[0].delta is not None
78+
assert chunk.data is not None
79+
assert len(chunk.data.choices) > 0
80+
assert chunk.data.choices[0].delta is not None
8181
chunk_count += 1
82-
82+
8383
# Verify we got multiple chunks
8484
assert chunk_count > 0

0 commit comments

Comments
 (0)