Skip to content

Commit 63f99a5

Browse files
fix(langchain): stop double-counting anthropic cache tokens in prompt totals
langchain-anthropic has folded cache read/creation tokens into usage_metadata input_tokens since 0.2.3 (versions before that don't emit input_token_details at all), and langchain-aws does the same — per the langchain-core UsageMetadata contract, input_token_details is a breakdown of input_tokens, not an addition to it. The cache normalization from #411/#445 detected "separate cache token accounting" by the presence of cache_creation/ephemeral_* detail keys, which langchain-anthropic always emits, so every cached Anthropic call had cache tokens added to prompt_tokens a second time. With a warm cache this roughly doubles reported prompt tokens (e.g. a real trace reported 75,387 prompt tokens for a 37,694-token request with 37,324 cache reads and 369 cache writes). Detect separate accounting arithmetically instead: only fold cache tokens into prompt/total when they exceed the reported prompt total, which is impossible under the UsageMetadata contract but is exactly the inconsistency the original normalization (BT-5150) was added to repair. Strengthen the VCR prompt-caching test to assert span prompt/total tokens equal the usage_metadata the model reported, and add unit coverage for the folded (Anthropic), subset (OpenAI), and separate (legacy) conventions. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
1 parent 82d86b1 commit 63f99a5

2 files changed

Lines changed: 108 additions & 25 deletions

File tree

py/src/braintrust/integrations/langchain/callbacks.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -617,24 +617,6 @@ def _get_model_name_from_response(response: LLMResult) -> str | None:
617617
return model_name
618618

619619

620-
def _cache_tokens_are_separate_from_input_tokens(input_token_details: dict[str, Any]) -> bool:
621-
# LangChain provider packages use different cache-token conventions:
622-
# - OpenAI-style responses report cache reads as a subset of input_tokens.
623-
# - Anthropic-style responses report cache reads/creation separately from input_tokens.
624-
#
625-
# Avoid provider-name checks here so any LangChain integration using the same
626-
# "separate cache tokens" schema gets normalized, while providers that only
627-
# expose cache_read as input-token detail do not get double-counted.
628-
return any(
629-
key in input_token_details
630-
for key in (
631-
"cache_creation",
632-
"ephemeral_5m_input_tokens",
633-
"ephemeral_1h_input_tokens",
634-
)
635-
)
636-
637-
638620
def _get_metrics_from_response(response: LLMResult):
639621
metrics = {}
640622

@@ -685,15 +667,16 @@ def _get_metrics_from_response(response: LLMResult):
685667
completion_tokens = metrics.get("completion_tokens")
686668
total_tokens = metrics.get("total_tokens")
687669
if prompt_tokens is not None and completion_tokens is not None:
688-
if (
689-
cache_tokens
690-
and total_tokens == prompt_tokens + completion_tokens
691-
and _cache_tokens_are_separate_from_input_tokens(input_token_details)
692-
):
670+
# LangChain's UsageMetadata contract makes input_token_details a
671+
# breakdown of input_tokens, so cache tokens already count toward
672+
# the prompt total (langchain-anthropic >= 0.2.3, langchain-aws,
673+
# langchain-openai all comply). Cache tokens exceeding the prompt
674+
# total means the integration reported uncached input only — fold
675+
# cache tokens back in so prompt/total stay internally consistent.
676+
if cache_tokens > prompt_tokens and total_tokens == prompt_tokens + completion_tokens:
693677
prompt_tokens += cache_tokens
694678
metrics["prompt_tokens"] = prompt_tokens
695-
if total_tokens is not None:
696-
metrics["total_tokens"] = total_tokens + cache_tokens
679+
metrics["total_tokens"] = total_tokens + cache_tokens
697680
metrics["tokens"] = prompt_tokens + completion_tokens
698681

699682
if not metrics or not any(metrics.values()):

py/src/braintrust/integrations/langchain/test_callbacks.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import pytest
99
from braintrust import logger
1010
from braintrust.integrations.langchain import BraintrustCallbackHandler
11+
from braintrust.integrations.langchain.callbacks import _get_metrics_from_response
1112
from braintrust.logger import flush
1213
from braintrust.test_helpers import init_test_logger
1314
from langchain_core.callbacks import BaseCallbackHandler
1415
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
16+
from langchain_core.outputs import ChatGeneration, LLMResult
1517
from langchain_core.prompts import ChatPromptTemplate
1618
from langchain_core.prompts.prompt import PromptTemplate
1719
from langchain_core.runnables import RunnableMap, RunnableSerializable
@@ -906,6 +908,94 @@ def test_streaming_ttft(logger_memory_logger):
906908
)
907909

