Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions langfuse/callback/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,12 +1096,18 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]):
for key, value in input_token_details.items():
usage_model[f"input_{key}"] = value

if "input" in usage_model:
usage_model["input"] -= value
Comment thread
hassiebp marked this conversation as resolved.
Outdated

if "output_token_details" in usage_model:
output_token_details = usage_model.pop("output_token_details", {})

for key, value in output_token_details.items():
usage_model[f"output_{key}"] = value

if "output" in usage_model:
usage_model["output"] -= value
Comment thread
hassiebp marked this conversation as resolved.
Outdated

return usage_model if usage_model else None


Expand Down
46 changes: 46 additions & 0 deletions tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,3 +2318,49 @@ def call_model(state: MessagesState):
assert observation.level == "DEFAULT"

assert hidden_count > 0


def test_cached_token_usage():
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
(
"This is a test prompt to reproduce the issue. "
"The prompt needs 1024 tokens to enable cache." * 100
),
),
("user", "Reply to this message {test_param}."),
]
)
chat = ChatOpenAI(model="gpt-4o-mini")
chain = prompt | chat
handler = CallbackHandler()
config = {"callbacks": [handler]} if handler else {}
Comment thread
hassiebp marked this conversation as resolved.
Outdated

chain.invoke({"test_param": "in a funny way"}, config)

# invoke again to force cached token usage
chain.invoke({"test_param": "in a funny way"}, config)

handler.flush()

trace = get_api().trace.get(handler.get_trace_id())

generation = next((o for o in trace.observations if o.type == "GENERATION"))
Comment thread
hassiebp marked this conversation as resolved.

assert generation.usage_details["input_cache_read"] > 0
assert (
generation.usage_details["input"]
+ generation.usage_details["input_cache_read"]
+ generation.usage_details["output"]
== generation.usage_details["total"]
)

assert generation.cost_details["input_cache_read"] > 0
assert (
generation.cost_details["input"]
+ generation.cost_details["input_cache_read"]
+ generation.cost_details["output"]
== generation.cost_details["total"]
)