Skip to content

Commit de28291

Browse files
authored
fix(langchain): avoid double-counting cached input tokens (#445)
Only fold cache tokens into prompt and total metrics when LangChain reports cache tokens separately from input tokens. This preserves Anthropic-style cache normalization while avoiding double-counting for OpenAI-style responses where cached tokens are already included in input_tokens. Adds regression coverage for OpenAI cached token metrics. resolves https://linear.app/braintrustdata/issue/BT-5310/langchain-callback-double-counts-cache-tokens-for-openai-after-pr-411
1 parent 3651b0b commit de28291

2 files changed

Lines changed: 57 additions & 4 deletions

File tree

py/src/braintrust/integrations/langchain/callbacks.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,24 @@ def _get_model_name_from_response(response: LLMResult) -> str | None:
617617
return model_name
618618

619619

620+
def _cache_tokens_are_separate_from_input_tokens(input_token_details: dict[str, Any]) -> bool:
621+
# LangChain provider packages use different cache-token conventions:
622+
# - OpenAI-style responses report cache reads as a subset of input_tokens.
623+
# - Anthropic-style responses report cache reads/creation separately from input_tokens.
624+
#
625+
# Avoid provider-name checks here so any LangChain integration using the same
626+
# "separate cache tokens" schema gets normalized, while providers that only
627+
# expose cache_read as input-token detail do not get double-counted.
628+
return any(
629+
key in input_token_details
630+
for key in (
631+
"cache_creation",
632+
"ephemeral_5m_input_tokens",
633+
"ephemeral_1h_input_tokens",
634+
)
635+
)
636+
637+
620638
def _get_metrics_from_response(response: LLMResult):
621639
metrics = {}
622640

@@ -646,10 +664,14 @@ def _get_metrics_from_response(response: LLMResult):
646664
# langchain-anthropic >= 1.4.0 maps cache_creation_input_tokens to
647665
# ephemeral tier fields (ephemeral_5m_input_tokens, ephemeral_1h_input_tokens)
648666
# rather than the top-level cache_creation field. Sum both for compat.
649-
cache_creation = input_token_details.get("cache_creation") or (
650-
input_token_details.get("ephemeral_5m_input_tokens", 0)
651-
+ input_token_details.get("ephemeral_1h_input_tokens", 0)
652-
)
667+
cache_creation = input_token_details.get("cache_creation")
668+
if not cache_creation and (
669+
"ephemeral_5m_input_tokens" in input_token_details
670+
or "ephemeral_1h_input_tokens" in input_token_details
671+
):
672+
cache_creation = input_token_details.get("ephemeral_5m_input_tokens", 0) + input_token_details.get(
673+
"ephemeral_1h_input_tokens", 0
674+
)
653675

654676
if cache_read is not None:
655677
metrics["prompt_cached_tokens"] = cache_read
@@ -665,6 +687,7 @@ def _get_metrics_from_response(response: LLMResult):
665687
and prompt_tokens is not None
666688
and completion_tokens is not None
667689
and total_tokens == prompt_tokens + completion_tokens
690+
and _cache_tokens_are_separate_from_input_tokens(input_token_details)
668691
):
669692
metrics["prompt_tokens"] = prompt_tokens + cache_tokens
670693
metrics["total_tokens"] = total_tokens + cache_tokens

py/src/braintrust/integrations/langchain/test_callbacks.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import pytest
99
from braintrust import logger
1010
from braintrust.integrations.langchain import BraintrustCallbackHandler
11+
from braintrust.integrations.langchain.callbacks import _get_metrics_from_response
1112
from braintrust.logger import flush
1213
from braintrust.test_helpers import init_test_logger
1314
from langchain_core.callbacks import BaseCallbackHandler
1415
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
16+
from langchain_core.outputs import ChatGeneration, LLMResult
1517
from langchain_core.prompts import ChatPromptTemplate
1618
from langchain_core.prompts.prompt import PromptTemplate
1719
from langchain_core.runnables import RunnableMap, RunnableSerializable
@@ -906,6 +908,34 @@ def test_streaming_ttft(logger_memory_logger):
906908
)
907909

908910

911+
def test_openai_cached_tokens_are_not_folded_into_prompt_tokens():
912+
response = LLMResult(
913+
generations=[
914+
[
915+
ChatGeneration(
916+
message=AIMessage(
917+
content="Done",
918+
response_metadata={"model_name": "gpt-4o-mini-2024-07-18"},
919+
usage_metadata={
920+
"input_tokens": 1000,
921+
"output_tokens": 200,
922+
"total_tokens": 1200,
923+
"input_token_details": {"cache_read": 500},
924+
},
925+
)
926+
)
927+
]
928+
]
929+
)
930+
931+
assert _get_metrics_from_response(response) == {
932+
"prompt_tokens": 1000,
933+
"completion_tokens": 200,
934+
"total_tokens": 1200,
935+
"prompt_cached_tokens": 500,
936+
}
937+
938+
909939
@pytest.mark.vcr
910940
def test_prompt_caching_tokens(logger_memory_logger):
911941
from langchain_anthropic import ChatAnthropic

0 commit comments

Comments
 (0)