diff --git a/sdks/python/src/opik/integrations/bedrock/invoke_model/chunks_aggregator/claude.py b/sdks/python/src/opik/integrations/bedrock/invoke_model/chunks_aggregator/claude.py index 72fbdeb6126..db0780470c4 100644 --- a/sdks/python/src/opik/integrations/bedrock/invoke_model/chunks_aggregator/claude.py +++ b/sdks/python/src/opik/integrations/bedrock/invoke_model/chunks_aggregator/claude.py @@ -36,6 +36,8 @@ def aggregate(self, items: List[Dict[str, Any]]) -> Dict[str, Any]: stop_reason = None input_tokens = 0 output_tokens = 0 + cache_creation_input_tokens = 0 + cache_read_input_tokens = 0 for item in items: if "chunk" not in item: @@ -51,6 +53,10 @@ def aggregate(self, items: List[Dict[str, Any]]) -> Dict[str, Any]: usage = message.get("usage", {}) input_tokens = usage.get("input_tokens", 0) output_tokens = usage.get("output_tokens", 0) + cache_creation_input_tokens = usage.get( + "cache_creation_input_tokens", 0 + ) + cache_read_input_tokens = usage.get("cache_read_input_tokens", 0) LOGGER.debug( "Claude message_start: input_tokens=%d, output_tokens=%d", input_tokens, @@ -109,7 +115,12 @@ def aggregate(self, items: List[Dict[str, Any]]) -> Dict[str, Any]: # Convert to Bedrock usage format using shared converter bedrock_usage = usage_converters.anthropic_to_bedrock_usage( - {"input_tokens": input_tokens, "output_tokens": output_tokens} + { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cache_creation_input_tokens": cache_creation_input_tokens, + "cache_read_input_tokens": cache_read_input_tokens, + } ) # Return Claude's native format with Bedrock usage diff --git a/sdks/python/tests/unit/integrations/__init__.py b/sdks/python/tests/unit/integrations/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdks/python/tests/unit/integrations/bedrock/__init__.py b/sdks/python/tests/unit/integrations/bedrock/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdks/python/tests/unit/integrations/bedrock/test_claude_aggregator.py b/sdks/python/tests/unit/integrations/bedrock/test_claude_aggregator.py new file mode 100644 index 00000000000..18a7bbf354c --- /dev/null +++ b/sdks/python/tests/unit/integrations/bedrock/test_claude_aggregator.py @@ -0,0 +1,86 @@ +import json + + +from opik.integrations.bedrock.invoke_model.chunks_aggregator.claude import ( + ClaudeAggregator, +) + + +def _make_chunk(data: dict) -> dict: + return {"chunk": {"bytes": json.dumps(data).encode()}} + + +def test_claude_aggregator__cache_tokens_in_message_start__included_in_usage(): + aggregator = ClaudeAggregator() + + chunks = [ + _make_chunk( + { + "type": "message_start", + "message": { + "role": "assistant", + "usage": { + "input_tokens": 100, + "output_tokens": 0, + "cache_creation_input_tokens": 50, + "cache_read_input_tokens": 200, + }, + }, + } + ), + _make_chunk( + { + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "Hello"}, + } + ), + _make_chunk({"type": "content_block_stop"}), + _make_chunk( + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 5}, + } + ), + _make_chunk({"type": "message_stop", "amazon-bedrock-invocationMetrics": {}}), + ] + + result = aggregator.aggregate(chunks) + + assert result["usage"]["cacheWriteInputTokens"] == 50 + assert result["usage"]["cacheReadInputTokens"] == 200 + assert result["usage"]["inputTokens"] == 100 + assert result["usage"]["outputTokens"] == 5 + + +def test_claude_aggregator__no_cache_tokens__defaults_to_zero(): + aggregator = ClaudeAggregator() + + chunks = [ + _make_chunk( + { + "type": "message_start", + "message": { + "role": "assistant", + "usage": { + "input_tokens": 10, + "output_tokens": 0, + }, + }, + } + ), + _make_chunk({"type": "content_block_stop"}), + _make_chunk( + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 3}, + } + ), + _make_chunk({"type": "message_stop", "amazon-bedrock-invocationMetrics": {}}), + ] + + result = aggregator.aggregate(chunks) + + assert result["usage"]["cacheWriteInputTokens"] == 0 + assert result["usage"]["cacheReadInputTokens"] == 0