Skip to content

Commit e68d9d8

Browse files
committed
mistral - latency
1 parent 8dbea7c commit e68d9d8

2 files changed

Lines changed: 51 additions & 39 deletions

File tree

src/strands/models/mistral.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import base64
77
import json
88
import logging
9+
import time
910
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union
1011

1112
import mistralai
@@ -334,7 +335,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
334335
return {"messageStop": {"stopReason": reason}}
335336

336337
case "metadata":
337-
usage = event["data"]
338+
usage = event["data"]["usage"]
339+
metrics = event["data"]["metrics"]
338340
return {
339341
"metadata": {
340342
"usage": {
@@ -343,7 +345,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
343345
"totalTokens": usage.total_tokens,
344346
},
345347
"metrics": {
346-
"latencyMs": event.get("latency_ms", 0),
348+
"latencyMs": metrics["latency"] * 1000,
347349
},
348350
},
349351
}
@@ -360,6 +362,8 @@ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, An
360362
Yields:
361363
Formatted events that match the streaming format.
362364
"""
365+
start_time = time.time()
366+
363367
yield {"chunk_type": "message_start"}
364368

365369
content_started = False
@@ -389,7 +393,12 @@ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, An
389393
yield {"chunk_type": "message_stop", "data": finish_reason}
390394

391395
if hasattr(response, "usage") and response.usage:
392-
yield {"chunk_type": "metadata", "data": response.usage}
396+
end_time = time.time()
397+
latency = end_time - start_time
398+
yield {
399+
"chunk_type": "metadata",
400+
"data": {"usage": response.usage, "metrics": {"latency": latency}},
401+
}
393402

394403
@override
395404
async def stream(
@@ -434,6 +443,7 @@ async def stream(
434443

435444
# Use the streaming API
436445
async with mistralai.Mistral(**self.client_args) as client:
446+
start_time = time.time()
437447
stream_response = await client.chat.stream_async(**request)
438448

439449
yield self.format_chunk({"chunk_type": "message_start"})
@@ -488,7 +498,14 @@ async def stream(
488498
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
489499

490500
if hasattr(chunk, "usage"):
491-
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
501+
end_time = time.time()
502+
latency = end_time - start_time
503+
yield self.format_chunk(
504+
{
505+
"chunk_type": "metadata",
506+
"data": {"usage": chunk.usage, "metrics": {"latency": latency}},
507+
}
508+
)
492509

493510
except Exception as e:
494511
if "rate" in str(e).lower() or "429" in str(e):

tests/strands/models/test_mistral.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ class TestOutputModel(pydantic.BaseModel):
6868
return TestOutputModel
6969

7070

71+
@pytest.fixture
72+
def mock_time():
73+
with unittest.mock.patch.object(strands.models.mistral, "time") as mock:
74+
yield mock.time
75+
76+
7177
def test__init__model_configs(mistral_client, model_id, max_tokens):
7278
_ = mistral_client
7379

@@ -380,38 +386,12 @@ def test_format_chunk_metadata(model):
380386

381387
event = {
382388
"chunk_type": "metadata",
383-
"data": mock_usage,
384-
"latency_ms": 250,
385-
}
386-
387-
actual_chunk = model.format_chunk(event)
388-
exp_chunk = {
389-
"metadata": {
390-
"usage": {
391-
"inputTokens": 100,
392-
"outputTokens": 50,
393-
"totalTokens": 150,
394-
},
395-
"metrics": {
396-
"latencyMs": 250,
397-
},
389+
"data": {
390+
"usage": mock_usage,
391+
"metrics": {"latency": 0.001},
398392
},
399393
}
400394

401-
assert actual_chunk == exp_chunk
402-
403-
404-
def test_format_chunk_metadata_no_latency(model):
405-
mock_usage = unittest.mock.Mock()
406-
mock_usage.prompt_tokens = 100
407-
mock_usage.completion_tokens = 50
408-
mock_usage.total_tokens = 150
409-
410-
event = {
411-
"chunk_type": "metadata",
412-
"data": mock_usage,
413-
}
414-
415395
actual_chunk = model.format_chunk(event)
416396
exp_chunk = {
417397
"metadata": {
@@ -421,7 +401,7 @@ def test_format_chunk_metadata_no_latency(model):
421401
"totalTokens": 150,
422402
},
423403
"metrics": {
424-
"latencyMs": 0,
404+
"latencyMs": 1,
425405
},
426406
},
427407
}
@@ -437,7 +417,9 @@ def test_format_chunk_unknown(model):
437417

438418

439419
@pytest.mark.asyncio
440-
async def test_stream(mistral_client, model, agenerator, alist):
420+
async def test_stream(mistral_client, model, mock_time, agenerator, alist):
421+
mock_time.side_effect = [0, 0.001]
422+
441423
mock_usage = unittest.mock.Mock()
442424
mock_usage.prompt_tokens = 100
443425
mock_usage.completion_tokens = 50
@@ -458,10 +440,8 @@ async def test_stream(mistral_client, model, agenerator, alist):
458440
mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event]))
459441

460442
messages = [{"role": "user", "content": [{"text": "test"}]}]
461-
response = model.stream(messages, None, None)
462-
463-
# Consume the response
464-
await alist(response)
443+
stream = model.stream(messages, None, None)
444+
responses = await alist(stream)
465445

466446
expected_request = {
467447
"model": "mistral-large-latest",
@@ -472,6 +452,21 @@ async def test_stream(mistral_client, model, agenerator, alist):
472452

473453
mistral_client.chat.stream_async.assert_called_once_with(**expected_request)
474454

455+
tru_metadata = responses[-1]
456+
exp_metadata = {
457+
"metadata": {
458+
"usage": {
459+
"inputTokens": 100,
460+
"outputTokens": 50,
461+
"totalTokens": 150,
462+
},
463+
"metrics": {
464+
"latencyMs": 1,
465+
},
466+
},
467+
}
468+
assert tru_metadata == exp_metadata
469+
475470

476471
@pytest.mark.asyncio
477472
async def test_stream_rate_limit_error(mistral_client, model, alist):

0 commit comments

Comments
 (0)