diff --git a/app/desktop/studio_server/eval_api.py b/app/desktop/studio_server/eval_api.py index e047caa55..4dc323296 100644 --- a/app/desktop/studio_server/eval_api.py +++ b/app/desktop/studio_server/eval_api.py @@ -236,6 +236,10 @@ class MeanUsage(BaseModel): mean_cost: float | None = Field( default=None, description="Average cost per run in USD." ) + mean_total_llm_latency_ms: float | None = Field( + default=None, + description="Average total LLM latency per run in milliseconds.", + ) class EvalRunResult(BaseModel): @@ -1422,10 +1426,12 @@ async def get_run_config_eval_scores( total_output_tokens = 0.0 total_total_tokens = 0.0 total_cost = 0.0 + total_llm_latency_ms_sum = 0.0 input_tokens_count = 0 output_tokens_count = 0 total_tokens_count = 0 cost_count = 0 + latency_ms_count = 0 total_eval_runs = 0 for eval in evals: @@ -1506,6 +1512,9 @@ async def get_run_config_eval_scores( if usage.cost is not None: total_cost += usage.cost cost_count += 1 + if usage.total_llm_latency_ms is not None: + total_llm_latency_ms_sum += usage.total_llm_latency_ms + latency_ms_count += 1 incomplete = False for output_score in eval.output_scores: @@ -1576,6 +1585,9 @@ async def get_run_config_eval_scores( if total_tokens_count >= threshold else None, mean_cost=total_cost / cost_count if cost_count >= threshold else None, + mean_total_llm_latency_ms=total_llm_latency_ms_sum / latency_ms_count + if latency_ms_count >= threshold + else None, ) return RunConfigEvalScoresSummary( diff --git a/app/desktop/studio_server/test_eval_api.py b/app/desktop/studio_server/test_eval_api.py index 7c7e83b04..cfb56e2fa 100644 --- a/app/desktop/studio_server/test_eval_api.py +++ b/app/desktop/studio_server/test_eval_api.py @@ -2042,6 +2042,7 @@ async def test_get_run_config_eval_scores_with_usage( output_tokens=50, total_tokens=150, cost=0.005, + total_llm_latency_ms=500, ), parent=mock_task, ) @@ -2063,6 +2064,7 @@ async def test_get_run_config_eval_scores_with_usage( output_tokens=100, total_tokens=300, cost=0.010, + total_llm_latency_ms=1000, ), parent=mock_task, ) @@ -2193,6 +2195,168 @@ async def test_get_run_config_eval_scores_with_usage( assert mean_usage["mean_output_tokens"] == 75.0 assert mean_usage["mean_total_tokens"] == 225.0 assert mean_usage["mean_cost"] == 0.0075 + # Expected mean latency: (500+1000)/2 = 750.0 (2 of 3 runs have latency, 66.7% > 50%) + assert mean_usage["mean_total_llm_latency_ms"] == 750.0 + + +@pytest.mark.asyncio +async def test_get_run_config_eval_scores_latency_below_threshold( + client, mock_task_from_id, mock_task, mock_eval, mock_eval_config, mock_run_config +): + """Test that mean_total_llm_latency_ms is None when fewer than 50% of runs have latency data""" + mock_task_from_id.return_value = mock_task + + # Create 3 TaskRuns, only 1 with latency data (1/3 = 33% < 50% threshold) + task_run_1 = TaskRun( + input="test input 1", + input_source=DataSource( + type=DataSourceType.synthetic, + properties={ + "model_name": "gpt-4", + "model_provider": "openai", + "adapter_name": "langchain_adapter", + }, + ), + output=TaskOutput(output="test output 1"), + usage=Usage( + input_tokens=100, + output_tokens=50, + total_tokens=150, + cost=0.005, + total_llm_latency_ms=500, + ), + parent=mock_task, + ) + task_run_1.save_to_file() + + task_run_2 = TaskRun( + input="test input 2", + input_source=DataSource( + type=DataSourceType.synthetic, + properties={ + "model_name": "gpt-4", + "model_provider": "openai", + "adapter_name": "langchain_adapter", + }, + ), + output=TaskOutput(output="test output 2"), + usage=Usage( + input_tokens=200, + output_tokens=100, + total_tokens=300, + cost=0.010, + ), + parent=mock_task, + ) + task_run_2.save_to_file() + + task_run_3 = TaskRun( + input="test input 3", + input_source=DataSource( + type=DataSourceType.synthetic, + properties={ + "model_name": "gpt-4", + "model_provider": "openai", + "adapter_name": "langchain_adapter", + }, + ), + output=TaskOutput(output="test output 3"), + usage=Usage( + input_tokens=150, + output_tokens=75, + total_tokens=225, + cost=0.008, + ), + parent=mock_task, + ) + task_run_3.save_to_file() + + eval_run_1 = EvalRun( + task_run_config_id=mock_run_config.id, + scores={"score1": 4.0, "overall_rating": 4.0}, + input="test input 1", + output="test output 1", + dataset_id=task_run_1.id, + task_run_usage=task_run_1.usage, + parent=mock_eval_config, + ) + eval_run_1.save_to_file() + + eval_run_2 = EvalRun( + task_run_config_id=mock_run_config.id, + scores={"score1": 4.5, "overall_rating": 4.5}, + input="test input 2", + output="test output 2", + dataset_id=task_run_2.id, + task_run_usage=task_run_2.usage, + parent=mock_eval_config, + ) + eval_run_2.save_to_file() + + eval_run_3 = EvalRun( + task_run_config_id=mock_run_config.id, + scores={"score1": 3.5, "overall_rating": 3.5}, + input="test input 3", + output="test output 3", + dataset_id=task_run_3.id, + task_run_usage=task_run_3.usage, + parent=mock_eval_config, + ) + eval_run_3.save_to_file() + + mock_task_for_api = MagicMock() + mock_task_for_api.runs.return_value = [task_run_1, task_run_2, task_run_3] + mock_task_for_api.evals.return_value = [mock_eval] + + mock_eval_config_for_api = MagicMock() + mock_eval_config_for_api.runs.return_value = [eval_run_1, eval_run_2, eval_run_3] + mock_eval_config_for_api.id = mock_eval_config.id + + mock_eval_for_api = MagicMock() + mock_eval_for_api.configs.return_value = [mock_eval_config_for_api] + mock_eval_for_api.id = mock_eval.id + mock_eval_for_api.eval_set_filter_id = mock_eval.eval_set_filter_id + mock_eval_for_api.output_scores = mock_eval.output_scores + + mock_eval.current_config_id = mock_eval_config.id + + with ( + patch( + "app.desktop.studio_server.eval_api.task_from_id" + ) as mock_task_from_id_patch, + patch( + "app.desktop.studio_server.eval_api.eval_from_id" + ) as mock_eval_from_id_patch, + patch( + "app.desktop.studio_server.eval_api.task_run_config_from_id" + ) as mock_task_run_config_from_id_patch, + ): + mock_task_from_id_patch.return_value = mock_task_for_api + mock_eval_from_id_patch.return_value = mock_eval_for_api + mock_task_run_config_from_id_patch.return_value = mock_run_config + + with patch( + "app.desktop.studio_server.eval_api.dataset_ids_in_filter" + ) as mock_dataset_ids_in_filter: + mock_dataset_ids_in_filter.return_value = { + task_run_1.id, + task_run_2.id, + task_run_3.id, + } + + response = client.get( + f"/api/projects/project1/tasks/task1/run_configs/{mock_run_config.id}/eval_scores" + ) + + assert response.status_code == 200 + data = response.json() + mean_usage = data["mean_usage"] + assert mean_usage is not None + + # Cost/tokens should be present (3/3 = 100% > 50%) + assert mean_usage["mean_cost"] is not None + # Latency should be None (only 1/3 = 33% < 50% threshold) + assert mean_usage["mean_total_llm_latency_ms"] is None def test_get_eval_configs_score_summary_no_filter_id( diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index 71902acb1..0f1d677a2 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -3677,6 +3677,8 @@ export interface components { refusal?: string | null; /** Tool Calls */ tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][]; + /** Latency Ms */ + latency_ms?: number | null; }; /** * ChatCompletionAssistantMessageParamWrapper @@ -3705,6 +3707,8 @@ export interface components { refusal?: string | null; /** Tool Calls */ tool_calls?: components["schemas"]["ChatCompletionMessageFunctionToolCallParam"][]; + /** Latency Ms */ + latency_ms?: number | null; }; /** ChatCompletionContentPartImageParam */ ChatCompletionContentPartImageParam: { @@ -7036,6 +7040,11 @@ export interface components { * @description Average cost per run in USD. */ mean_cost?: number | null; + /** + * Mean Total Llm Latency Ms + * @description Average total LLM latency per run in milliseconds. + */ + mean_total_llm_latency_ms?: number | null; }; /** ModelDetails */ ModelDetails: { @@ -9980,6 +9989,11 @@ export interface components { * @description Number of tokens served from prompt cache. None if not reported. */ cached_tokens?: number | null; + /** + * Total Llm Latency Ms + * @description Total time spent waiting on LLM API calls in milliseconds. Sum of per-call latencies, excludes tool execution time. + */ + total_llm_latency_ms?: number | null; }; /** * UserModelEntry diff --git a/app/web_ui/src/lib/components/compare_chart.svelte b/app/web_ui/src/lib/components/compare_chart.svelte index de358aeeb..869a1bcb8 100644 --- a/app/web_ui/src/lib/components/compare_chart.svelte +++ b/app/web_ui/src/lib/components/compare_chart.svelte @@ -13,6 +13,7 @@ getRunConfigPromptDisplayName, } from "$lib/utils/run_config_formatters" import ChartNoData from "./chart_no_data.svelte" + import { formatLatency } from "$lib/utils/formatters" // Type for comparison features (same as parent page) type ComparisonFeature = { @@ -122,6 +123,9 @@ if (dataKey.includes("mean_cost")) { return `$${value.toFixed(6)}` } + if (dataKey.includes("latency")) { + return formatLatency(value) + } if (dataKey.includes("tokens")) { return value.toFixed(0) } diff --git a/app/web_ui/src/lib/components/compare_radar_chart.svelte b/app/web_ui/src/lib/components/compare_radar_chart.svelte index 44a813db9..a2f08c6dd 100644 --- a/app/web_ui/src/lib/components/compare_radar_chart.svelte +++ b/app/web_ui/src/lib/components/compare_radar_chart.svelte @@ -34,15 +34,16 @@ // Chart instance let chartInstance: echarts.ECharts | null = null - // Cost key that should be included in radar chart (inverted, lower is better) + // Keys that should be included in radar chart where lower is better const COST_KEY = "cost::mean_cost" + const LATENCY_KEY = "cost::mean_total_llm_latency_ms" - // Check if a key is the cost metric (where lower is better) - function isCostMetric(key: string): boolean { - return key === COST_KEY + // Check if a key is a lower-is-better metric + function isLowerIsBetterMetric(key: string): boolean { + return key === COST_KEY || key === LATENCY_KEY } - export function costToScore( + export function metricToScore( cost: number, costs: number[], { @@ -81,15 +82,19 @@ ...comparisonFeatures .filter((f) => f.eval_id !== "kiln_cost_section") .flatMap((f) => f.items.map((item) => item.key)), - COST_KEY, // Add cost at the end + COST_KEY, + LATENCY_KEY, ] // Get labels for radar indicators function getKeyLabel(dataKey: string): string { - // Special handling for cost metric + // Special handling for lower-is-better metrics if (dataKey === COST_KEY) { return "Cost Efficiency" } + if (dataKey === LATENCY_KEY) { + return "Speed" + } for (const feature of comparisonFeatures) { const item = feature.items.find((i) => i.key === dataKey) if (item) return item.label @@ -132,7 +137,10 @@ } // Build full tooltip HTML for a run config (reused by chart tooltip and legend tooltip) - function buildRunConfigTooltip(name: string, allCosts: number[]): string { + function buildRunConfigTooltip( + name: string, + lowerIsBetterValues: Record, + ): string { const config = run_configs.find((c) => getSeriesDisplayName(c) === name) let html = `
${name}
` @@ -159,9 +167,22 @@ const rawValue = config?.id ? getModelValueRaw(config.id, key) : null if (rawValue === null) { html += `
${label}: N/A
` - } else if (isCostMetric(key)) { - const displayValue = costToScore(rawValue, allCosts) + } else if (key === COST_KEY) { + const displayValue = metricToScore( + rawValue, + lowerIsBetterValues[key] || [], + ) html += `
${label}: ${displayValue.toFixed(1)} (Mean Cost: $${rawValue.toFixed(6)})
` + } else if (key === LATENCY_KEY) { + const displayValue = metricToScore( + rawValue, + lowerIsBetterValues[key] || [], + ) + const formatted = + rawValue < 1000 + ? `${Math.round(rawValue)}ms` + : `${(rawValue / 1000).toFixed(1)}s` + html += `
${label}: ${displayValue.toFixed(1)} (Mean Latency: ${formatted})
` } else { html += `
${label}: ${rawValue.toFixed(3)}
` } @@ -174,15 +195,15 @@ indicators: { name: string; max: number }[] series: { value: number[]; name: string }[] legend: string[] - allCosts: number[] + lowerIsBetterValues: Record } { const indicators: { name: string; max: number }[] = [] const series: { value: number[]; name: string }[] = [] const legend: string[] = [] - const allCosts: number[] = [] + const lowerIsBetterValues: Record = {} if (dataKeys.length === 0 || selectedRunConfigIds.length === 0) { - return { indicators, series, legend, allCosts } + return { indicators, series, legend, lowerIsBetterValues } } // Calculate max values for each data key across all selected run configs @@ -195,19 +216,20 @@ if (value !== null && value > max) { max = value } - if (value !== null && isCostMetric(key)) { - allCosts.push(value) + if (value !== null && isLowerIsBetterMetric(key)) { + if (!lowerIsBetterValues[key]) lowerIsBetterValues[key] = [] + lowerIsBetterValues[key].push(value) } } // Add 10% padding to max for better visualization maxValues[key] = max > 0 ? max * 1.1 : 1 } - // Build indicators with actual max values (except cost which uses 0-100 scale) + // Build indicators with actual max values (lower-is-better metrics use 0-100 scale) for (const key of dataKeys) { indicators.push({ name: getKeyLabel(key), - max: isCostMetric(key) ? 100 : maxValues[key], + max: isLowerIsBetterMetric(key) ? 100 : maxValues[key], }) } @@ -224,8 +246,8 @@ let displayValue: number if (rawValue === null) { displayValue = 0 - } else if (isCostMetric(key)) { - displayValue = costToScore(rawValue, allCosts) + } else if (isLowerIsBetterMetric(key)) { + displayValue = metricToScore(rawValue, lowerIsBetterValues[key] || []) } else { displayValue = rawValue } @@ -241,7 +263,7 @@ } } - return { indicators, series, legend, allCosts } + return { indicators, series, legend, lowerIsBetterValues } } // Check if there's data to display (reactive, depends on dataKeys and selectedRunConfigIds) @@ -261,7 +283,8 @@ return } - const { indicators, series, legend, allCosts } = generateChartData() + const { indicators, series, legend, lowerIsBetterValues } = + generateChartData() const legendFormatter = buildLegendFormatter() @@ -270,7 +293,7 @@ tooltip: { trigger: "item", formatter: (params: { name: string }) => - buildRunConfigTooltip(params.name, allCosts), + buildRunConfigTooltip(params.name, lowerIsBetterValues), }, legend: { data: legend, @@ -282,7 +305,7 @@ tooltip: { show: true, formatter: (params: { name: string }) => - buildRunConfigTooltip(params.name, allCosts), + buildRunConfigTooltip(params.name, lowerIsBetterValues), }, textStyle: { lineHeight: 16, diff --git a/app/web_ui/src/lib/utils/formatters.ts b/app/web_ui/src/lib/utils/formatters.ts index a812483fc..d06585c5d 100644 --- a/app/web_ui/src/lib/utils/formatters.ts +++ b/app/web_ui/src/lib/utils/formatters.ts @@ -326,3 +326,9 @@ export function formatEvalConfigName( ] return eval_config.name + " — " + parts.join(", ") } + +export function formatLatency(ms: number): string { + const roundedMs = Math.round(ms) + if (roundedMs < 1000) return `${roundedMs}ms` + return `${(ms / 1000).toFixed(1)}s` +} diff --git a/app/web_ui/src/routes/(app)/run/run.svelte b/app/web_ui/src/routes/(app)/run/run.svelte index dbabf48eb..1b649ff46 100644 --- a/app/web_ui/src/routes/(app)/run/run.svelte +++ b/app/web_ui/src/routes/(app)/run/run.svelte @@ -11,7 +11,7 @@ import { client } from "$lib/api_client" import Output from "$lib/ui/output.svelte" import { KilnError, createKilnError } from "$lib/utils/error_handlers" - import { formatDate } from "$lib/utils/formatters" + import { formatDate, formatLatency } from "$lib/utils/formatters" import { bounceOut } from "svelte/easing" import { fly } from "svelte/transition" import { onMount } from "svelte" @@ -59,12 +59,13 @@ async function calculate_subtask_usage( trace: Trace | null | undefined, visited: Set = new Set(), - ): Promise<{ cost: number; tokens: number }> { - if (!trace) return { cost: 0, tokens: 0 } + ): Promise<{ cost: number; tokens: number; latency_ms: number }> { + if (!trace) return { cost: 0, tokens: 0, latency_ms: 0 } const references = extract_subtask_references(trace) let total_cost = 0 let total_tokens = 0 + let total_llm_latency_ms = 0 for (const ref of references) { const key = `${ref.project_id}:${ref.task_id}:${ref.run_id}` @@ -88,12 +89,14 @@ if (!response.error && response.data) { total_cost += response.data.usage?.cost ?? 0 total_tokens += response.data.usage?.total_tokens ?? 0 + total_llm_latency_ms += response.data.usage?.total_llm_latency_ms ?? 0 const subtask_usage = await calculate_subtask_usage( response.data.trace, visited, ) total_cost += subtask_usage.cost total_tokens += subtask_usage.tokens + total_llm_latency_ms += subtask_usage.latency_ms } } catch (error) { console.warn( @@ -108,11 +111,16 @@ } } - return { cost: total_cost, tokens: total_tokens } + return { + cost: total_cost, + tokens: total_tokens, + latency_ms: total_llm_latency_ms, + } } let subtask_cost: number | null = null let subtask_tokens: number | null = null + let subtask_latency_ms: number | null = null let subtask_usage_loading = false // Counter to prevent race conditions: when run changes rapidly, multiple async requests // may be in flight. We only update state if this request is still the latest one. @@ -126,6 +134,7 @@ if (request_id === subtask_usage_request_id) { subtask_cost = null subtask_tokens = null + subtask_latency_ms = null subtask_usage_loading = false } return @@ -137,11 +146,13 @@ if (request_id === subtask_usage_request_id) { subtask_cost = usage.cost subtask_tokens = usage.tokens + subtask_latency_ms = usage.latency_ms } } catch { if (request_id === subtask_usage_request_id) { subtask_cost = null subtask_tokens = null + subtask_latency_ms = null } } finally { if (request_id === subtask_usage_request_id) { @@ -520,11 +531,13 @@ subtask_cost: number | null, subtask_usage_loading: boolean, subtask_tokens: number | null, + subtask_latency_ms: number | null, ) { let properties = [] const run_cost = run?.usage?.cost ?? 0 const run_tokens = run?.usage?.total_tokens ?? 0 + const run_latency = run?.usage?.total_llm_latency_ms ?? 0 if (subtask_usage_loading) { properties.push({ @@ -580,6 +593,33 @@ }) } + if (subtask_usage_loading) { + properties.push({ + name: "Total Latency", + value: "Loading...", + }) + } else { + const total_latency = run_latency + (subtask_latency_ms ?? 0) + if (total_latency > 0) { + properties.push({ + name: "Total Latency", + value: formatLatency(total_latency), + }) + } + } + + if (subtask_usage_loading) { + properties.push({ + name: "Subtasks Latency", + value: "Loading...", + }) + } else if (subtask_latency_ms && subtask_latency_ms > 0) { + properties.push({ + name: "Subtasks Latency", + value: formatLatency(subtask_latency_ms), + }) + } + return properties } @@ -589,6 +629,7 @@ subtask_cost, subtask_usage_loading, subtask_tokens, + subtask_latency_ms, ) // Feedback diff --git a/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/compare/+page.svelte b/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/compare/+page.svelte index 0c70ff62d..3c78dd6fd 100644 --- a/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/compare/+page.svelte +++ b/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/compare/+page.svelte @@ -5,6 +5,7 @@ import { goto } from "$app/navigation" import { client } from "$lib/api_client" import { createKilnError, KilnError } from "$lib/utils/error_handlers" + import { formatLatency } from "$lib/utils/formatters" import type { Task, TaskRunConfig, Eval } from "$lib/types" import type { components } from "$lib/api_schema" import CompareChart from "$lib/components/compare_chart.svelte" @@ -394,10 +395,11 @@ { label: "Output Tokens", key: "cost::mean_output_tokens" }, { label: "Total Tokens", key: "cost::mean_total_tokens" }, { label: "Cost (USD)", key: "cost::mean_cost" }, + { label: "Latency", key: "cost::mean_total_llm_latency_ms" }, ] features.push({ - category: "Average Usage & Cost", + category: "Average Usage, Cost & Latency", items: costItems, has_default_eval_config: undefined, eval_id: "kiln_cost_section", @@ -463,6 +465,8 @@ return meanUsage.mean_total_tokens ?? null case "mean_cost": return meanUsage.mean_cost ?? null + case "mean_total_llm_latency_ms": + return meanUsage.mean_total_llm_latency_ms ?? null } return null } @@ -489,10 +493,12 @@ const [category, scoreKey] = dataKey.split("::") - // Format cost with currency symbol, tokens as whole numbers, others as decimals + // Format cost with currency symbol, latency with ms/s, tokens as whole numbers, others as decimals if (category === "cost") { if (scoreKey === "mean_cost") { return `$${value.toFixed(7)}` + } else if (scoreKey === "mean_total_llm_latency_ms") { + return formatLatency(value) } else { return value.toFixed(1) } @@ -592,25 +598,10 @@ goto(`/dataset/${project_id}/${task_id}/add_data?${params.toString()}`) } - function getPercentageDifference( - baseValue: string, - compareValue: string, + function getPercentageDifferenceRaw( + base: number | null, + compare: number | null, ): string { - // Return empty if either value is unavailable - if (baseValue === "—" || compareValue === "—") return "" - - // Parse numeric values, handling currency formatting - const parseValue = (val: string): number | null => { - if (val === "—") return null - // Remove currency symbol and parse - const cleaned = val.replace(/^\$/, "") - const parsed = parseFloat(cleaned) - return isNaN(parsed) ? null : parsed - } - - const base = parseValue(baseValue) - const compare = parseValue(compareValue) - if (base === null || compare === null) return "" // Handle division by zero @@ -980,17 +971,17 @@ {getModelValue(selectedModels[i], item.key)} {#if i > 0 && selectedModels[0] !== null} - {@const baseValue = getModelValue( + {@const baseRaw = getModelValueRaw( selectedModels[0], item.key, )} - {@const currentValue = getModelValue( + {@const currentRaw = getModelValueRaw( selectedModels[i], item.key, )} - {@const percentDiff = getPercentageDifference( - baseValue, - currentValue, + {@const percentDiff = getPercentageDifferenceRaw( + baseRaw, + currentRaw, )} {#if percentDiff} diff --git a/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py index 62a78aafc..fda2939d8 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py +++ b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py @@ -3,6 +3,7 @@ import copy import json import logging +import time from dataclasses import dataclass from typing import TYPE_CHECKING, Any, AsyncIterator @@ -66,6 +67,7 @@ def __init__( self._top_logprobs = top_logprobs self._result: AdapterStreamResult | None = None self._iterated = False + self._message_latency: dict[int, int] = {} @property def result(self) -> AdapterStreamResult: @@ -138,7 +140,9 @@ async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]: if not isinstance(prior_output, str): raise RuntimeError(f"assistant message is not a string: {prior_output}") - trace = self._adapter.all_messages_to_trace(self._messages) + trace = self._adapter.all_messages_to_trace( + self._messages, self._message_latency + ) self._result = AdapterStreamResult( run_output=RunOutput( output=prior_output, @@ -167,11 +171,16 @@ async def _stream_model_turn( ) stream = StreamingCompletion(**completion_kwargs) + start = time.monotonic() async for chunk in stream: yield chunk + call_latency_ms = int((time.monotonic() - start) * 1000) response, response_choice = _validate_response(stream.response) usage += self._adapter.usage_from_response(response) + usage.total_llm_latency_ms = ( + usage.total_llm_latency_ms or 0 + ) + call_latency_ms content = response_choice.message.content tool_calls = response_choice.message.tool_calls @@ -181,6 +190,7 @@ async def _stream_model_turn( ) self._messages.append(response_choice.message) + self._message_latency[len(self._messages) - 1] = call_latency_ms if tool_calls and len(tool_calls) > 0: # Check for return_on_tool_call BEFORE processing diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index 549fe4cce..ccdb71cad 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -2,6 +2,7 @@ import copy import json import logging +import time from dataclasses import dataclass from typing import Any, Dict, List, Tuple @@ -79,6 +80,7 @@ class ModelTurnResult: model_choice: Choices | None usage: Usage interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None + message_latency: dict[int, int] | None = None class LiteLlmAdapter(BaseAdapter): @@ -123,6 +125,9 @@ async def _run_model_turn( usage = Usage() messages = list(prior_messages) tool_calls_count = 0 + # LLM call latency in ms, keyed by index in the messages list. + # Kept separate because we don't own the LiteLLM message objects. + message_latency: dict[int, int] = {} while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: # Build completion kwargs for tool calls @@ -134,13 +139,18 @@ async def _run_model_turn( skip_response_format, ) - # Make the completion call + # Make the completion call (timed) + start = time.monotonic() model_response, response_choice = await self.acompletion_checking_response( **completion_kwargs ) + call_latency_ms = int((time.monotonic() - start) * 1000) # count the usage usage += self.usage_from_response(model_response) + usage.total_llm_latency_ms = ( + usage.total_llm_latency_ms or 0 + ) + call_latency_ms # Extract content and tool calls if not hasattr(response_choice, "message"): @@ -154,6 +164,7 @@ async def _run_model_turn( # Add message to messages, so it can be used in the next turn messages.append(response_choice.message) + message_latency[len(messages) - 1] = call_latency_ms # Process tool calls if any if tool_calls and len(tool_calls) > 0: @@ -175,6 +186,7 @@ async def _run_model_turn( model_choice=response_choice, usage=usage, interrupted_by_tool_calls=standard_tool_calls, + message_latency=message_latency, ) # otherwise: process tool calls internally until final output @@ -194,6 +206,7 @@ async def _run_model_turn( model_response=model_response, model_choice=response_choice, usage=usage, + message_latency=message_latency, ) # If there were tool calls, increment counter and continue @@ -209,6 +222,7 @@ async def _run_model_turn( model_response=model_response, model_choice=response_choice, usage=usage, + message_latency=message_latency, ) # If we get here with no content and no tool calls, break @@ -240,6 +254,7 @@ async def _run( prior_output: str | None = None final_choice: Choices | None = None turns = 0 + message_latency: dict[int, int] = {} # Same loop for both fresh runs and prior_trace continuation. # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). @@ -274,6 +289,8 @@ async def _run( ) usage += turn_result.usage + if turn_result.message_latency: + message_latency.update(turn_result.message_latency) prior_output = turn_result.assistant_message messages = turn_result.all_messages @@ -281,7 +298,7 @@ async def _run( # Check if we were interrupted by tool calls if turn_result.interrupted_by_tool_calls: - trace = self.all_messages_to_trace(messages) + trace = self.all_messages_to_trace(messages, message_latency) intermediate_outputs = chat_formatter.intermediate_outputs() output = RunOutput( output=prior_output or "", @@ -305,7 +322,7 @@ async def _run( if not isinstance(prior_output, str): raise RuntimeError(f"assistant message is not a string: {prior_output}") - trace = self.all_messages_to_trace(messages) + trace = self.all_messages_to_trace(messages, message_latency) output = RunOutput( output=prior_output, intermediate_outputs=intermediate_outputs, @@ -859,7 +876,9 @@ async def run_tool_and_format( return assistant_output_from_toolcall, tool_call_response_messages def litellm_message_to_trace_message( - self, raw_message: LiteLLMMessage + self, + raw_message: LiteLLMMessage, + latency_ms: int | None = None, ) -> ChatCompletionAssistantMessageParamWrapper: """ Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper @@ -898,6 +917,9 @@ def litellm_message_to_trace_message( if len(open_ai_tool_calls) > 0: message["tool_calls"] = open_ai_tool_calls + if latency_ms is not None: + message["latency_ms"] = latency_ms + if not message.get("content") and not message.get("tool_calls"): raise ValueError( "Model returned an assistant message, but no content or tool calls. This is not supported." @@ -906,15 +928,18 @@ def litellm_message_to_trace_message( return message def all_messages_to_trace( - self, messages: list[ChatCompletionMessageIncludingLiteLLM] + self, + messages: list[ChatCompletionMessageIncludingLiteLLM], + message_latency: dict[int, int] | None = None, ) -> list[ChatCompletionMessageParam]: """ Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. """ trace: list[ChatCompletionMessageParam] = [] - for message in messages: + for i, message in enumerate(messages): if isinstance(message, LiteLLMMessage): - trace.append(self.litellm_message_to_trace_message(message)) + latency_ms = message_latency.get(i) if message_latency else None + trace.append(self.litellm_message_to_trace_message(message, latency_ms)) else: trace.append(message) return trace diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 10878c44d..e741a6a3e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -2429,3 +2429,165 @@ async def description(self) -> str: assert success_msg["content"] == "all good" assert success_msg.get("is_error") is None assert success_msg.get("error_message") is None + + +class TestLatencyTracking: + """Tests for LLM call latency tracking in _run_model_turn and trace messages.""" + + @pytest.fixture + def adapter(self, config, mock_task): + return LiteLlmAdapter(config=config, kiln_task=mock_task) + + @pytest.fixture + def simple_response(self): + return ModelResponse( + model="test-model", + choices=[{"message": {"content": "Hello!", "tool_calls": None}}], + ) + + @pytest.fixture + def provider(self): + from kiln_ai.adapters.ml_model_list import KilnModelProvider + + return KilnModelProvider( + name=ModelProviderName.openrouter, model_id="test-model" + ) + + @pytest.mark.asyncio + async def test_run_model_turn_sets_latency_on_usage( + self, adapter, simple_response, provider + ): + """_run_model_turn() should set total_llm_latency_ms on usage.""" + monotonic_values = [0.0, 0.5] # 500ms call + with patch.object(adapter, "build_completion_kwargs", return_value={}): + with patch.object( + adapter, + "acompletion_checking_response", + return_value=(simple_response, simple_response.choices[0]), + ): + with patch( + "kiln_ai.adapters.model_adapters.litellm_adapter.time.monotonic", + side_effect=monotonic_values, + ): + result = await adapter._run_model_turn( + provider, [{"role": "user", "content": "Hi"}], None, False + ) + + assert result.usage.total_llm_latency_ms == 500 + + @pytest.mark.asyncio + async def test_run_model_turn_sets_latency_on_message_latency_dict( + self, adapter, simple_response, provider + ): + """_run_model_turn() should record latency in the message_latency dict.""" + monotonic_values = [0.0, 0.25] # 250ms + with patch.object(adapter, "build_completion_kwargs", return_value={}): + with patch.object( + adapter, + "acompletion_checking_response", + return_value=(simple_response, simple_response.choices[0]), + ): + with patch( + "kiln_ai.adapters.model_adapters.litellm_adapter.time.monotonic", + side_effect=monotonic_values, + ): + result = await adapter._run_model_turn( + provider, [{"role": "user", "content": "Hi"}], None, False + ) + + assert result.message_latency is not None + # Latency is keyed by message index in the messages list + last_index = len(result.all_messages) - 1 + assert result.message_latency[last_index] == 250 + + @pytest.mark.asyncio + async def test_run_model_turn_accumulates_latency_across_tool_calls( + self, adapter, provider + ): + """Latency should accumulate across multiple LLM calls in a tool-call loop.""" + # First LLM call: model requests a regular tool (not task_response) + tool_call_response = ModelResponse( + model="test-model", + choices=[ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "some_tool", + "arguments": '{"arg": "val"}', + }, + } + ], + } + } + ], + ) + # Second LLM call: model returns final content + final_response = ModelResponse( + model="test-model", + choices=[ + { + "message": { + "content": "Final answer", + } + } + ], + ) + + # monotonic: start1, end1 (200ms), start2, end2 (300ms) + monotonic_values = [0.0, 0.2, 0.2, 0.5] + with patch.object(adapter, "build_completion_kwargs", return_value={}): + with patch.object( + adapter, + "acompletion_checking_response", + side_effect=[ + (tool_call_response, tool_call_response.choices[0]), + (final_response, final_response.choices[0]), + ], + ): + with patch.object( + adapter, + "process_tool_calls", + return_value=( + None, + [ + { + "role": "tool", + "content": "tool result", + "tool_call_id": "call_1", + } + ], + ), + ): + with patch( + "kiln_ai.adapters.model_adapters.litellm_adapter.time.monotonic", + side_effect=monotonic_values, + ): + result = await adapter._run_model_turn( + provider, [{"role": "user", "content": "Hi"}], None, False + ) + + # 200ms + 300ms = 500ms total + assert result.usage.total_llm_latency_ms == 500 + + def test_litellm_message_to_trace_message_includes_latency(self, adapter): + """litellm_message_to_trace_message should include latency_ms when provided.""" + from litellm.types.utils import Message as LiteLLMMessage + + msg = LiteLLMMessage(role="assistant", content="Hello") + + trace_msg = adapter.litellm_message_to_trace_message(msg, latency_ms=123) + assert trace_msg["latency_ms"] == 123 + + def test_litellm_message_to_trace_message_no_latency(self, adapter): + """litellm_message_to_trace_message should omit latency_ms when not provided.""" + from litellm.types.utils import Message as LiteLLMMessage + + msg = LiteLLMMessage(role="assistant", content="Hello") + + trace_msg = adapter.litellm_message_to_trace_message(msg) + assert "latency_ms" not in trace_msg diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index 2198696b4..1d62c9fa9 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -46,6 +46,11 @@ class Usage(BaseModel): description="Number of tokens served from prompt cache. None if not reported.", ge=0, ) + total_llm_latency_ms: int | None = Field( + default=None, + description="Total time spent waiting on LLM API calls in milliseconds. Sum of per-call latencies, excludes tool execution time.", + ge=0, + ) def __add__(self, other: "Usage") -> "Usage": """Add two Usage objects together, handling None values gracefully. @@ -82,6 +87,9 @@ def _add_optional_float(a: float | None, b: float | None) -> float | None: total_tokens=_add_optional_int(self.total_tokens, other.total_tokens), cost=_add_optional_float(self.cost, other.cost), cached_tokens=_add_optional_int(self.cached_tokens, other.cached_tokens), + total_llm_latency_ms=_add_optional_int( + self.total_llm_latency_ms, other.total_llm_latency_ms + ), ) diff --git a/libs/core/kiln_ai/datamodel/test_example_models.py b/libs/core/kiln_ai/datamodel/test_example_models.py index d93985255..85c2e08b8 100644 --- a/libs/core/kiln_ai/datamodel/test_example_models.py +++ b/libs/core/kiln_ai/datamodel/test_example_models.py @@ -972,3 +972,46 @@ def test_usage_addition_immutability(): assert result.output_tokens == 125 assert result.total_tokens == 425 assert result.cost == 0.015 + + +def test_usage_total_llm_latency_ms_field(): + """Test total_llm_latency_ms field validation.""" + usage = Usage(total_llm_latency_ms=500) + assert usage.total_llm_latency_ms == 500 + + usage_none = Usage() + assert usage_none.total_llm_latency_ms is None + + usage_zero = Usage(total_llm_latency_ms=0) + assert usage_zero.total_llm_latency_ms == 0 + + with pytest.raises(ValidationError): + Usage(total_llm_latency_ms=-1) + + +@pytest.mark.parametrize( + "latency1,latency2,expected", + [ + (None, None, None), + (None, 300, 300), + (200, None, 200), + (200, 300, 500), + (0, 100, 100), + ], +) +def test_usage_addition_with_latency(latency1, latency2, expected): + """Test Usage.__add__ handles total_llm_latency_ms correctly.""" + u1 = Usage(total_llm_latency_ms=latency1) + u2 = Usage(total_llm_latency_ms=latency2) + result = u1 + u2 + assert result.total_llm_latency_ms == expected + + +def test_usage_backwards_compat_without_latency(): + """Test that old Usage JSON without total_llm_latency_ms deserializes as None.""" + old_json = ( + '{"input_tokens": 100, "output_tokens": 50, "total_tokens": 150, "cost": 0.01}' + ) + usage = Usage.model_validate_json(old_json) + assert usage.total_llm_latency_ms is None + assert usage.input_tokens == 100 diff --git a/libs/core/kiln_ai/utils/open_ai_types.py b/libs/core/kiln_ai/utils/open_ai_types.py index 0b3b1039c..1f76bd345 100644 --- a/libs/core/kiln_ai/utils/open_ai_types.py +++ b/libs/core/kiln_ai/utils/open_ai_types.py @@ -83,6 +83,9 @@ class ChatCompletionAssistantMessageParamWrapper(TypedDict, total=False): tool_calls: List[ChatCompletionMessageToolCallParam] """The tool calls generated by the model, such as function calls.""" + latency_ms: Optional[int] + """Time spent waiting on this specific LLM API call in milliseconds.""" + class ChatCompletionToolMessageParamWrapper(TypedDict, total=False): content: Required[Union[str, Iterable[ChatCompletionContentPartTextParam]]] diff --git a/libs/core/kiln_ai/utils/test_open_ai_types.py b/libs/core/kiln_ai/utils/test_open_ai_types.py index c1580e103..bfed2bcbc 100644 --- a/libs/core/kiln_ai/utils/test_open_ai_types.py +++ b/libs/core/kiln_ai/utils/test_open_ai_types.py @@ -41,6 +41,10 @@ def test_assistant_message_param_properties_match(): assert "reasoning_content" in kiln_properties, "Kiln should have reasoning_content" kiln_properties.remove("reasoning_content") + # latency_ms is a Kiln-added property for LLM call timing. Confirm it's there and remove it. + assert "latency_ms" in kiln_properties, "Kiln should have latency_ms" + kiln_properties.remove("latency_ms") + assert openai_properties == kiln_properties, ( f"Property names don't match. " f"OpenAI has: {openai_properties}, "