diff --git a/src/minisweagent/models/portkey_model.py b/src/minisweagent/models/portkey_model.py index 6d23aef00..9950a2eb8 100644 --- a/src/minisweagent/models/portkey_model.py +++ b/src/minisweagent/models/portkey_model.py @@ -167,6 +167,18 @@ def _calculate_cost(self, response) -> dict[str, float]: f"Completion tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}" ) completion_tokens = 0 + if total_tokens is None: + logger.warning( + f"Total tokens are None for model {self.config.model_name}. Setting to sum of prompt and completion tokens. Full response: {response_for_cost_calc.model_dump()}" + ) + + # Defense: prompt_tokens and completion_tokens are guaranteed to be non-None here + # because they were set to 0 above if they were None. But we use get_token() to + # be extra safe in case of unexpected API responses. + def get_token(val): + return val if val is not None else 0 + + total_tokens = get_token(prompt_tokens) + get_token(completion_tokens) if total_tokens - prompt_tokens - completion_tokens != 0: # This is most likely related to how portkey treats cached tokens: It doesn't count them towards the prompt tokens (?) logger.warning( diff --git a/tests/models/test_portkey_model.py b/tests/models/test_portkey_model.py index f8df5dd78..0b0d13d17 100644 --- a/tests/models/test_portkey_model.py +++ b/tests/models/test_portkey_model.py @@ -191,3 +191,52 @@ def test_portkey_model_cost_validation_error(): assert "Error calculating cost" in str(exc_info.value) assert "MSWEA_COST_TRACKING='ignore_errors'" in str(exc_info.value) + + +def test_portkey_model_total_tokens_none(): + """Test that None total_tokens is handled gracefully in cost calculation.""" + mock_portkey_class = MagicMock() + mock_client = MagicMock() + mock_response = MagicMock() + mock_choice = MagicMock() + mock_message = MagicMock() + mock_usage = MagicMock() + mock_tool_call = MagicMock() + + mock_tool_call.id = "call_999" + mock_tool_call.function.name = "bash" + mock_tool_call.function.arguments = json.dumps({"command": "echo test"}) + mock_message.tool_calls = [mock_tool_call] + mock_message.content = None + mock_message.model_dump.return_value = { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "call_999", "function": {"name": "bash", "arguments": '{"command": "echo test"}'}}], + } + mock_choice.message = mock_message + mock_response.choices = [mock_choice] + mock_response.model_dump.return_value = {"test": "response"} + mock_response.model_copy.return_value = mock_response + mock_response.usage = mock_usage + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + # Simulate the bug: total_tokens is None + mock_usage.total_tokens = None + + mock_client.chat.completions.create.return_value = mock_response + mock_portkey_class.return_value = mock_client + + with patch("minisweagent.models.portkey_model.Portkey", mock_portkey_class): + with patch.dict(os.environ, {"PORTKEY_API_KEY": "test-key"}): + with patch("minisweagent.models.portkey_model.litellm.cost_calculator.completion_cost") as mock_cost: + mock_cost.return_value = 0.01 + + model = PortkeyModel(model_name="gpt-4o") + messages = [{"role": "user", "content": "test"}] + + # This should not raise a TypeError + result = model.query(messages) + + # Verify the result is still correct + assert result["extra"]["actions"] == [{"command": "echo test", "tool_call_id": "call_999"}] + assert result["extra"]["cost"] == 0.01