diff --git a/src/agents/run.py b/src/agents/run.py index 58eef335e0..eb7fab5c0f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1234,72 +1234,82 @@ async def _run_single_turn_streamed( ) # 1. Stream the output events - async for event in model.stream_response( - filtered.instructions, - filtered.input, - model_settings, - all_tools, - output_schema, - handoffs, - get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ): - # Emit the raw event ASAP - streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) - - if isinstance(event, ResponseCompletedEvent): - usage = ( - Usage( - requests=1, - input_tokens=event.response.usage.input_tokens, - output_tokens=event.response.usage.output_tokens, - total_tokens=event.response.usage.total_tokens, - input_tokens_details=event.response.usage.input_tokens_details, - output_tokens_details=event.response.usage.output_tokens_details, + try: + async for event in model.stream_response( + filtered.instructions, + filtered.input, + model_settings, + all_tools, + output_schema, + handoffs, + get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, + ): + # Emit the raw event ASAP + streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + + if isinstance(event, ResponseCompletedEvent): + usage = ( + Usage( + requests=1, + input_tokens=event.response.usage.input_tokens, + output_tokens=event.response.usage.output_tokens, + total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, + ) + if event.response.usage + else Usage() ) - if event.response.usage - else Usage() - ) - final_response = ModelResponse( - output=event.response.output, - usage=usage, - response_id=event.response.id, - ) - context_wrapper.usage.add(usage) - - if isinstance(event, ResponseOutputItemDoneEvent): - output_item = event.item - - if isinstance(output_item, _TOOL_CALL_TYPES): - call_id: str | None = getattr( - output_item, "call_id", getattr(output_item, "id", None) + final_response = ModelResponse( + output=event.response.output, + usage=usage, + response_id=event.response.id, ) + context_wrapper.usage.add(usage) - if call_id and call_id not in emitted_tool_call_ids: - emitted_tool_call_ids.add(call_id) + if isinstance(event, ResponseOutputItemDoneEvent): + output_item = event.item - tool_item = ToolCallItem( - raw_item=cast(ToolCallItemTypes, output_item), - agent=agent, - ) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=tool_item, name="tool_called") + if isinstance(output_item, _TOOL_CALL_TYPES): + call_id: str | None = getattr( + output_item, "call_id", getattr(output_item, "id", None) ) - elif isinstance(output_item, ResponseReasoningItem): - reasoning_id: str | None = getattr(output_item, "id", None) + if call_id and call_id not in emitted_tool_call_ids: + emitted_tool_call_ids.add(call_id) - if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: - emitted_reasoning_item_ids.add(reasoning_id) + tool_item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, output_item), + agent=agent, + ) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=tool_item, name="tool_called") + ) - reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") - ) + elif isinstance(output_item, ResponseReasoningItem): + reasoning_id: str | None = getattr(output_item, "id", None) + + if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: + emitted_reasoning_item_ids.add(reasoning_id) + + reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent( + item=reasoning_item, name="reasoning_item_created" + ) + ) + except Exception: + # Track that a request was made, even if we didn't get usage info from the response + # This ensures usage tracking is not completely lost when streaming fails + # (Issue #1973) + if final_response is None: + context_wrapper.usage.add(Usage(requests=1)) + raise # Call hook just after the model response is finalized. if final_response is not None: diff --git a/tests/test_usage_tracking_on_error.py b/tests/test_usage_tracking_on_error.py new file mode 100644 index 0000000000..158b675adf --- /dev/null +++ b/tests/test_usage_tracking_on_error.py @@ -0,0 +1,143 @@ +"""Test that usage tracking works correctly when streaming fails. + +This addresses Issue #1973: Usage tracking lost when streaming fails mid-request. +""" + +import pytest + +from agents import Agent, Runner + +from .fake_model import FakeModel + + +@pytest.mark.asyncio +async def test_usage_tracking_requests_on_streaming_error(): + """Test that at least request count is tracked when streaming fails. + + This addresses Issue #1973: When the model raises an error during streaming, + we should track that a request was made, even if token counts are unavailable. + """ + model = FakeModel() + + # Simulate a streaming failure (e.g., context window exceeded, connection drop) + model.set_next_output(RuntimeError("Context window exceeded")) + + agent = Agent( + name="test_agent", + model=model, + ) + + # Run the agent and expect it to fail + with pytest.raises(RuntimeError): + result = Runner.run_streamed(agent, input="Test input that consumes tokens") + async for _ in result.stream_events(): + pass + + # FIXED: Request count should be tracked even when streaming fails + assert result.context_wrapper.usage.requests == 1, "Request count should be tracked on error" + + # Token counts are unavailable when streaming fails before ResponseCompletedEvent + assert result.context_wrapper.usage.input_tokens == 0 + assert result.context_wrapper.usage.output_tokens == 0 + assert result.context_wrapper.usage.total_tokens == 0 + + +@pytest.mark.asyncio +async def test_usage_tracking_preserved_on_success(): + """Test that normal usage tracking still works correctly after the fix. + + This ensures our fix doesn't break the normal case where streaming succeeds. + """ + from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + + from agents.usage import Usage + + from .test_responses import get_text_message + + model = FakeModel() + + # Set custom usage to verify it's tracked correctly + model.set_hardcoded_usage( + Usage( + requests=1, + input_tokens=100, + output_tokens=50, + total_tokens=150, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + ) + ) + + # Simulate successful streaming + model.set_next_output([get_text_message("Success")]) + + agent = Agent( + name="test_agent", + model=model, + ) + + result = Runner.run_streamed(agent, input="Test input") + async for _ in result.stream_events(): + pass + + # Usage should be tracked correctly in the success case + assert result.context_wrapper.usage.requests == 1 + assert result.context_wrapper.usage.input_tokens == 100 + assert result.context_wrapper.usage.output_tokens == 50 + assert result.context_wrapper.usage.total_tokens == 150 + # Note: FakeModel doesn't fully support token_details, so we only test the main counts + + +@pytest.mark.asyncio +async def test_usage_tracking_multi_turn_with_error(): + """Test usage tracking across multiple turns when an error occurs. + + This ensures that usage from successful turns is preserved even when a later turn fails. + """ + import json + + from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + + from agents.usage import Usage + + from .test_responses import get_function_tool, get_function_tool_call + + model = FakeModel() + + # First turn: successful with usage + model.set_hardcoded_usage( + Usage( + requests=1, + input_tokens=100, + output_tokens=50, + total_tokens=150, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ) + ) + + agent = Agent( + name="test_agent", + model=model, + tools=[get_function_tool("test_tool", "tool_result")], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: successful tool call + [get_function_tool_call("test_tool", json.dumps({"arg": "value"}))], + # Second turn: error + RuntimeError("API error on second turn"), + ] + ) + + with pytest.raises(RuntimeError): + result = Runner.run_streamed(agent, input="Test input") + async for _ in result.stream_events(): + pass + + # Usage should include first turn's usage + second turn's request count + assert result.context_wrapper.usage.requests == 2, "Should track both turns" + assert result.context_wrapper.usage.input_tokens == 100, "Should preserve first turn's tokens" + assert result.context_wrapper.usage.output_tokens == 50, "Should preserve first turn's tokens" + assert result.context_wrapper.usage.total_tokens == 150, "Should preserve first turn's tokens"