diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index 6b9170c72..90be28323 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -3732,6 +3732,8 @@ export interface components { tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][]; /** Latency Ms */ latency_ms?: number | null; + /** Usage */ + usage?: components["schemas"]["Usage"] | null; }; /** * ChatCompletionAssistantMessageParamWrapper @@ -3762,6 +3764,8 @@ export interface components { tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][]; /** Latency Ms */ latency_ms?: number | null; + /** Usage */ + usage?: components["schemas"]["Usage"] | null; }; /** ChatCompletionContentPartImageParam */ ChatCompletionContentPartImageParam: { diff --git a/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py index 5e89eacc7..2e987d20c 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py +++ b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py @@ -20,6 +20,7 @@ ) from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import Usage +from kiln_ai.datamodel.usage import record_per_call_usage_and_latency if TYPE_CHECKING: from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter @@ -64,6 +65,7 @@ def __init__( self._result: AdapterStreamResult | None = None self._iterated = False self._message_latency: dict[int, int] = {} + self._message_usage: dict[int, Usage] = {} @property def result(self) -> AdapterStreamResult: @@ -134,7 +136,7 @@ async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]: raise RuntimeError(f"assistant message is not a string: {prior_output}") trace = self._adapter.all_messages_to_trace( - self._messages, self._message_latency + self._messages, self._message_latency, self._message_usage ) self._result = AdapterStreamResult( run_output=RunOutput( @@ -170,10 +172,6 @@ async def _stream_model_turn( call_latency_ms = int((time.monotonic() - start) * 1000) response, response_choice = _validate_response(stream.response) - usage += self._adapter.usage_from_response(response) - usage.total_llm_latency_ms = ( - usage.total_llm_latency_ms or 0 - ) + call_latency_ms content = response_choice.message.content tool_calls = response_choice.message.tool_calls @@ -183,7 +181,14 @@ async def _stream_model_turn( ) self._messages.append(response_choice.message) - self._message_latency[len(self._messages) - 1] = call_latency_ms + usage = record_per_call_usage_and_latency( + call_usage=self._adapter.usage_from_response(response), + call_latency_ms=call_latency_ms, + turn_usage=usage, + message_index=len(self._messages) - 1, + message_latency=self._message_latency, + message_usage=self._message_usage, + ) if tool_calls and len(tool_calls) > 0: # Check for return_on_tool_call BEFORE processing diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index d51cafd14..d3fe6ceb9 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -44,6 +44,7 @@ KilnAgentRunConfigProperties, as_kiln_agent_run_config, ) +from kiln_ai.datamodel.usage import record_per_call_usage_and_latency from kiln_ai.tools.base_tool import ( KilnToolInterface, ToolCallContext, @@ -82,6 +83,14 @@ class ModelTurnResult: usage: Usage interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None message_latency: dict[int, int] | None = None + message_usage: dict[int, Usage] | None = None + """Per-assistant-message token usage, keyed by index in ``all_messages``. + + Threaded the same way as ``message_latency`` so traces can carry the + usage of every individual inference call — including inner tool-loop + iterations within a single turn that get aggregated into ``usage`` + above. + """ class LiteLlmAdapter(BaseAdapter): @@ -126,9 +135,10 @@ async def _run_model_turn( usage = Usage() messages = list(prior_messages) tool_calls_count = 0 - # LLM call latency in ms, keyed by index in the messages list. + # LLM call latency in ms + usage, keyed by index in the messages list. # Kept separate because we don't own the LiteLLM message objects. message_latency: dict[int, int] = {} + message_usage: dict[int, Usage] = {} while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: # Build completion kwargs for tool calls @@ -147,12 +157,6 @@ async def _run_model_turn( ) call_latency_ms = int((time.monotonic() - start) * 1000) - # count the usage - usage += self.usage_from_response(model_response) - usage.total_llm_latency_ms = ( - usage.total_llm_latency_ms or 0 - ) + call_latency_ms - # Extract content and tool calls if not hasattr(response_choice, "message"): raise ValueError("Response choice has no message") @@ -165,7 +169,16 @@ async def _run_model_turn( # Add message to messages, so it can be used in the next turn messages.append(response_choice.message) - message_latency[len(messages) - 1] = call_latency_ms + # Aggregate per-call usage + latency onto the turn total and + # stamp them onto the per-message dicts for the trace. + usage = record_per_call_usage_and_latency( + call_usage=self.usage_from_response(model_response), + call_latency_ms=call_latency_ms, + turn_usage=usage, + message_index=len(messages) - 1, + message_latency=message_latency, + message_usage=message_usage, + ) # Process tool calls if any if tool_calls and len(tool_calls) > 0: @@ -188,6 +201,7 @@ async def _run_model_turn( usage=usage, interrupted_by_tool_calls=standard_tool_calls, message_latency=message_latency, + message_usage=message_usage, ) # otherwise: process tool calls internally until final output @@ -208,6 +222,7 @@ async def _run_model_turn( model_choice=response_choice, usage=usage, message_latency=message_latency, + message_usage=message_usage, ) # If there were tool calls, increment counter and continue @@ -224,6 +239,7 @@ async def _run_model_turn( model_choice=response_choice, usage=usage, message_latency=message_latency, + message_usage=message_usage, ) # If we get here with no content and no tool calls, break @@ -256,6 +272,7 @@ async def _run( final_choice: Choices | None = None turns = 0 message_latency: dict[int, int] = {} + message_usage: dict[int, Usage] = {} # Same loop for both fresh runs and prior_trace continuation. # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). @@ -288,6 +305,8 @@ async def _run( usage += turn_result.usage if turn_result.message_latency: message_latency.update(turn_result.message_latency) + if turn_result.message_usage: + message_usage.update(turn_result.message_usage) prior_output = turn_result.assistant_message messages = turn_result.all_messages @@ -295,7 +314,9 @@ async def _run( # Check if we were interrupted by tool calls if turn_result.interrupted_by_tool_calls: - trace = self.all_messages_to_trace(messages, message_latency) + trace = self.all_messages_to_trace( + messages, message_latency, message_usage + ) intermediate_outputs = chat_formatter.intermediate_outputs() output = RunOutput( output=prior_output or "", @@ -319,7 +340,7 @@ async def _run( if not isinstance(prior_output, str): raise RuntimeError(f"assistant message is not a string: {prior_output}") - trace = self.all_messages_to_trace(messages, message_latency) + trace = self.all_messages_to_trace(messages, message_latency, message_usage) output = RunOutput( output=prior_output, intermediate_outputs=intermediate_outputs, @@ -878,6 +899,7 @@ def litellm_message_to_trace_message( self, raw_message: LiteLLMMessage, latency_ms: int | None = None, + usage: Usage | None = None, ) -> ChatCompletionAssistantMessageParamWrapper: """ Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper @@ -919,6 +941,9 @@ def litellm_message_to_trace_message( if latency_ms is not None: message["latency_ms"] = latency_ms + if usage is not None: + message["usage"] = usage + if not message.get("content") and not message.get("tool_calls"): raise ValueError( "Model returned an assistant message, but no content or tool calls. This is not supported." @@ -930,6 +955,7 @@ def all_messages_to_trace( self, messages: list[ChatCompletionMessageIncludingLiteLLM], message_latency: dict[int, int] | None = None, + message_usage: dict[int, Usage] | None = None, ) -> list[ChatCompletionMessageParam]: """ Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. @@ -938,7 +964,10 @@ def all_messages_to_trace( for i, message in enumerate(messages): if isinstance(message, LiteLLMMessage): latency_ms = message_latency.get(i) if message_latency else None - trace.append(self.litellm_message_to_trace_message(message, latency_ms)) + usage = message_usage.get(i) if message_usage else None + trace.append( + self.litellm_message_to_trace_message(message, latency_ms, usage) + ) else: trace.append(message) return trace diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index a45080c19..3139b2aa3 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -2823,3 +2823,198 @@ def test_litellm_message_to_trace_message_no_latency(self, adapter): trace_msg = adapter.litellm_message_to_trace_message(msg) assert "latency_ms" not in trace_msg + + def test_litellm_message_to_trace_message_includes_usage(self, adapter): + """litellm_message_to_trace_message attaches per-call usage when provided.""" + from litellm.types.utils import Message as LiteLLMMessage + + msg = LiteLLMMessage(role="assistant", content="Hello") + per_call_usage = Usage(input_tokens=42, output_tokens=7, total_tokens=49) + + trace_msg = adapter.litellm_message_to_trace_message( + msg, latency_ms=99, usage=per_call_usage + ) + assert trace_msg["latency_ms"] == 99 + assert trace_msg["usage"] is per_call_usage + assert trace_msg["usage"].input_tokens == 42 + assert trace_msg["usage"].output_tokens == 7 + + def test_litellm_message_to_trace_message_no_usage(self, adapter): + """litellm_message_to_trace_message omits usage when not provided.""" + from litellm.types.utils import Message as LiteLLMMessage + + msg = LiteLLMMessage(role="assistant", content="Hello") + + trace_msg = adapter.litellm_message_to_trace_message(msg) + assert "usage" not in trace_msg + + @pytest.mark.asyncio + async def test_run_model_turn_records_per_call_usage_for_each_tool_loop_inference( + self, adapter, provider + ): + """Inner tool-loop inferences each get their own per-message usage entry. + + Locks the fix for kintsugi's chain-summing token undercount: when a + single ``call_model`` invocation runs N inferences (model → tool → model + → tool → model), the saved ``task_run.usage`` only carries the LAST + inference's tokens. Per-message usage on every assistant trace event + lets downstream consumers sum the actual provider-billed totals. + """ + # Two LLM responses with distinct usage shapes, so we can tell them apart. + tool_call_response = ModelResponse( + model="test-model", + choices=[ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "some_tool", + "arguments": '{"arg": "val"}', + }, + } + ], + } + } + ], + usage={"prompt_tokens": 100, "completion_tokens": 11, "total_tokens": 111}, + ) + final_response = ModelResponse( + model="test-model", + choices=[{"message": {"content": "Final answer"}}], + usage={"prompt_tokens": 200, "completion_tokens": 22, "total_tokens": 222}, + ) + + monotonic_values = [0.0, 0.05, 0.05, 0.20] # 50ms then 150ms + with patch.object(adapter, "build_completion_kwargs", return_value={}): + with patch.object( + adapter, + "acompletion_checking_response", + side_effect=[ + (tool_call_response, tool_call_response.choices[0]), + (final_response, final_response.choices[0]), + ], + ): + with patch.object( + adapter, + "process_tool_calls", + return_value=( + None, + [ + { + "role": "tool", + "content": "tool result", + "tool_call_id": "call_1", + } + ], + ), + ): + with patch( + "kiln_ai.adapters.model_adapters.litellm_adapter.time.monotonic", + side_effect=monotonic_values, + ): + result = await adapter._run_model_turn( + provider, + [{"role": "user", "content": "Hi"}], + None, + False, + ) + + # message_usage carries one entry per assistant inference, keyed by + # the message's index in all_messages. + assert result.message_usage is not None + assert len(result.message_usage) == 2 + + # Identify the two assistant message indices. + asst_indices = [ + i + for i, m in enumerate(result.all_messages) + if (isinstance(m, dict) and m.get("role") == "assistant") + or getattr(m, "role", None) == "assistant" + ] + assert len(asst_indices) == 2 + + first = result.message_usage[asst_indices[0]] + second = result.message_usage[asst_indices[1]] + + # Per-call usage matches the per-call ModelResponse, NOT the summed total. + assert first.input_tokens == 100 + assert first.output_tokens == 11 + assert first.total_llm_latency_ms == 50 + assert second.input_tokens == 200 + assert second.output_tokens == 22 + assert second.total_llm_latency_ms == 150 + + # Sanity: turn-total usage IS still the sum (existing contract). + assert result.usage.input_tokens == 300 + assert result.usage.output_tokens == 33 + + @pytest.mark.asyncio + async def test_all_messages_to_trace_attaches_per_message_usage( + self, adapter, provider + ): + """End-to-end: per-message usage flows from the inference loop onto the + assistant trace messages.""" + tool_call_response = ModelResponse( + model="test-model", + choices=[ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "some_tool", + "arguments": "{}", + }, + } + ], + } + } + ], + usage={"prompt_tokens": 50, "completion_tokens": 5, "total_tokens": 55}, + ) + final_response = ModelResponse( + model="test-model", + choices=[{"message": {"content": "Done"}}], + usage={"prompt_tokens": 75, "completion_tokens": 8, "total_tokens": 83}, + ) + with patch.object(adapter, "build_completion_kwargs", return_value={}): + with patch.object( + adapter, + "acompletion_checking_response", + side_effect=[ + (tool_call_response, tool_call_response.choices[0]), + (final_response, final_response.choices[0]), + ], + ): + with patch.object( + adapter, + "process_tool_calls", + return_value=( + None, + [ + { + "role": "tool", + "content": "ok", + "tool_call_id": "call_1", + } + ], + ), + ): + result = await adapter._run_model_turn( + provider, [{"role": "user", "content": "Hi"}], None, False + ) + + trace = adapter.all_messages_to_trace( + result.all_messages, result.message_latency, result.message_usage + ) + asst_msgs = [m for m in trace if m.get("role") == "assistant"] + assert len(asst_msgs) == 2 + assert asst_msgs[0]["usage"].input_tokens == 50 + assert asst_msgs[1]["usage"].input_tokens == 75 diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index 1d62c9fa9..b59062191 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -1,7 +1,7 @@ import json from typing import TYPE_CHECKING, Dict, List, Union -from pydantic import BaseModel, Field, ValidationInfo, model_validator +from pydantic import Field, ValidationInfo, model_validator from typing_extensions import Self from kiln_ai.datamodel.basemodel import KilnParentedModel, KilnParentModel @@ -9,6 +9,7 @@ from kiln_ai.datamodel.json_schema import validate_schema_with_value_error from kiln_ai.datamodel.strict_mode import strict_mode from kiln_ai.datamodel.task_output import DataSource, TaskOutput +from kiln_ai.datamodel.usage import Usage from kiln_ai.utils.open_ai_types import ( ChatCompletionMessageParam, trace_has_pending_client_tool_calls, @@ -18,79 +19,12 @@ from kiln_ai.datamodel.task import Task -class Usage(BaseModel): - """Token usage and cost information for a task run.""" - - input_tokens: int | None = Field( - default=None, - description="The number of input tokens used in the task run.", - ge=0, - ) - output_tokens: int | None = Field( - default=None, - description="The number of output tokens used in the task run.", - ge=0, - ) - total_tokens: int | None = Field( - default=None, - description="The total number of tokens used in the task run.", - ge=0, - ) - cost: float | None = Field( - default=None, - description="The cost of the task run in US dollars, saved at runtime (prices can change over time).", - ge=0, - ) - cached_tokens: int | None = Field( - default=None, - description="Number of tokens served from prompt cache. None if not reported.", - ge=0, - ) - total_llm_latency_ms: int | None = Field( - default=None, - description="Total time spent waiting on LLM API calls in milliseconds. Sum of per-call latencies, excludes tool execution time.", - ge=0, - ) - - def __add__(self, other: "Usage") -> "Usage": - """Add two Usage objects together, handling None values gracefully. - - None + None = None - None + value = value - value + None = value - value1 + value2 = value1 + value2 - """ - if not isinstance(other, Usage): - raise TypeError(f"Cannot add Usage with {type(other).__name__}") - - def _add_optional_int(a: int | None, b: int | None) -> int | None: - if a is None and b is None: - return None - if a is None: - return b - if b is None: - return a - return a + b - - def _add_optional_float(a: float | None, b: float | None) -> float | None: - if a is None and b is None: - return None - if a is None: - return b - if b is None: - return a - return a + b - - return Usage( - input_tokens=_add_optional_int(self.input_tokens, other.input_tokens), - output_tokens=_add_optional_int(self.output_tokens, other.output_tokens), - total_tokens=_add_optional_int(self.total_tokens, other.total_tokens), - cost=_add_optional_float(self.cost, other.cost), - cached_tokens=_add_optional_int(self.cached_tokens, other.cached_tokens), - total_llm_latency_ms=_add_optional_int( - self.total_llm_latency_ms, other.total_llm_latency_ms - ), - ) +# ``Usage`` is defined in ``kiln_ai.datamodel.usage`` (its own module so it +# can be imported from ``kiln_ai.utils.open_ai_types`` without cycles — +# ``open_ai_types`` annotates the per-message ``usage`` field on assistant +# trace entries with this type). Re-exported here for backwards compat +# with callers that import ``from kiln_ai.datamodel.task_run import Usage``. +__all__ = ["TaskRun", "Usage"] class TaskRun( diff --git a/libs/core/kiln_ai/datamodel/usage.py b/libs/core/kiln_ai/datamodel/usage.py new file mode 100644 index 000000000..7d24a0037 --- /dev/null +++ b/libs/core/kiln_ai/datamodel/usage.py @@ -0,0 +1,113 @@ +"""Token usage / cost / latency model. + +Lives in its own module so ``kiln_ai.utils.open_ai_types`` can import it +for the per-message ``usage`` field on assistant trace messages without +creating a circular dependency on ``kiln_ai.datamodel.task_run`` (which +itself imports from ``open_ai_types``). +""" + +from pydantic import BaseModel, Field + + +def record_per_call_usage_and_latency( + call_usage: "Usage", + call_latency_ms: int, + turn_usage: "Usage", + message_index: int, + message_latency: dict[int, int], + message_usage: dict[int, "Usage"], +) -> "Usage": + """Aggregate one inference's usage + latency onto the turn total and + record the per-message entries on the trace. + + Returns the new turn-level ``Usage`` (since ``Usage + Usage`` produces + a fresh object — caller reassigns). Stamps ``total_llm_latency_ms`` + on the per-call ``Usage`` AFTER the aggregation so latency travels + with usage on the trace without double-counting it on the turn total. + + Shared by ``_run_model_turn`` (litellm_adapter.py) and + ``_stream_model_turn`` (adapter_stream.py). + """ + new_turn_usage = turn_usage + call_usage + new_turn_usage.total_llm_latency_ms = ( + new_turn_usage.total_llm_latency_ms or 0 + ) + call_latency_ms + call_usage.total_llm_latency_ms = call_latency_ms + message_latency[message_index] = call_latency_ms + message_usage[message_index] = call_usage + return new_turn_usage + + +class Usage(BaseModel): + """Token usage and cost information for a task run.""" + + input_tokens: int | None = Field( + default=None, + description="The number of input tokens used in the task run.", + ge=0, + ) + output_tokens: int | None = Field( + default=None, + description="The number of output tokens used in the task run.", + ge=0, + ) + total_tokens: int | None = Field( + default=None, + description="The total number of tokens used in the task run.", + ge=0, + ) + cost: float | None = Field( + default=None, + description="The cost of the task run in US dollars, saved at runtime (prices can change over time).", + ge=0, + ) + cached_tokens: int | None = Field( + default=None, + description="Number of tokens served from prompt cache. None if not reported.", + ge=0, + ) + total_llm_latency_ms: int | None = Field( + default=None, + description="Total time spent waiting on LLM API calls in milliseconds. Sum of per-call latencies, excludes tool execution time.", + ge=0, + ) + + def __add__(self, other: "Usage") -> "Usage": + """Add two Usage objects together, handling None values gracefully. + + None + None = None + None + value = value + value + None = value + value1 + value2 = value1 + value2 + """ + if not isinstance(other, Usage): + raise TypeError(f"Cannot add Usage with {type(other).__name__}") + + def _add_optional_int(a: int | None, b: int | None) -> int | None: + if a is None and b is None: + return None + if a is None: + return b + if b is None: + return a + return a + b + + def _add_optional_float(a: float | None, b: float | None) -> float | None: + if a is None and b is None: + return None + if a is None: + return b + if b is None: + return a + return a + b + + return Usage( + input_tokens=_add_optional_int(self.input_tokens, other.input_tokens), + output_tokens=_add_optional_int(self.output_tokens, other.output_tokens), + total_tokens=_add_optional_int(self.total_tokens, other.total_tokens), + cost=_add_optional_float(self.cost, other.cost), + cached_tokens=_add_optional_int(self.cached_tokens, other.cached_tokens), + total_llm_latency_ms=_add_optional_int( + self.total_llm_latency_ms, other.total_llm_latency_ms + ), + ) diff --git a/libs/core/kiln_ai/utils/open_ai_types.py b/libs/core/kiln_ai/utils/open_ai_types.py index 9eec6caac..05a36a486 100644 --- a/libs/core/kiln_ai/utils/open_ai_types.py +++ b/libs/core/kiln_ai/utils/open_ai_types.py @@ -8,6 +8,7 @@ """ from typing import ( + Annotated, Any, Iterable, List, @@ -30,8 +31,19 @@ ContentArrayOfContentPart, FunctionCall, ) +from pydantic import WithJsonSchema from typing_extensions import Required, TypedDict +# JSON-schema hint for the per-message ``usage`` field. We can't import +# ``Usage`` eagerly here (cycle through ``kiln_ai.datamodel`` which +# loads ``task_run`` which imports from this module). ``WithJsonSchema`` +# lets the OpenAPI generator emit a ``$ref`` to the same ``Usage`` +# component schema that ``TaskRun.usage`` registers, so the frontend +# stays typed even though the Python annotation here is ``Any``. +_UsageOpenApiHint = WithJsonSchema( + {"anyOf": [{"$ref": "#/components/schemas/Usage"}, {"type": "null"}]} +) + class ChatCompletionAssistantMessageParamWrapper(TypedDict, total=False): """ @@ -87,6 +99,29 @@ class ChatCompletionAssistantMessageParamWrapper(TypedDict, total=False): latency_ms: Optional[int] """Time spent waiting on this specific LLM API call in milliseconds.""" + usage: Annotated[Optional[Any], _UsageOpenApiHint] + """Token usage for this specific LLM API call. + + The runtime value is a ``kiln_ai.datamodel.usage.Usage`` instance (or + its dict serialization on a deserialized trace). Typed as ``Any`` to + avoid an import cycle: ``kiln_ai.datamodel`` eagerly loads ``task_run``, + which imports from this module, so any direct annotation of ``Usage`` + here would form a cycle that breaks Pydantic's schema build for + ``TaskRun.trace``. The ``WithJsonSchema`` hint above keeps the + generated OpenAPI / TypeScript schema as ``Usage | null`` despite + the relaxed Python type. + + Captured per-call (not summed) so downstream consumers can sum across + every assistant turn in the trace and recover provider-true totals — + even when multiple inferences happen inside a single + ``return_on_tool_call=False`` turn (the loop where the model calls a + tool, the adapter runs it internally, and the model is re-called + within the same ``call_model`` invocation). Without this field, only + the last inference's usage shows up on the saved snapshot's + ``task_run.usage``, and inner-loop inferences are billed by the + provider but invisible to trace consumers. + """ + class ChatCompletionToolMessageParamWrapper(TypedDict, total=False): content: Required[Union[str, Iterable[ChatCompletionContentPartTextParam]]] @@ -124,6 +159,7 @@ class ChatCompletionToolMessageParamWrapper(TypedDict, total=False): KILN_ONLY_MESSAGE_FIELDS: frozenset[str] = frozenset( { "latency_ms", + "usage", "is_error", "error_message", "kiln_task_tool_data", diff --git a/libs/core/kiln_ai/utils/test_open_ai_types.py b/libs/core/kiln_ai/utils/test_open_ai_types.py index e2189e63f..788810b8f 100644 --- a/libs/core/kiln_ai/utils/test_open_ai_types.py +++ b/libs/core/kiln_ai/utils/test_open_ai_types.py @@ -47,6 +47,10 @@ def test_assistant_message_param_properties_match(): assert "latency_ms" in kiln_properties, "Kiln should have latency_ms" kiln_properties.remove("latency_ms") + # usage is a Kiln-added property for per-call token accounting. Confirm it's there and remove it. + assert "usage" in kiln_properties, "Kiln should have usage" + kiln_properties.remove("usage") + assert openai_properties == kiln_properties, ( f"Property names don't match. " f"OpenAI has: {openai_properties}, " @@ -215,7 +219,7 @@ def test_tool_message_wrapper_can_be_instantiated(): def test_kiln_only_message_fields_set(): assert KILN_ONLY_MESSAGE_FIELDS == frozenset( - {"latency_ms", "is_error", "error_message", "kiln_task_tool_data"} + {"latency_ms", "usage", "is_error", "error_message", "kiln_task_tool_data"} ) @@ -227,6 +231,7 @@ def test_sanitize_messages_strips_kiln_only_fields(): "role": "assistant", "content": "hello", "latency_ms": 200, + "usage": {"input_tokens": 10, "output_tokens": 5}, }, { "role": "tool",