Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3732,6 +3732,8 @@ export interface components {
tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][];
/** Latency Ms */
latency_ms?: number | null;
/** Usage */
usage?: components["schemas"]["Usage"] | null;
};
/**
* ChatCompletionAssistantMessageParamWrapper
Expand Down Expand Up @@ -3762,6 +3764,8 @@ export interface components {
tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][];
/** Latency Ms */
latency_ms?: number | null;
/** Usage */
usage?: components["schemas"]["Usage"] | null;
};
/** ChatCompletionContentPartImageParam */
ChatCompletionContentPartImageParam: {
Expand Down
17 changes: 11 additions & 6 deletions libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from kiln_ai.adapters.run_output import RunOutput
from kiln_ai.datamodel import Usage
from kiln_ai.datamodel.usage import record_per_call_usage_and_latency

if TYPE_CHECKING:
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
self._result: AdapterStreamResult | None = None
self._iterated = False
self._message_latency: dict[int, int] = {}
self._message_usage: dict[int, Usage] = {}

@property
def result(self) -> AdapterStreamResult:
Expand Down Expand Up @@ -134,7 +136,7 @@ async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]:
raise RuntimeError(f"assistant message is not a string: {prior_output}")

trace = self._adapter.all_messages_to_trace(
self._messages, self._message_latency
self._messages, self._message_latency, self._message_usage
)
self._result = AdapterStreamResult(
run_output=RunOutput(
Expand Down Expand Up @@ -170,10 +172,6 @@ async def _stream_model_turn(
call_latency_ms = int((time.monotonic() - start) * 1000)

response, response_choice = _validate_response(stream.response)
usage += self._adapter.usage_from_response(response)
usage.total_llm_latency_ms = (
usage.total_llm_latency_ms or 0
) + call_latency_ms

content = response_choice.message.content
tool_calls = response_choice.message.tool_calls
Expand All @@ -183,7 +181,14 @@ async def _stream_model_turn(
)

self._messages.append(response_choice.message)
self._message_latency[len(self._messages) - 1] = call_latency_ms
usage = record_per_call_usage_and_latency(
call_usage=self._adapter.usage_from_response(response),
call_latency_ms=call_latency_ms,
turn_usage=usage,
message_index=len(self._messages) - 1,
message_latency=self._message_latency,
message_usage=self._message_usage,
)

if tool_calls and len(tool_calls) > 0:
# Check for return_on_tool_call BEFORE processing
Expand Down
51 changes: 40 additions & 11 deletions libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
KilnAgentRunConfigProperties,
as_kiln_agent_run_config,
)
from kiln_ai.datamodel.usage import record_per_call_usage_and_latency
from kiln_ai.tools.base_tool import (
KilnToolInterface,
ToolCallContext,
Expand Down Expand Up @@ -82,6 +83,14 @@ class ModelTurnResult:
usage: Usage
interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None
message_latency: dict[int, int] | None = None
message_usage: dict[int, Usage] | None = None
"""Per-assistant-message token usage, keyed by index in ``all_messages``.

Threaded the same way as ``message_latency`` so traces can carry the
usage of every individual inference call — including inner tool-loop
iterations within a single turn that get aggregated into ``usage``
above.
"""


class LiteLlmAdapter(BaseAdapter):
Expand Down Expand Up @@ -126,9 +135,10 @@ async def _run_model_turn(
usage = Usage()
messages = list(prior_messages)
tool_calls_count = 0
# LLM call latency in ms, keyed by index in the messages list.
# LLM call latency in ms + usage, keyed by index in the messages list.
# Kept separate because we don't own the LiteLLM message objects.
message_latency: dict[int, int] = {}
message_usage: dict[int, Usage] = {}

while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
# Build completion kwargs for tool calls
Expand All @@ -147,12 +157,6 @@ async def _run_model_turn(
)
call_latency_ms = int((time.monotonic() - start) * 1000)

# count the usage
usage += self.usage_from_response(model_response)
usage.total_llm_latency_ms = (
usage.total_llm_latency_ms or 0
) + call_latency_ms

# Extract content and tool calls
if not hasattr(response_choice, "message"):
raise ValueError("Response choice has no message")
Expand All @@ -165,7 +169,16 @@ async def _run_model_turn(

# Add message to messages, so it can be used in the next turn
messages.append(response_choice.message)
message_latency[len(messages) - 1] = call_latency_ms
# Aggregate per-call usage + latency onto the turn total and
# stamp them onto the per-message dicts for the trace.
usage = record_per_call_usage_and_latency(
call_usage=self.usage_from_response(model_response),
call_latency_ms=call_latency_ms,
turn_usage=usage,
message_index=len(messages) - 1,
message_latency=message_latency,
message_usage=message_usage,
)

# Process tool calls if any
if tool_calls and len(tool_calls) > 0:
Expand All @@ -188,6 +201,7 @@ async def _run_model_turn(
usage=usage,
interrupted_by_tool_calls=standard_tool_calls,
message_latency=message_latency,
message_usage=message_usage,
)

# otherwise: process tool calls internally until final output
Expand All @@ -208,6 +222,7 @@ async def _run_model_turn(
model_choice=response_choice,
usage=usage,
message_latency=message_latency,
message_usage=message_usage,
)

# If there were tool calls, increment counter and continue
Expand All @@ -224,6 +239,7 @@ async def _run_model_turn(
model_choice=response_choice,
usage=usage,
message_latency=message_latency,
message_usage=message_usage,
)

# If we get here with no content and no tool calls, break
Expand Down Expand Up @@ -256,6 +272,7 @@ async def _run(
final_choice: Choices | None = None
turns = 0
message_latency: dict[int, int] = {}
message_usage: dict[int, Usage] = {}

# Same loop for both fresh runs and prior_trace continuation.
# _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues).
Expand Down Expand Up @@ -288,14 +305,18 @@ async def _run(
usage += turn_result.usage
if turn_result.message_latency:
message_latency.update(turn_result.message_latency)
if turn_result.message_usage:
message_usage.update(turn_result.message_usage)

prior_output = turn_result.assistant_message
messages = turn_result.all_messages
final_choice = turn_result.model_choice

# Check if we were interrupted by tool calls
if turn_result.interrupted_by_tool_calls:
trace = self.all_messages_to_trace(messages, message_latency)
trace = self.all_messages_to_trace(
messages, message_latency, message_usage
)
intermediate_outputs = chat_formatter.intermediate_outputs()
output = RunOutput(
output=prior_output or "",
Expand All @@ -319,7 +340,7 @@ async def _run(
if not isinstance(prior_output, str):
raise RuntimeError(f"assistant message is not a string: {prior_output}")

trace = self.all_messages_to_trace(messages, message_latency)
trace = self.all_messages_to_trace(messages, message_latency, message_usage)
output = RunOutput(
output=prior_output,
intermediate_outputs=intermediate_outputs,
Expand Down Expand Up @@ -878,6 +899,7 @@ def litellm_message_to_trace_message(
self,
raw_message: LiteLLMMessage,
latency_ms: int | None = None,
usage: Usage | None = None,
) -> ChatCompletionAssistantMessageParamWrapper:
"""
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
Expand Down Expand Up @@ -919,6 +941,9 @@ def litellm_message_to_trace_message(
if latency_ms is not None:
message["latency_ms"] = latency_ms

if usage is not None:
message["usage"] = usage

if not message.get("content") and not message.get("tool_calls"):
raise ValueError(
"Model returned an assistant message, but no content or tool calls. This is not supported."
Expand All @@ -930,6 +955,7 @@ def all_messages_to_trace(
self,
messages: list[ChatCompletionMessageIncludingLiteLLM],
message_latency: dict[int, int] | None = None,
message_usage: dict[int, Usage] | None = None,
) -> list[ChatCompletionMessageParam]:
"""
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Expand All @@ -938,7 +964,10 @@ def all_messages_to_trace(
for i, message in enumerate(messages):
if isinstance(message, LiteLLMMessage):
latency_ms = message_latency.get(i) if message_latency else None
trace.append(self.litellm_message_to_trace_message(message, latency_ms))
usage = message_usage.get(i) if message_usage else None
trace.append(
self.litellm_message_to_trace_message(message, latency_ms, usage)
)
else:
trace.append(message)
return trace
Loading
Loading