908910

911+
def _single_generation_response(usage_metadata: dict, model_name: str) -> LLMResult:
912+
return LLMResult(
913+
generations=[
914+
[
915+
ChatGeneration(
916+
message=AIMessage(
917+
content="Done",
918+
response_metadata={"model_name": model_name},
919+
usage_metadata=cast(dict, usage_metadata),
920+
)
921+
)
922+
]
923+
]
924+
)
925+
926+
927+
def test_folded_cache_tokens_are_not_double_counted():
928+
# langchain-anthropic >= 0.2.3 folds cache read/creation tokens into
929+
# input_tokens, exposing them via input_token_details as a breakdown.
930+
response = _single_generation_response(
931+
{
932+
"input_tokens": 1095,
933+
"output_tokens": 40,
934+
"total_tokens": 1135,
935+
"input_token_details": {
936+
"cache_read": 0,
937+
"cache_creation": 0,
938+
"ephemeral_5m_input_tokens": 1075,
939+
"ephemeral_1h_input_tokens": 0,
940+
},
941+
},
942+
model_name="claude-sonnet-4-5-20250929",
943+
)
944+
945+
assert _get_metrics_from_response(response) == {
946+
"prompt_tokens": 1095,
947+
"completion_tokens": 40,
948+
"total_tokens": 1135,
949+
"tokens": 1135,
950+
"prompt_cached_tokens": 0,
951+
"prompt_cache_creation_5m_tokens": 1075,
952+
"prompt_cache_creation_1h_tokens": 0,
953+
}
954+
955+
956+
def test_openai_cached_tokens_are_not_folded_into_prompt_tokens():
957+
response = _single_generation_response(
958+
{
959+
"input_tokens": 1000,
960+
"output_tokens": 200,
961+
"total_tokens": 1200,
962+
"input_token_details": {"cache_read": 500},
963+
},
964+
model_name="gpt-4o-mini-2024-07-18",
965+
)
966+
967+
assert _get_metrics_from_response(response) == {
968+
"prompt_tokens": 1000,
969+
"completion_tokens": 200,
970+
"total_tokens": 1200,
971+
"tokens": 1200,
972+
"prompt_cached_tokens": 500,
973+
}
974+
975+
976+
def test_separately_reported_cache_tokens_are_folded_into_prompt_tokens():
977+
# Integrations that report uncached input only make cache tokens exceed
978+
# the prompt total; normalize so prompt/total include cache tokens.
979+
response = _single_generation_response(
980+
{
981+
"input_tokens": 20,
982+
"output_tokens": 40,
983+
"total_tokens": 60,
984+
"input_token_details": {"cache_read": 1000, "cache_creation": 500},
985+
},
986+
model_name="claude-3-5-sonnet-20240620",
987+
)
988+
989+
assert _get_metrics_from_response(response) == {
990+
"prompt_tokens": 1520,
991+
"completion_tokens": 40,
992+
"total_tokens": 1560,
993+
"tokens": 1560,
994+
"prompt_cached_tokens": 1000,
995+
"prompt_cache_creation_tokens": 500,
996+
}
997+
998+
909999
@pytest.mark.vcr
9101000
def test_prompt_caching_tokens(logger_memory_logger):
9111001
from langchain_anthropic import ChatAnthropic
@@ -1098,6 +1188,12 @@ def test_prompt_caching_tokens(logger_memory_logger):
10981188
assert first_metrics["prompt_tokens"] >= first_cache_creation_tokens
10991189
assert first_metrics["tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"]
11001190

1191+
# langchain-anthropic already folds cache read/creation tokens into
1192+
# usage_metadata input_tokens; the callback must not add them again.
1193+
assert res.usage_metadata is not None
1194+
assert first_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"]
1195+
assert first_metrics["total_tokens"] == res.usage_metadata["total_tokens"]
1196+
11011197
second_metrics = None
11021198
for attempt in range(3):
11031199
res = model.invoke(
@@ -1134,6 +1230,10 @@ def test_prompt_caching_tokens(logger_memory_logger):
11341230
assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"]
11351231
assert second_metrics["tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"]
11361232

1233+
assert res.usage_metadata is not None
1234+
assert second_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"]
1235+
assert second_metrics["total_tokens"] == res.usage_metadata["total_tokens"]
1236+
11371237

11381238
@pytest.mark.vcr
11391239
def test_image_input(logger_memory_logger):

0 commit comments

Comments
 (0)