Skip to content

Commit d026d1b

Browse files
authored
Merge pull request #1379 from Kiln-AI/leonard/kil-606-fix-usage-cost-summing-in-chat-multi-turn-conversations
feat: include usage in each message
2 parents 309d759 + aadb21c commit d026d1b

16 files changed

Lines changed: 1492 additions & 97 deletions

app/web_ui/src/lib/api_schema.d.ts

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3732,6 +3732,7 @@ export interface components {
37323732
tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][];
37333733
/** Latency Ms */
37343734
latency_ms?: number | null;
3735+
usage?: components["schemas"]["MessageUsage"] | null;
37353736
};
37363737
/**
37373738
* ChatCompletionAssistantMessageParamWrapper
@@ -3762,6 +3763,7 @@ export interface components {
37623763
tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][];
37633764
/** Latency Ms */
37643765
latency_ms?: number | null;
3766+
usage?: components["schemas"]["MessageUsage"] | null;
37653767
};
37663768
/** ChatCompletionContentPartImageParam */
37673769
ChatCompletionContentPartImageParam: {
@@ -7226,6 +7228,47 @@ export interface components {
72267228
*/
72277229
mean_total_llm_latency_ms?: number | null;
72287230
};
7231+
/**
7232+
* MessageUsage
7233+
* @description Token usage and cost for a single LLM call or a multi-message sum.
7234+
*
7235+
* Carries only the fields that are meaningfully aggregatable across
7236+
* messages: token counts and cost. Per-call latency lives on the
7237+
* individual message's ``latency_ms`` field; aggregating it across the
7238+
* full trace would mix latencies from different points in time, so
7239+
* ``MessageUsage`` does NOT carry ``total_llm_latency_ms``.
7240+
*
7241+
* The :class:`Usage` subclass adds ``total_llm_latency_ms`` for the
7242+
* in-flight per-run accumulator that tracks how long this run spent
7243+
* waiting on LLM calls.
7244+
*/
7245+
MessageUsage: {
7246+
/**
7247+
* Input Tokens
7248+
* @description The number of input tokens used.
7249+
*/
7250+
input_tokens?: number | null;
7251+
/**
7252+
* Output Tokens
7253+
* @description The number of output tokens used.
7254+
*/
7255+
output_tokens?: number | null;
7256+
/**
7257+
* Total Tokens
7258+
* @description The total number of tokens used.
7259+
*/
7260+
total_tokens?: number | null;
7261+
/**
7262+
* Cost
7263+
* @description The cost in US dollars, saved at runtime (prices can change over time).
7264+
*/
7265+
cost?: number | null;
7266+
/**
7267+
* Cached Tokens
7268+
* @description Number of tokens served from prompt cache. None if not reported.
7269+
*/
7270+
cached_tokens?: number | null;
7271+
};
72297272
/** ModelDetails */
72307273
ModelDetails: {
72317274
/** Id */
@@ -9599,6 +9642,8 @@ export interface components {
95999642
tags: string[];
96009643
/** @description Usage information for the task run. This includes the number of input tokens, output tokens, and total tokens used. */
96019644
usage?: components["schemas"]["Usage"] | null;
9645+
/** @description Sum of per-message token usage and cost across the entire trace, including any seeded prior trace. None on records created before this field existed. For a fresh (non-seeded) run, the token / cost fields equal those of `usage`. */
9646+
cumulative_usage?: components["schemas"]["MessageUsage"] | null;
96029647
/**
96039648
* Trace
96049649
* @description The trace of the task run in OpenAI format. This is the list of messages that were sent to/from the model.
@@ -9676,6 +9721,8 @@ export interface components {
96769721
tags: string[];
96779722
/** @description Usage information for the task run. This includes the number of input tokens, output tokens, and total tokens used. */
96789723
usage?: components["schemas"]["Usage"] | null;
9724+
/** @description Sum of per-message token usage and cost across the entire trace, including any seeded prior trace. None on records created before this field existed. For a fresh (non-seeded) run, the token / cost fields equal those of `usage`. */
9725+
cumulative_usage?: components["schemas"]["MessageUsage"] | null;
96799726
/**
96809727
* Trace
96819728
* @description The trace of the task run in OpenAI format. This is the list of messages that were sent to/from the model.
@@ -10184,27 +10231,33 @@ export interface components {
1018410231
};
1018510232
/**
1018610233
* Usage
10187-
* @description Token usage and cost information for a task run.
10234+
* @description Token usage, cost, and aggregate LLM latency for a per-run accumulator.
10235+
*
10236+
* Extends :class:`MessageUsage` with ``total_llm_latency_ms``, which is
10237+
* only meaningful while a single run is in flight (its model calls run
10238+
* sequentially in real time). For per-message records and full-trace
10239+
* sums use :class:`MessageUsage` — those values would mix latencies
10240+
* from different points in time, so the field doesn't apply.
1018810241
*/
1018910242
Usage: {
1019010243
/**
1019110244
* Input Tokens
10192-
* @description The number of input tokens used in the task run.
10245+
* @description The number of input tokens used.
1019310246
*/
1019410247
input_tokens?: number | null;
1019510248
/**
1019610249
* Output Tokens
10197-
* @description The number of output tokens used in the task run.
10250+
* @description The number of output tokens used.
1019810251
*/
1019910252
output_tokens?: number | null;
1020010253
/**
1020110254
* Total Tokens
10202-
* @description The total number of tokens used in the task run.
10255+
* @description The total number of tokens used.
1020310256
*/
1020410257
total_tokens?: number | null;
1020510258
/**
1020610259
* Cost
10207-
* @description The cost of the task run in US dollars, saved at runtime (prices can change over time).
10260+
* @description The cost in US dollars, saved at runtime (prices can change over time).
1020810261
*/
1020910262
cost?: number | null;
1021010263
/**

libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ class StreamingCompletion:
3030
def __init__(self, *args: Any, **kwargs: Any) -> None:
3131
kwargs = dict(kwargs)
3232
kwargs.pop("stream", None)
33+
# LiteLLM's streaming responses don't include a usage block by
34+
# default — ``stream_options={"include_usage": True}`` is required
35+
# for the final assembled ModelResponse to carry token counts (and
36+
# downstream cost). Force it on; merge with caller-provided
37+
# ``stream_options`` without clobbering unrelated keys, but always
38+
# override ``include_usage`` since usage tracking is mandatory.
39+
caller_stream_options = kwargs.get("stream_options") or {}
40+
kwargs["stream_options"] = {
41+
**caller_stream_options,
42+
"include_usage": True,
43+
}
3344
self._args = args
3445
self._kwargs = kwargs
3546
self._response: Optional[Union[ModelResponse, TextCompletionResponse]] = None

libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,70 @@ async def test_empty_stream(self, mock_acompletion, mock_chunk_builder):
137137

138138
assert received == []
139139
assert stream.response is None
140+
141+
async def test_stream_options_include_usage_added_by_default(
142+
self, mock_acompletion, mock_chunk_builder
143+
):
144+
"""Default ``stream_options`` must request usage so the assembled response carries it."""
145+
mock_acompletion.return_value = _async_iter([])
146+
mock_chunk_builder.return_value = None
147+
148+
stream = StreamingCompletion(model="test", messages=[])
149+
async for _ in stream:
150+
pass
151+
152+
_, call_kwargs = mock_acompletion.call_args
153+
assert call_kwargs["stream_options"] == {"include_usage": True}
154+
155+
async def test_stream_options_include_usage_merged_with_caller_options(
156+
self, mock_acompletion, mock_chunk_builder
157+
):
158+
"""Caller-provided ``stream_options`` keys must survive; ``include_usage`` is forced on."""
159+
mock_acompletion.return_value = _async_iter([])
160+
mock_chunk_builder.return_value = None
161+
162+
stream = StreamingCompletion(
163+
model="test",
164+
messages=[],
165+
stream_options={"some_other_flag": True},
166+
)
167+
async for _ in stream:
168+
pass
169+
170+
_, call_kwargs = mock_acompletion.call_args
171+
assert call_kwargs["stream_options"] == {
172+
"some_other_flag": True,
173+
"include_usage": True,
174+
}
175+
176+
async def test_caller_provided_include_usage_false_is_overridden(
177+
self, mock_acompletion, mock_chunk_builder
178+
):
179+
"""Streaming usage tracking is mandatory; an explicit ``False`` is overridden."""
180+
mock_acompletion.return_value = _async_iter([])
181+
mock_chunk_builder.return_value = None
182+
183+
stream = StreamingCompletion(
184+
model="test",
185+
messages=[],
186+
stream_options={"include_usage": False},
187+
)
188+
async for _ in stream:
189+
pass
190+
191+
_, call_kwargs = mock_acompletion.call_args
192+
assert call_kwargs["stream_options"]["include_usage"] is True
193+
194+
async def test_stream_options_none_treated_as_empty(
195+
self, mock_acompletion, mock_chunk_builder
196+
):
197+
"""Passing ``stream_options=None`` must not crash — treated as empty."""
198+
mock_acompletion.return_value = _async_iter([])
199+
mock_chunk_builder.return_value = None
200+
201+
stream = StreamingCompletion(model="test", messages=[], stream_options=None)
202+
async for _ in stream:
203+
pass
204+
205+
_, call_kwargs = mock_acompletion.call_args
206+
assert call_kwargs["stream_options"] == {"include_usage": True}

libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ToolCallEventType,
2020
)
2121
from kiln_ai.adapters.run_output import RunOutput
22-
from kiln_ai.datamodel import Usage
22+
from kiln_ai.datamodel import MessageUsage, Usage
2323

2424
if TYPE_CHECKING:
2525
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
@@ -63,7 +63,12 @@ def __init__(
6363
self._top_logprobs = top_logprobs
6464
self._result: AdapterStreamResult | None = None
6565
self._iterated = False
66+
# Per-LLM-call latency / usage, keyed by index in the messages list.
67+
# Mirrors the non-streaming adapter — we don't own the LiteLLM
68+
# message objects, so we accumulate side-channel state and attach
69+
# it during ``all_messages_to_trace`` at finalization.
6670
self._message_latency: dict[int, int] = {}
71+
self._message_usage: dict[int, MessageUsage] = {}
6772

6873
@property
6974
def result(self) -> AdapterStreamResult:
@@ -134,7 +139,7 @@ async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]:
134139
raise RuntimeError(f"assistant message is not a string: {prior_output}")
135140

136141
trace = self._adapter.all_messages_to_trace(
137-
self._messages, self._message_latency
142+
self._messages, self._message_latency, self._message_usage
138143
)
139144
self._result = AdapterStreamResult(
140145
run_output=RunOutput(
@@ -170,7 +175,8 @@ async def _stream_model_turn(
170175
call_latency_ms = int((time.monotonic() - start) * 1000)
171176

172177
response, response_choice = _validate_response(stream.response)
173-
usage += self._adapter.usage_from_response(response)
178+
call_usage = self._adapter.usage_from_response(response)
179+
usage += call_usage
174180
usage.total_llm_latency_ms = (
175181
usage.total_llm_latency_ms or 0
176182
) + call_latency_ms
@@ -184,6 +190,7 @@ async def _stream_model_turn(
184190

185191
self._messages.append(response_choice.message)
186192
self._message_latency[len(self._messages) - 1] = call_latency_ms
193+
self._message_usage[len(self._messages) - 1] = call_usage
187194

188195
if tool_calls and len(tool_calls) > 0:
189196
# Check for return_on_tool_call BEFORE processing

libs/core/kiln_ai/adapters/model_adapters/base_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from kiln_ai.datamodel import (
3939
DataSource,
4040
DataSourceType,
41+
MessageUsage,
4142
Task,
4243
TaskOutput,
4344
TaskRun,
@@ -683,6 +684,7 @@ def generate_run(
683684
tags=self.base_adapter_config.default_tags or [],
684685
usage=usage,
685686
trace=trace,
687+
cumulative_usage=MessageUsage.from_trace(trace),
686688
)
687689

688690
def _properties_for_task_output(self) -> Dict[str, str | int | float]:

0 commit comments

Comments
 (0)