Skip to content

Commit 3406ef4

Browse files
authored
fix(mistral): report usage metrics in streaming mode (#1697)
1 parent 32d703c commit 3406ef4

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

src/strands/models/mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ async def stream(
496496

497497
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
498498

499-
if hasattr(chunk, "usage"):
500-
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
499+
if hasattr(chunk, "data") and hasattr(chunk.data, "usage") and chunk.data.usage:
500+
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage})
501501

502502
except Exception as e:
503503
if "rate" in str(e).lower() or "429" in str(e):

tests/strands/models/test_mistral.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,9 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning
451451
delta=unittest.mock.Mock(content="test stream", tool_calls=None),
452452
finish_reason="end_turn",
453453
)
454-
]
454+
],
455+
usage=mock_usage,
455456
),
456-
usage=mock_usage,
457457
)
458458

459459
mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event]))
@@ -476,6 +476,30 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning
476476
assert len(captured_warnings) == 0
477477

478478

479+
@pytest.mark.asyncio
480+
async def test_stream_no_usage(mistral_client, model, agenerator, alist):
481+
mock_event = unittest.mock.Mock(
482+
data=unittest.mock.Mock(
483+
choices=[
484+
unittest.mock.Mock(
485+
delta=unittest.mock.Mock(content="test stream", tool_calls=None),
486+
finish_reason="end_turn",
487+
)
488+
],
489+
usage=None,
490+
),
491+
)
492+
493+
mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event]))
494+
495+
messages = [{"role": "user", "content": [{"text": "test"}]}]
496+
response = model.stream(messages, None, None)
497+
498+
# Should complete without error and not yield a metadata chunk
499+
chunks = await alist(response)
500+
assert not any("metadata" in c for c in chunks if isinstance(c, dict))
501+
502+
479503
@pytest.mark.asyncio
480504
async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator, alist, captured_warnings):
481505
tool_choice = {"auto": {}}
@@ -492,9 +516,9 @@ async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator
492516
delta=unittest.mock.Mock(content="test stream", tool_calls=None),
493517
finish_reason="end_turn",
494518
)
495-
]
519+
],
520+
usage=mock_usage,
496521
),
497-
usage=mock_usage,
498522
)
499523

500524
mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event]))

tests_integ/models/test_model_mistral.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ async def test_agent_stream_async(agent):
106106

107107
assert all(string in text for string in ["12:00", "sunny"])
108108

109+
assert result.metrics.accumulated_usage is not None
110+
assert result.metrics.accumulated_usage["inputTokens"] > 0
111+
assert result.metrics.accumulated_usage["outputTokens"] > 0
112+
assert result.metrics.accumulated_usage["totalTokens"] > 0
113+
109114

110115
def test_agent_structured_output(non_streaming_agent, weather):
111116
tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")

0 commit comments

Comments
 (0)