diff --git a/src/extension/conversation/vscode-node/languageModelAccess.ts b/src/extension/conversation/vscode-node/languageModelAccess.ts index c341dd609f..2b41999104 100644 --- a/src/extension/conversation/vscode-node/languageModelAccess.ts +++ b/src/extension/conversation/vscode-node/languageModelAccess.ts @@ -25,6 +25,7 @@ import { IOctoKitService } from '../../../platform/github/common/githubService'; import { ILogService } from '../../../platform/log/common/logService'; import { isAnthropicToolSearchEnabled } from '../../../platform/networking/common/anthropic'; import { FinishedCallback, OpenAiFunctionTool, OptionalChatRequestParams } from '../../../platform/networking/common/fetch'; +import type { APIUsage } from '../../../platform/networking/common/openai'; import { IChatEndpoint, IEndpoint } from '../../../platform/networking/common/networking'; import { IOTelService, type OTelModelOptions } from '../../../platform/otel/common/otelService'; import { retrieveCapturingTokenByCorrelation, runWithCapturingToken } from '../../../platform/requestLogger/node/requestLogger'; @@ -501,7 +502,7 @@ export class CopilotLanguageModelWrapper extends Disposable { super(); } - private async _provideLanguageModelResponse(_endpoint: IChatEndpoint, _messages: Array, _options: vscode.ProvideLanguageModelChatResponseOptions, extensionId: string | undefined, callback: FinishedCallback, token: vscode.CancellationToken): Promise { + private async _provideLanguageModelResponse(_endpoint: IChatEndpoint, _messages: Array, _options: vscode.ProvideLanguageModelChatResponseOptions, extensionId: string | undefined, callback: FinishedCallback, token: vscode.CancellationToken): Promise { if (extensionId === 'core') { extensionId = undefined; } @@ -678,6 +679,8 @@ export class CopilotLanguageModelWrapper extends Disposable { tokenLimit } ); + + return result.usage; } async provideLanguageModelResponse(endpoint: IChatEndpoint, messages: Array, options: vscode.ProvideLanguageModelChatResponseOptions, extensionId: string | undefined, progress: vscode.Progress, token: vscode.CancellationToken): Promise { @@ -718,7 +721,13 @@ export class CopilotLanguageModelWrapper extends Disposable { return undefined; }; - return this._provideLanguageModelResponse(endpoint, messages, options, extensionId, finishCallback, token); + const usage = await this._provideLanguageModelResponse(endpoint, messages, options, extensionId, finishCallback, token); + if (usage) { + progress.report(new vscode.LanguageModelDataPart( + new TextEncoder().encode(JSON.stringify(usage)), + CustomDataPartMimeTypes.Usage + )); + } } async provideTokenCount(endpoint: IEndpoint, message: string | vscode.LanguageModelChatMessage | vscode.LanguageModelChatMessage2): Promise { diff --git a/src/platform/endpoint/common/endpointTypes.ts b/src/platform/endpoint/common/endpointTypes.ts index f1a3a75c4e..084fa8f1dc 100644 --- a/src/platform/endpoint/common/endpointTypes.ts +++ b/src/platform/endpoint/common/endpointTypes.ts @@ -9,6 +9,7 @@ export namespace CustomDataPartMimeTypes { export const ThinkingData = 'thinking'; export const ContextManagement = 'context_management'; export const PhaseData = 'phase_data'; + export const Usage = 'usage'; } export const CacheType = 'ephemeral'; \ No newline at end of file diff --git a/src/platform/endpoint/vscode-node/extChatEndpoint.ts b/src/platform/endpoint/vscode-node/extChatEndpoint.ts index 7f77e8202d..1c9c206a3a 100644 --- a/src/platform/endpoint/vscode-node/extChatEndpoint.ts +++ b/src/platform/endpoint/vscode-node/extChatEndpoint.ts @@ -18,7 +18,7 @@ import { ContextManagementResponse } from '../../networking/common/anthropic'; import { FinishedCallback, OpenAiFunctionTool, OptionalChatRequestParams } from '../../networking/common/fetch'; import { Response } from '../../networking/common/fetcherService'; import { IChatEndpoint, ICreateEndpointBodyOptions, IEndpointBody, IMakeChatRequestOptions } from '../../networking/common/networking'; -import { ChatCompletion } from '../../networking/common/openai'; +import { type APIUsage, ChatCompletion, isApiUsage } from '../../networking/common/openai'; import { IOTelService } from '../../otel/common/otelService'; import { retrieveCapturingTokenByCorrelation, storeCapturingTokenForCorrelation } from '../../requestLogger/node/requestLogger'; import { ITelemetryService } from '../../telemetry/common/telemetry'; @@ -205,6 +205,7 @@ export class ExtensionContributedChatEndpoint implements IChatEndpoint { let text = ''; let numToolsCalled = 0; const requestId = ourRequestId; + let reportedUsage: APIUsage | undefined; // consume stream for await (const chunk of response.stream) { @@ -230,6 +231,16 @@ export class ExtensionContributedChatEndpoint implements IChatEndpoint { } else if (chunk.mimeType === CustomDataPartMimeTypes.ContextManagement) { const contextManagement = JSON.parse(new TextDecoder().decode(chunk.data)) as ContextManagementResponse; await streamRecorder.callback?.(text, 0, { text: '', contextManagement }); + } else if (chunk.mimeType === CustomDataPartMimeTypes.Usage) { + try { + const parsed = JSON.parse(new TextDecoder().decode(chunk.data)); + if (isApiUsage(parsed)) { + // Last-write-wins: if multiple Usage DataParts arrive, keep the last one + reportedUsage = parsed; + } + } catch { + // ignore malformed usage data + } } } else if (chunk instanceof vscode.LanguageModelThinkingPart) { if (streamRecorder.callback) { @@ -250,7 +261,7 @@ export class ExtensionContributedChatEndpoint implements IChatEndpoint { type: ChatFetchResponseType.Success, requestId, serverRequestId: requestId, - usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, prompt_tokens_details: { cached_tokens: 0 } }, + usage: reportedUsage ?? { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, prompt_tokens_details: { cached_tokens: 0 } }, value: text, resolvedModel: this.languageModel.id }; diff --git a/src/platform/endpoint/vscode-node/test/extChatEndpoint.spec.ts b/src/platform/endpoint/vscode-node/test/extChatEndpoint.spec.ts new file mode 100644 index 0000000000..9be3c4d9cf --- /dev/null +++ b/src/platform/endpoint/vscode-node/test/extChatEndpoint.spec.ts @@ -0,0 +1,180 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it, vi } from 'vitest'; +import type { LanguageModelChat, LanguageModelChatResponse } from 'vscode'; +import { LanguageModelDataPart, LanguageModelTextPart } from '../../../../vscodeTypes'; +import { ChatFetchResponseType, ChatLocation } from '../../../chat/common/commonTypes'; +import { CustomDataPartMimeTypes } from '../../common/endpointTypes'; +import { ExtensionContributedChatEndpoint } from '../extChatEndpoint'; + +function createMockStream(chunks: unknown[]): LanguageModelChatResponse { + return { + stream: (async function* () { + for (const chunk of chunks) { + yield chunk; + } + })(), + text: (async function* () { + for (const chunk of chunks) { + if (chunk instanceof LanguageModelTextPart) { + yield chunk.value; + } + } + })(), + } as LanguageModelChatResponse; +} + +function createMockLanguageModel(streamChunks: unknown[]): LanguageModelChat { + return { + id: 'test-model', + name: 'Test Model', + vendor: 'test', + family: 'test-family', + version: '1.0', + maxInputTokens: 128000, + capabilities: {}, + sendRequest: vi.fn().mockResolvedValue(createMockStream(streamChunks)), + countTokens: vi.fn().mockResolvedValue(10), + } as unknown as LanguageModelChat; +} + +function createEndpoint(streamChunks: unknown[]): ExtensionContributedChatEndpoint { + const languageModel = createMockLanguageModel(streamChunks); + const mockInstantiationService = {} as any; + const mockOTelService = { + getActiveTraceContext: vi.fn().mockReturnValue(undefined), + } as any; + return new ExtensionContributedChatEndpoint(languageModel, mockInstantiationService, mockOTelService); +} + +describe('ExtensionContributedChatEndpoint usage reporting', () => { + it('should extract usage from Usage DataPart', async () => { + const usage = { prompt_tokens: 100, completion_tokens: 50, total_tokens: 150, prompt_tokens_details: { cached_tokens: 20 } }; + const endpoint = createEndpoint([ + new LanguageModelTextPart('Hello'), + new LanguageModelDataPart(new TextEncoder().encode(JSON.stringify(usage)), CustomDataPartMimeTypes.Usage), + ]); + + const result = await endpoint.makeChatRequest2({ + debugName: 'test', + messages: [], + finishedCb: undefined, + location: ChatLocation.Panel, + }, { isCancellationRequested: false, onCancellationRequested: vi.fn() } as any); + + expect(result.type).toBe(ChatFetchResponseType.Success); + if (result.type === ChatFetchResponseType.Success) { + expect(result.usage?.prompt_tokens).toBe(100); + expect(result.usage?.completion_tokens).toBe(50); + expect(result.usage?.total_tokens).toBe(150); + expect(result.usage?.prompt_tokens_details?.cached_tokens).toBe(20); + } + }); + + it('should fall back to zero usage when no Usage DataPart is present', async () => { + const endpoint = createEndpoint([ + new LanguageModelTextPart('Hello'), + ]); + + const result = await endpoint.makeChatRequest2({ + debugName: 'test', + messages: [], + finishedCb: undefined, + location: ChatLocation.Panel, + }, { isCancellationRequested: false, onCancellationRequested: vi.fn() } as any); + + expect(result.type).toBe(ChatFetchResponseType.Success); + if (result.type === ChatFetchResponseType.Success) { + expect(result.usage?.prompt_tokens).toBe(0); + expect(result.usage?.completion_tokens).toBe(0); + } + }); + + it('should fall back to zero usage when Usage DataPart contains malformed data', async () => { + const endpoint = createEndpoint([ + new LanguageModelTextPart('Hello'), + new LanguageModelDataPart(new TextEncoder().encode('not-valid-json'), CustomDataPartMimeTypes.Usage), + ]); + + const result = await endpoint.makeChatRequest2({ + debugName: 'test', + messages: [], + finishedCb: undefined, + location: ChatLocation.Panel, + }, { isCancellationRequested: false, onCancellationRequested: vi.fn() } as any); + + expect(result.type).toBe(ChatFetchResponseType.Success); + if (result.type === ChatFetchResponseType.Success) { + expect(result.usage?.prompt_tokens).toBe(0); + expect(result.usage?.completion_tokens).toBe(0); + } + }); + + it('should reject usage with invalid field types', async () => { + const invalidUsage = { prompt_tokens: '100', completion_tokens: 50, total_tokens: 150 }; + const endpoint = createEndpoint([ + new LanguageModelTextPart('Hello'), + new LanguageModelDataPart(new TextEncoder().encode(JSON.stringify(invalidUsage)), CustomDataPartMimeTypes.Usage), + ]); + + const result = await endpoint.makeChatRequest2({ + debugName: 'test', + messages: [], + finishedCb: undefined, + location: ChatLocation.Panel, + }, { isCancellationRequested: false, onCancellationRequested: vi.fn() } as any); + + expect(result.type).toBe(ChatFetchResponseType.Success); + if (result.type === ChatFetchResponseType.Success) { + expect(result.usage?.prompt_tokens).toBe(0); + expect(result.usage?.completion_tokens).toBe(0); + } + }); + + it('should extract usage when Usage DataPart arrives before text', async () => { + const usage = { prompt_tokens: 200, completion_tokens: 80, total_tokens: 280 }; + const endpoint = createEndpoint([ + new LanguageModelDataPart(new TextEncoder().encode(JSON.stringify(usage)), CustomDataPartMimeTypes.Usage), + new LanguageModelTextPart('Hello'), + ]); + + const result = await endpoint.makeChatRequest2({ + debugName: 'test', + messages: [], + finishedCb: undefined, + location: ChatLocation.Panel, + }, { isCancellationRequested: false, onCancellationRequested: vi.fn() } as any); + + expect(result.type).toBe(ChatFetchResponseType.Success); + if (result.type === ChatFetchResponseType.Success) { + expect(result.usage?.prompt_tokens).toBe(200); + expect(result.usage?.completion_tokens).toBe(80); + expect(result.usage?.total_tokens).toBe(280); + } + }); + + it('should report usage when finishedCb is provided', async () => { + const usage = { prompt_tokens: 50, completion_tokens: 25, total_tokens: 75 }; + const finishedCb = vi.fn(); + const endpoint = createEndpoint([ + new LanguageModelTextPart('Hello'), + new LanguageModelDataPart(new TextEncoder().encode(JSON.stringify(usage)), CustomDataPartMimeTypes.Usage), + ]); + + const result = await endpoint.makeChatRequest2({ + debugName: 'test', + messages: [], + finishedCb, + location: ChatLocation.Panel, + }, { isCancellationRequested: false, onCancellationRequested: vi.fn() } as any); + + expect(result.type).toBe(ChatFetchResponseType.Success); + if (result.type === ChatFetchResponseType.Success) { + expect(result.usage?.prompt_tokens).toBe(50); + expect(result.usage?.completion_tokens).toBe(25); + } + }); +});