Skip to content

Commit da8cd83

Browse files
committed
refactor(web): derive token usage post-stream and join steps by position
Collect usage from researchStream.steps and response.messages after the stream completes (covers approval-gated and failed tool calls, off the hot path), nest tool estimates under their step in a single stepTokenUsage array, and join UI steps to entries by stepIndex.
1 parent 7bfaa2a commit da8cd83

8 files changed

Lines changed: 283 additions & 260 deletions

File tree

packages/web/src/ee/features/chat/agent.test.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ const createAssistantMessage = (parts: SBChatMessagePart[]): SBChatMessage => ({
137137
});
138138

139139
const createFakeStreamResult = () => ({
140-
response: Promise.resolve(new Response()),
140+
response: Promise.resolve({ messages: [] }),
141+
steps: Promise.resolve([]),
141142
totalUsage: Promise.resolve({
142143
inputTokens: 1,
143144
outputTokens: 1,

packages/web/src/ee/features/chat/agent.ts

Lines changed: 68 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { SBChatMessage, SBChatMessageMetadata, StepTokenUsageEntry, ToolTokenUsageEntry } from "@/features/chat/types";
2-
import { estimateModelToolOutputTokens, estimateToolOutputTokens } from "@/features/chat/tokenEstimation";
2+
import { estimateModelToolOutputTokens } from "@/features/chat/tokenEstimation";
33
import { getFileSource } from '@/features/git';
44
import { isServiceError } from "@/lib/utils";
55
import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider";
@@ -150,9 +150,6 @@ export const createMessageStream = async ({
150150

151151
const startTime = new Date();
152152

153-
const collectedToolTokenUsage: ToolTokenUsageEntry[] = [];
154-
const collectedStepTokenUsage: StepTokenUsageEntry[] = [];
155-
156153
const researchStream = await createAgentStream({
157154
model,
158155
providerOptions: modelProviderOptions,
@@ -167,12 +164,6 @@ export const createMessageStream = async ({
167164
data: source,
168165
});
169166
},
170-
onToolTokenUsage: (entry) => {
171-
collectedToolTokenUsage.push(entry);
172-
},
173-
onStepTokenUsage: (entry) => {
174-
collectedStepTokenUsage.push(entry);
175-
},
176167
onMcpServerDiscovered: (sanitizedName, faviconUrl) => {
177168
writer.write({
178169
type: 'data-mcp-server',
@@ -200,6 +191,59 @@ export const createMessageStream = async ({
200191
});
201192

202193
const totalUsage = await researchStream.totalUsage;
194+
const steps = await researchStream.steps;
195+
const response = await researchStream.response;
196+
197+
// Tool output estimates are derived from `response.messages` rather
198+
// than per-step `toolResults` because the response messages cover
199+
// tool calls that never run inside a step — approval-gated tools
200+
// execute before the step loop, and thrown tool errors are recorded
201+
// as `tool-error` parts that `toolResults` excludes. Their
202+
// `tool-result` parts also carry the output in model-visible form
203+
// (`toModelOutput` already applied), which is exactly the payload
204+
// whose token footprint we want to estimate.
205+
const toolUsageByToolCallId = new Map<string, ToolTokenUsageEntry>(
206+
response.messages.flatMap((message) =>
207+
message.role !== 'tool' ? [] : message.content.flatMap((part) =>
208+
part.type !== 'tool-result' ? [] : [[part.toolCallId, {
209+
toolCallId: part.toolCallId,
210+
toolName: part.toolName,
211+
estimatedOutputTokens: estimateModelToolOutputTokens(part.output),
212+
}] as const]
213+
)
214+
)
215+
);
216+
217+
// One entry per step, in step order. The UI joins its step groups
218+
// to these entries by array position, so the order and count must
219+
// mirror the stream's steps exactly. Tool calls nest under the
220+
// step they ran in; `content` is matched rather than `toolResults`
221+
// so that thrown tool errors (`tool-error` parts, which
222+
// `toolResults` excludes) are still attributed to their step.
223+
const stepTokenUsage: StepTokenUsageEntry[] = steps.map(({ usage, content }) => ({
224+
inputTokens: usage.inputTokens,
225+
outputTokens: usage.outputTokens,
226+
cacheReadTokens: usage.inputTokenDetails?.cacheReadTokens,
227+
tools: content.flatMap((part) => {
228+
if (part.type !== 'tool-result' && part.type !== 'tool-error') {
229+
return [];
230+
}
231+
const entry = toolUsageByToolCallId.get(part.toolCallId);
232+
if (!entry) {
233+
return [];
234+
}
235+
toolUsageByToolCallId.delete(part.toolCallId);
236+
return [entry];
237+
}),
238+
}));
239+
240+
// Any estimates left unclaimed belong to tool calls that executed
241+
// before the step loop (approval continuations). Their output
242+
// enters the context as input to this phase's first step, so nest
243+
// them under it.
244+
if (toolUsageByToolCallId.size > 0 && stepTokenUsage.length > 0) {
245+
stepTokenUsage[0].tools.unshift(...toolUsageByToolCallId.values());
246+
}
203247

204248
writer.write({
205249
type: 'message-metadata',
@@ -210,22 +254,9 @@ export const createMessageStream = async ({
210254
totalCacheReadTokens: (priorMetadata?.totalCacheReadTokens ?? 0) + (totalUsage.inputTokenDetails?.cacheReadTokens ?? 0),
211255
totalCacheWriteTokens: (priorMetadata?.totalCacheWriteTokens ?? 0) + (totalUsage.inputTokenDetails?.cacheWriteTokens ?? 0),
212256
totalResponseTimeMs: (priorMetadata?.totalResponseTimeMs ?? 0) + (new Date().getTime() - startTime.getTime()),
213-
// Unlike the token totals above, these are concatenated (not
214-
// summed) across approval-continuation phases so tool calls
215-
// and steps from the pre-approval phase are preserved. Step
216-
// indices captured in this phase are relative to this phase's
217-
// stream, so offset them by the prior phase's step count to
218-
// keep them valid against the concatenated step array.
219-
toolTokenUsage: [
220-
...(priorMetadata?.toolTokenUsage ?? []),
221-
...collectedToolTokenUsage.map((entry) => ({
222-
...entry,
223-
stepIndex: entry.stepIndex !== undefined
224-
? entry.stepIndex + (priorMetadata?.stepTokenUsage?.length ?? 0)
225-
: undefined,
226-
})),
227-
],
228-
stepTokenUsage: [...(priorMetadata?.stepTokenUsage ?? []), ...collectedStepTokenUsage],
257+
// Concatenated (not summed) across approval-continuation
258+
// phases so earlier phases' steps are preserved in order.
259+
stepTokenUsage: [...(priorMetadata?.stepTokenUsage ?? []), ...stepTokenUsage],
229260
modelName,
230261
traceId,
231262
...metadata,
@@ -253,8 +284,6 @@ interface AgentOptions {
253284
inputMessages: ModelMessage[];
254285
inputSources: Source[];
255286
onWriteSource: (source: Source) => void;
256-
onToolTokenUsage?: (entry: ToolTokenUsageEntry) => void;
257-
onStepTokenUsage?: (entry: StepTokenUsageEntry) => void;
258287
onMcpServerDiscovered: (sanitizedName: string, faviconUrl: string) => void;
259288
onMcpServerFailed: (serverName: string) => void;
260289
traceId: string;
@@ -273,8 +302,6 @@ const createAgentStream = async ({
273302
selectedRepos,
274303
disabledMcpServerIds,
275304
onWriteSource,
276-
onToolTokenUsage,
277-
onStepTokenUsage,
278305
onMcpServerDiscovered,
279306
onMcpServerFailed,
280307
traceId,
@@ -460,41 +487,21 @@ const createAgentStream = async ({
460487
logger.warn(`Tool call repair failed for "${toolCall.toolName}": ${error.message}`);
461488
return null;
462489
},
463-
onStepFinish: async ({ stepNumber, usage, toolResults }) => {
464-
onStepTokenUsage?.({
465-
inputTokens: usage.inputTokens,
466-
outputTokens: usage.outputTokens,
467-
cacheReadTokens: usage.inputTokenDetails?.cacheReadTokens,
468-
});
469-
470-
for (const { toolCallId, toolName, input, output, dynamic } of toolResults) {
471-
// Token estimation runs for every tool result — including
472-
// dynamic (MCP) tools and error outputs — since they all
473-
// re-enter the model's context on the next step.
474-
//
475-
// The model never sees the raw output object of tools that
476-
// define a `toModelOutput` mapping (e.g. builtin tools send
477-
// only their `output` text, not the UI-only metadata), so
478-
// estimate the mapped payload when one exists. Tools without
479-
// the mapping have their output sent as a JSON object.
480-
const tool = allTools[toolName];
481-
const estimatedOutputTokens = tool?.toModelOutput
482-
? estimateModelToolOutputTokens(await tool.toModelOutput({ toolCallId, input, output }))
483-
: estimateToolOutputTokens(output);
484-
485-
onToolTokenUsage?.({
486-
toolCallId,
487-
toolName,
488-
estimatedOutputTokens,
489-
stepIndex: stepNumber,
490-
});
491-
490+
// Token usage collection deliberately does NOT happen here: the SDK
491+
// awaits this callback before starting the next step, so it must
492+
// stay cheap, and `toolResults` misses tool calls that never run
493+
// inside a step (approval-gated tools execute before the step loop)
494+
// as well as thrown tool errors (recorded as `tool-error` parts).
495+
// Both are instead derived post-stream in `createMessageStream`
496+
// from `steps` and `response.messages`.
497+
onStepFinish: ({ toolResults }) => {
498+
toolResults.forEach(({ output, dynamic }) => {
492499
if (dynamic || isServiceError(output)) {
493-
continue;
500+
return;
494501
}
495502

496503
output.sources?.forEach(onWriteSource);
497-
}
504+
});
498505
},
499506
experimental_telemetry: {
500507
isEnabled: env.SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true',

packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -91,33 +91,57 @@ const ChatThreadListItemComponent = forwardRef<HTMLDivElement, ChatThreadListIte
9191
// should be visible to the user. By "steps", we mean parts that originated
9292
// from the same LLM invocation. By "visibile", we mean parts that have some
9393
// visual representation in the UI (e.g., text, reasoning, tool calls, etc.).
94-
const uiVisibleThinkingSteps = useMemo(() => {
95-
const steps = groupMessageIntoSteps(assistantMessage?.parts ?? []);
96-
97-
// Filter out the answerPart and empty steps
98-
return steps
99-
.map(
100-
(step) => step
101-
// First, filter out any parts that are not text
102-
.filter((part) => {
103-
if (part.type === 'text') {
104-
return !part.text.includes(ANSWER_TAG);
105-
}
106-
107-
return true;
108-
})
109-
.filter((part) => {
110-
// Only include text, reasoning, and tool parts
111-
return (
112-
part.type === 'text' ||
113-
part.type === 'reasoning' ||
114-
part.type.startsWith('tool-') ||
115-
part.type === 'dynamic-tool'
116-
)
117-
})
118-
)
94+
//
95+
// Each step is tagged with its stepIndex — the invocation's position in
96+
// the turn, which indexes into `metadata.stepTokenUsage`. Indices are
97+
// assigned by counting 'step-start' markers (one per invocation) BEFORE
98+
// any filtering, so dropping empty or answer-only steps below cannot
99+
// shift the indices of the steps that remain.
100+
const { uiVisibleThinkingSteps, answerStepIndex } = useMemo(() => {
101+
const groupedParts = groupMessageIntoSteps(assistantMessage?.parts ?? []);
102+
103+
// Parts written before the first step-start (e.g. data parts) don't
104+
// belong to any step; they get stepIndex -1 and never survive the
105+
// visibility filters below.
106+
let stepIndex = -1;
107+
let answerStepIndex: number | undefined = undefined;
108+
109+
const steps = groupedParts
110+
.map((stepParts) => {
111+
if (stepParts[0]?.type === 'step-start') {
112+
stepIndex++;
113+
}
114+
115+
if (stepParts.some((part) => part.type === 'text' && part.text.includes(ANSWER_TAG))) {
116+
answerStepIndex = stepIndex;
117+
}
118+
119+
return {
120+
stepIndex,
121+
parts: stepParts
122+
// First, filter out the answer text
123+
.filter((part) => {
124+
if (part.type === 'text') {
125+
return !part.text.includes(ANSWER_TAG);
126+
}
127+
128+
return true;
129+
})
130+
.filter((part) => {
131+
// Only include text, reasoning, and tool parts
132+
return (
133+
part.type === 'text' ||
134+
part.type === 'reasoning' ||
135+
part.type.startsWith('tool-') ||
136+
part.type === 'dynamic-tool'
137+
)
138+
}),
139+
};
140+
})
119141
// Then, filter out any steps that are empty
120-
.filter(step => step.length > 0);
142+
.filter((step) => step.parts.length > 0);
143+
144+
return { uiVisibleThinkingSteps: steps, answerStepIndex };
121145
}, [assistantMessage?.parts]);
122146

123147
// "thinking" is when the agent is generating output that is not the answer.
@@ -379,6 +403,7 @@ const ChatThreadListItemComponent = forwardRef<HTMLDivElement, ChatThreadListIte
379403
isNetworkActive={isNetworkActive}
380404
isAwaitingToolApproval={isAwaitingToolApproval}
381405
thinkingSteps={uiVisibleThinkingSteps}
406+
answerStepIndex={answerStepIndex}
382407
metadata={assistantMessage?.metadata}
383408
/>
384409

packages/web/src/ee/features/chat/components/chatThread/detailsCard.test.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ describe('DetailsCard', () => {
111111
isTurnInProgress={true}
112112
isNetworkActive={false}
113113
isAwaitingToolApproval={false}
114-
thinkingSteps={[[failedActivationPart]]}
114+
thinkingSteps={[{ stepIndex: 0, parts: [failedActivationPart] }]}
115115
/>
116116
</TooltipProvider>
117117
);

0 commit comments

Comments
 (0)