Skip to content

Commit bf9dfaf

Browse files
AbhiPrasadbhaveshklaviyoclaude
authored
fix(langchain): stop double-counting anthropic cache tokens in prompt totals (#510)
supercedes #504 see https://github.com/braintrustdata/braintrust-spec/blob/main/docs/features/prompt-cache.md --------- Co-authored-by: Bhavesh Pareek <bhavesh.pareek@klaviyo.com> Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
1 parent 5afb9e2 commit bf9dfaf

2 files changed

Lines changed: 18 additions & 25 deletions

File tree

py/src/braintrust/integrations/langchain/callbacks.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -617,24 +617,6 @@ def _get_model_name_from_response(response: LLMResult) -> str | None:
617617
return model_name
618618

619619

620-
def _cache_tokens_are_separate_from_input_tokens(input_token_details: dict[str, Any]) -> bool:
621-
# LangChain provider packages use different cache-token conventions:
622-
# - OpenAI-style responses report cache reads as a subset of input_tokens.
623-
# - Anthropic-style responses report cache reads/creation separately from input_tokens.
624-
#
625-
# Avoid provider-name checks here so any LangChain integration using the same
626-
# "separate cache tokens" schema gets normalized, while providers that only
627-
# expose cache_read as input-token detail do not get double-counted.
628-
return any(
629-
key in input_token_details
630-
for key in (
631-
"cache_creation",
632-
"ephemeral_5m_input_tokens",
633-
"ephemeral_1h_input_tokens",
634-
)
635-
)
636-
637-
638620
def _get_metrics_from_response(response: LLMResult):
639621
metrics = {}
640622

@@ -685,15 +667,16 @@ def _get_metrics_from_response(response: LLMResult):
685667
completion_tokens = metrics.get("completion_tokens")
686668
total_tokens = metrics.get("total_tokens")
687669
if prompt_tokens is not None and completion_tokens is not None:
688-
if (
689-
cache_tokens
690-
and total_tokens == prompt_tokens + completion_tokens
691-
and _cache_tokens_are_separate_from_input_tokens(input_token_details)
692-
):
670+
# LangChain's UsageMetadata contract makes input_token_details a
671+
# breakdown of input_tokens, so cache tokens already count toward
672+
# the prompt total (langchain-anthropic >= 0.2.3, langchain-aws,
673+
# langchain-openai all comply). Cache tokens exceeding the prompt
674+
# total means the integration reported uncached input only — fold
675+
# cache tokens back in so prompt/total stay internally consistent.
676+
if cache_tokens > prompt_tokens and total_tokens == prompt_tokens + completion_tokens:
693677
prompt_tokens += cache_tokens
694678
metrics["prompt_tokens"] = prompt_tokens
695-
if total_tokens is not None:
696-
metrics["total_tokens"] = total_tokens + cache_tokens
679+
metrics["total_tokens"] = total_tokens + cache_tokens
697680
metrics["tokens"] = prompt_tokens + completion_tokens
698681

699682
if not metrics or not any(metrics.values()):

py/src/braintrust/integrations/langchain/test_callbacks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,12 @@ def test_prompt_caching_tokens(logger_memory_logger):
10981098
assert first_metrics["prompt_tokens"] >= first_cache_creation_tokens
10991099
assert first_metrics["tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"]
11001100

1101+
# langchain-anthropic already folds cache read/creation tokens into
1102+
# usage_metadata input_tokens; the callback must not add them again.
1103+
assert res.usage_metadata is not None
1104+
assert first_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"]
1105+
assert first_metrics["total_tokens"] == res.usage_metadata["total_tokens"]
1106+
11011107
second_metrics = None
11021108
for attempt in range(3):
11031109
res = model.invoke(
@@ -1134,6 +1140,10 @@ def test_prompt_caching_tokens(logger_memory_logger):
11341140
assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"]
11351141
assert second_metrics["tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"]
11361142

1143+
assert res.usage_metadata is not None
1144+
assert second_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"]
1145+
assert second_metrics["total_tokens"] == res.usage_metadata["total_tokens"]
1146+
11371147

11381148
@pytest.mark.vcr
11391149
def test_image_input(logger_memory_logger):

0 commit comments

Comments
 (0)