11import { SBChatMessage , SBChatMessageMetadata , StepTokenUsageEntry , ToolTokenUsageEntry } from "@/features/chat/types" ;
2- import { estimateModelToolOutputTokens , estimateToolOutputTokens } from "@/features/chat/tokenEstimation" ;
2+ import { estimateModelToolOutputTokens } from "@/features/chat/tokenEstimation" ;
33import { getFileSource } from '@/features/git' ;
44import { isServiceError } from "@/lib/utils" ;
55import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider" ;
@@ -150,9 +150,6 @@ export const createMessageStream = async ({
150150
151151 const startTime = new Date ( ) ;
152152
153- const collectedToolTokenUsage : ToolTokenUsageEntry [ ] = [ ] ;
154- const collectedStepTokenUsage : StepTokenUsageEntry [ ] = [ ] ;
155-
156153 const researchStream = await createAgentStream ( {
157154 model,
158155 providerOptions : modelProviderOptions ,
@@ -167,12 +164,6 @@ export const createMessageStream = async ({
167164 data : source ,
168165 } ) ;
169166 } ,
170- onToolTokenUsage : ( entry ) => {
171- collectedToolTokenUsage . push ( entry ) ;
172- } ,
173- onStepTokenUsage : ( entry ) => {
174- collectedStepTokenUsage . push ( entry ) ;
175- } ,
176167 onMcpServerDiscovered : ( sanitizedName , faviconUrl ) => {
177168 writer . write ( {
178169 type : 'data-mcp-server' ,
@@ -200,6 +191,59 @@ export const createMessageStream = async ({
200191 } ) ;
201192
202193 const totalUsage = await researchStream . totalUsage ;
194+ const steps = await researchStream . steps ;
195+ const response = await researchStream . response ;
196+
197+ // Tool output estimates are derived from `response.messages` rather
198+ // than per-step `toolResults` because the response messages cover
199+ // tool calls that never run inside a step — approval-gated tools
200+ // execute before the step loop, and thrown tool errors are recorded
201+ // as `tool-error` parts that `toolResults` excludes. Their
202+ // `tool-result` parts also carry the output in model-visible form
203+ // (`toModelOutput` already applied), which is exactly the payload
204+ // whose token footprint we want to estimate.
205+ const toolUsageByToolCallId = new Map < string , ToolTokenUsageEntry > (
206+ response . messages . flatMap ( ( message ) =>
207+ message . role !== 'tool' ? [ ] : message . content . flatMap ( ( part ) =>
208+ part . type !== 'tool-result' ? [ ] : [ [ part . toolCallId , {
209+ toolCallId : part . toolCallId ,
210+ toolName : part . toolName ,
211+ estimatedOutputTokens : estimateModelToolOutputTokens ( part . output ) ,
212+ } ] as const ]
213+ )
214+ )
215+ ) ;
216+
217+ // One entry per step, in step order. The UI joins its step groups
218+ // to these entries by array position, so the order and count must
219+ // mirror the stream's steps exactly. Tool calls nest under the
220+ // step they ran in; `content` is matched rather than `toolResults`
221+ // so that thrown tool errors (`tool-error` parts, which
222+ // `toolResults` excludes) are still attributed to their step.
223+ const stepTokenUsage : StepTokenUsageEntry [ ] = steps . map ( ( { usage, content } ) => ( {
224+ inputTokens : usage . inputTokens ,
225+ outputTokens : usage . outputTokens ,
226+ cacheReadTokens : usage . inputTokenDetails ?. cacheReadTokens ,
227+ tools : content . flatMap ( ( part ) => {
228+ if ( part . type !== 'tool-result' && part . type !== 'tool-error' ) {
229+ return [ ] ;
230+ }
231+ const entry = toolUsageByToolCallId . get ( part . toolCallId ) ;
232+ if ( ! entry ) {
233+ return [ ] ;
234+ }
235+ toolUsageByToolCallId . delete ( part . toolCallId ) ;
236+ return [ entry ] ;
237+ } ) ,
238+ } ) ) ;
239+
240+ // Any estimates left unclaimed belong to tool calls that executed
241+ // before the step loop (approval continuations). Their output
242+ // enters the context as input to this phase's first step, so nest
243+ // them under it.
244+ if ( toolUsageByToolCallId . size > 0 && stepTokenUsage . length > 0 ) {
245+ stepTokenUsage [ 0 ] . tools . unshift ( ...toolUsageByToolCallId . values ( ) ) ;
246+ }
203247
204248 writer . write ( {
205249 type : 'message-metadata' ,
@@ -210,22 +254,9 @@ export const createMessageStream = async ({
210254 totalCacheReadTokens : ( priorMetadata ?. totalCacheReadTokens ?? 0 ) + ( totalUsage . inputTokenDetails ?. cacheReadTokens ?? 0 ) ,
211255 totalCacheWriteTokens : ( priorMetadata ?. totalCacheWriteTokens ?? 0 ) + ( totalUsage . inputTokenDetails ?. cacheWriteTokens ?? 0 ) ,
212256 totalResponseTimeMs : ( priorMetadata ?. totalResponseTimeMs ?? 0 ) + ( new Date ( ) . getTime ( ) - startTime . getTime ( ) ) ,
213- // Unlike the token totals above, these are concatenated (not
214- // summed) across approval-continuation phases so tool calls
215- // and steps from the pre-approval phase are preserved. Step
216- // indices captured in this phase are relative to this phase's
217- // stream, so offset them by the prior phase's step count to
218- // keep them valid against the concatenated step array.
219- toolTokenUsage : [
220- ...( priorMetadata ?. toolTokenUsage ?? [ ] ) ,
221- ...collectedToolTokenUsage . map ( ( entry ) => ( {
222- ...entry ,
223- stepIndex : entry . stepIndex !== undefined
224- ? entry . stepIndex + ( priorMetadata ?. stepTokenUsage ?. length ?? 0 )
225- : undefined ,
226- } ) ) ,
227- ] ,
228- stepTokenUsage : [ ...( priorMetadata ?. stepTokenUsage ?? [ ] ) , ...collectedStepTokenUsage ] ,
257+ // Concatenated (not summed) across approval-continuation
258+ // phases so earlier phases' steps are preserved in order.
259+ stepTokenUsage : [ ...( priorMetadata ?. stepTokenUsage ?? [ ] ) , ...stepTokenUsage ] ,
229260 modelName,
230261 traceId,
231262 ...metadata ,
@@ -253,8 +284,6 @@ interface AgentOptions {
253284 inputMessages : ModelMessage [ ] ;
254285 inputSources : Source [ ] ;
255286 onWriteSource : ( source : Source ) => void ;
256- onToolTokenUsage ?: ( entry : ToolTokenUsageEntry ) => void ;
257- onStepTokenUsage ?: ( entry : StepTokenUsageEntry ) => void ;
258287 onMcpServerDiscovered : ( sanitizedName : string , faviconUrl : string ) => void ;
259288 onMcpServerFailed : ( serverName : string ) => void ;
260289 traceId : string ;
@@ -273,8 +302,6 @@ const createAgentStream = async ({
273302 selectedRepos,
274303 disabledMcpServerIds,
275304 onWriteSource,
276- onToolTokenUsage,
277- onStepTokenUsage,
278305 onMcpServerDiscovered,
279306 onMcpServerFailed,
280307 traceId,
@@ -460,41 +487,21 @@ const createAgentStream = async ({
460487 logger . warn ( `Tool call repair failed for "${ toolCall . toolName } ": ${ error . message } ` ) ;
461488 return null ;
462489 } ,
463- onStepFinish : async ( { stepNumber, usage, toolResults } ) => {
464- onStepTokenUsage ?.( {
465- inputTokens : usage . inputTokens ,
466- outputTokens : usage . outputTokens ,
467- cacheReadTokens : usage . inputTokenDetails ?. cacheReadTokens ,
468- } ) ;
469-
470- for ( const { toolCallId, toolName, input, output, dynamic } of toolResults ) {
471- // Token estimation runs for every tool result — including
472- // dynamic (MCP) tools and error outputs — since they all
473- // re-enter the model's context on the next step.
474- //
475- // The model never sees the raw output object of tools that
476- // define a `toModelOutput` mapping (e.g. builtin tools send
477- // only their `output` text, not the UI-only metadata), so
478- // estimate the mapped payload when one exists. Tools without
479- // the mapping have their output sent as a JSON object.
480- const tool = allTools [ toolName ] ;
481- const estimatedOutputTokens = tool ?. toModelOutput
482- ? estimateModelToolOutputTokens ( await tool . toModelOutput ( { toolCallId, input, output } ) )
483- : estimateToolOutputTokens ( output ) ;
484-
485- onToolTokenUsage ?.( {
486- toolCallId,
487- toolName,
488- estimatedOutputTokens,
489- stepIndex : stepNumber ,
490- } ) ;
491-
490+ // Token usage collection deliberately does NOT happen here: the SDK
491+ // awaits this callback before starting the next step, so it must
492+ // stay cheap, and `toolResults` misses tool calls that never run
493+ // inside a step (approval-gated tools execute before the step loop)
494+ // as well as thrown tool errors (recorded as `tool-error` parts).
495+ // Both are instead derived post-stream in `createMessageStream`
496+ // from `steps` and `response.messages`.
497+ onStepFinish : ( { toolResults } ) => {
498+ toolResults . forEach ( ( { output, dynamic } ) => {
492499 if ( dynamic || isServiceError ( output ) ) {
493- continue ;
500+ return ;
494501 }
495502
496503 output . sources ?. forEach ( onWriteSource ) ;
497- }
504+ } ) ;
498505 } ,
499506 experimental_telemetry : {
500507 isEnabled : env . SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true' ,
0 commit comments