diff --git a/.changeset/spotty-dragons-hug.md b/.changeset/spotty-dragons-hug.md new file mode 100644 index 00000000000..22a8a53e314 --- /dev/null +++ b/.changeset/spotty-dragons-hug.md @@ -0,0 +1,5 @@ +--- +'@aws-amplify/ai-constructs': minor +--- + +Add metrics (latencyMs) and usage (inputTokens, outputTokens, totalTokens) to streaming conversation responses from Bedrock Converse API diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts index 12988a5977d..80c563d2642 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts @@ -201,6 +201,8 @@ void describe('Bedrock converse adapter', () => { associatedUserMessageId: event.currentMessageId, contentBlockIndex: 1, stopReason: 'end_turn', + metrics: { latencyMs: 150 }, + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, }, ]); } else { @@ -712,6 +714,8 @@ void describe('Bedrock converse adapter', () => { associatedUserMessageId: event.currentMessageId, contentBlockIndex: 0, stopReason: 'tool_use', + metrics: { latencyMs: 150 }, + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, }, ]); } else { @@ -1033,14 +1037,44 @@ void describe('Bedrock converse adapter', () => { progressCalls[1].arguments[0], 'Processed 2000 chunks from Bedrock Converse Stream response, requestId=testRequestId', ); - // each block is decomposed into 4 chunks + start and stop of whole message. - const expectedNumberOfAllChunks = numberOfBlocks * 4 + 2; + // each block is decomposed into 4 chunks + start, stop, and metadata of whole message. + const expectedNumberOfAllChunks = numberOfBlocks * 4 + 3; assert.strictEqual( progressCalls[2].arguments[0], `Completed processing ${expectedNumberOfAllChunks.toString()} chunks from Bedrock Converse Stream response, requestId=testRequestId`, ); }); + void it('includes metrics and usage in the final streaming chunk', async () => { + const event: ConversationTurnEvent = { + ...commonEvent, + }; + + const bedrockClient = new BedrockRuntimeClient(); + const content = [{ text: 'block1' }]; + const bedrockResponse = mockBedrockResponse(content, true); + mock.method(bedrockClient, 'send', () => Promise.resolve(bedrockResponse)); + + const adapter = new BedrockConverseAdapter( + event, + [], + bedrockClient, + undefined, + messageHistoryRetriever, + ); + + const chunks: Array = + await askBedrockWithStreaming(adapter); + const lastChunk = chunks[chunks.length - 1]; + assert.ok(lastChunk.stopReason); + assert.deepStrictEqual(lastChunk.metrics, { latencyMs: 150 }); + assert.deepStrictEqual(lastChunk.usage, { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + }); + }); + void it('throws if tool is duplicated', () => { assert.throws( () => @@ -1304,6 +1338,12 @@ const mockConverseStreamCommandOutput = ( stopReason: stopReason, }, }); + streamItems.push({ + metadata: { + metrics: { latencyMs: 150 }, + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + }, + }); return { $metadata: {}, stream: (async function* (): AsyncGenerator { diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts index 1d6f07f83ea..d4e8eac32d5 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts @@ -175,6 +175,13 @@ export class BedrockConverseAdapter { let blockIndex = 0; let lastBlockIndex = 0; let stopReason = ''; + // The following metadata´ are overwritten on each iteration of the tool-use loop. + // Only the final iteration's values are reported, as intermediate iterations + // are tool-use round-trips and the final iteration contains the actual model response. + let latencyMs = 0; + let inputTokens = 0; + let outputTokens = 0; + let totalTokens = 0; // Accumulates client facing content per turn. // So that upstream can persist full message at the end of the streaming. const accumulatedTurnContent: Array = []; @@ -304,6 +311,14 @@ export class BedrockConverseAdapter { } } else if (chunk.messageStop) { stopReason = chunk.messageStop.stopReason ?? ''; + this.logger.debug( + `Bedrock stop reason received: stopReason=${chunk.messageStop.stopReason ?? ''}`, + ); + } else if (chunk.metadata) { + latencyMs = chunk.metadata.metrics?.latencyMs ?? 0; + inputTokens = chunk.metadata.usage?.inputTokens ?? 0; + outputTokens = chunk.metadata.usage?.outputTokens ?? 0; + totalTokens = chunk.metadata.usage?.totalTokens ?? 0; } processedBedrockChunks++; if (processedBedrockChunks % 1000 === 0) { @@ -330,6 +345,8 @@ export class BedrockConverseAdapter { associatedUserMessageId: this.event.currentMessageId, contentBlockIndex: lastBlockIndex, stopReason: stopReason, + metrics: { latencyMs }, + usage: { inputTokens, outputTokens, totalTokens }, }; return; } @@ -359,6 +376,8 @@ export class BedrockConverseAdapter { associatedUserMessageId: this.event.currentMessageId, contentBlockIndex: lastBlockIndex, stopReason: stopReason, + metrics: { latencyMs }, + usage: { inputTokens, outputTokens, totalTokens }, }; } diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts index cedfbedf88a..5b62357c130 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts @@ -300,6 +300,89 @@ void describe('Conversation turn response sender', () => { }); }); + void it('serializes metrics and usage to JSON when streaming', async () => { + const userAgentProvider = new UserAgentProvider( + {} as unknown as ConversationTurnEvent, + ); + mock.method(userAgentProvider, 'getUserAgent', () => ''); + const graphqlRequestExecutor = new GraphqlRequestExecutor( + '', + '', + userAgentProvider, + ); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + () => Promise.resolve(), + ); + const sender = new ConversationTurnResponseSender( + event, + userAgentProvider, + graphqlRequestExecutor, + ); + const chunk: StreamingResponseChunk = { + accumulatedTurnContent: [{ text: 'testContent' }], + associatedUserMessageId: 'testAssociatedUserMessageId', + contentBlockIndex: 0, + conversationId: 'testConversationId', + stopReason: 'end_turn', + metrics: { latencyMs: 150 }, + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + }; + await sender.sendResponseChunk(chunk); + + assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); + const request = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + // metrics and usage should be serialized to JSON strings + assert.strictEqual( + (request.variables.input as Record).metrics, + JSON.stringify({ latencyMs: 150 }), + ); + assert.strictEqual( + (request.variables.input as Record).usage, + JSON.stringify({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }), + ); + }); + + void it('does not include metrics and usage for non-stop-reason chunks', async () => { + const userAgentProvider = new UserAgentProvider( + {} as unknown as ConversationTurnEvent, + ); + mock.method(userAgentProvider, 'getUserAgent', () => ''); + const graphqlRequestExecutor = new GraphqlRequestExecutor( + '', + '', + userAgentProvider, + ); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + () => Promise.resolve(), + ); + const sender = new ConversationTurnResponseSender( + event, + userAgentProvider, + graphqlRequestExecutor, + ); + const chunk: StreamingResponseChunk = { + accumulatedTurnContent: [{ text: 'testContent' }], + associatedUserMessageId: 'testAssociatedUserMessageId', + contentBlockIndex: 0, + contentBlockDeltaIndex: 0, + conversationId: 'testConversationId', + contentBlockText: 'testBlockText', + }; + await sender.sendResponseChunk(chunk); + + assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); + const request = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + const input = request.variables.input as Record; + assert.strictEqual(input.metrics, undefined); + assert.strictEqual(input.usage, undefined); + }); + void it('sends errors response back to appsync', async () => { const userAgentProvider = new UserAgentProvider( {} as unknown as ConversationTurnEvent, diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts index 9849f5ecd07..eeed75ca847 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts @@ -15,8 +15,16 @@ export type MutationResponseInput = { }; }; +export type SerializedStreamingResponseChunk = Omit< + StreamingResponseChunk, + 'metrics' | 'usage' +> & { + metrics?: string; + usage?: string; +}; + export type MutationStreamingResponseInput = { - input: StreamingResponseChunk; + input: SerializedStreamingResponseChunk; }; export type MutationErrorsResponseInput = { @@ -133,14 +141,18 @@ export class ConversationTurnResponseSender { } } `; - chunk = { - ...chunk, + + const { metrics, usage, ...rest } = chunk; + const serializedChunk: SerializedStreamingResponseChunk = { + ...rest, accumulatedTurnContent: this.serializeContent( chunk.accumulatedTurnContent, ), + ...(metrics && { metrics: JSON.stringify(metrics) }), + ...(usage && { usage: JSON.stringify(usage) }), }; - const variables: MutationStreamingResponseInput = { - input: chunk, + const variables = { + input: serializedChunk, }; return { query, variables }; }; diff --git a/packages/ai-constructs/src/conversation/runtime/types.ts b/packages/ai-constructs/src/conversation/runtime/types.ts index 6cc7da99a52..8e3a6848fc7 100644 --- a/packages/ai-constructs/src/conversation/runtime/types.ts +++ b/packages/ai-constructs/src/conversation/runtime/types.ts @@ -133,6 +133,8 @@ export type StreamingResponseChunk = { contentBlockDoneAtIndex?: never; contentBlockToolUse?: never; stopReason?: never; + metrics?: never; + usage?: never; } | { // end of block. applicable to text blocks @@ -141,6 +143,8 @@ export type StreamingResponseChunk = { contentBlockDeltaIndex?: never; contentBlockToolUse?: never; stopReason?: never; + metrics?: never; + usage?: never; } | { // tool use @@ -149,10 +153,14 @@ export type StreamingResponseChunk = { contentBlockText?: never; contentBlockDeltaIndex?: never; stopReason?: never; + metrics?: never; + usage?: never; } | { // turn complete stopReason: string; + metrics: { latencyMs: number }; + usage: { inputTokens: number; outputTokens: number; totalTokens: number }; contentBlockDoneAtIndex?: never; contentBlockText?: never; contentBlockDeltaIndex?: never; diff --git a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts index 185780bc0b0..c7a9d22f8a1 100644 --- a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts +++ b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts @@ -91,6 +91,16 @@ const schema = a.schema({ message: a.string(), }), + MockMetrics: a.customType({ + latencyMs: a.integer(), + }), + + MockUsage: a.customType({ + inputTokens: a.integer(), + outputTokens: a.integer(), + totalTokens: a.integer(), + }), + ConversationMessageAssistantResponse: a .model({ conversationId: a.id(), @@ -119,6 +129,8 @@ const schema = a.schema({ // when message is complete stopReason: a.string(), + metrics: a.ref('MockMetrics'), + usage: a.ref('MockUsage'), // error errors: a.ref('MockConversationTurnError').array(),