Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/spotty-dragons-hug.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@aws-amplify/ai-constructs': minor
---

Add metrics (latencyMs) and usage (inputTokens, outputTokens, totalTokens) to streaming conversation responses from Bedrock Converse API
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ void describe('Bedrock converse adapter', () => {
associatedUserMessageId: event.currentMessageId,
contentBlockIndex: 1,
stopReason: 'end_turn',
metrics: { latencyMs: 150 },
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
},
]);
} else {
Expand Down Expand Up @@ -712,6 +714,8 @@ void describe('Bedrock converse adapter', () => {
associatedUserMessageId: event.currentMessageId,
contentBlockIndex: 0,
stopReason: 'tool_use',
metrics: { latencyMs: 150 },
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
},
]);
} else {
Expand Down Expand Up @@ -1033,14 +1037,44 @@ void describe('Bedrock converse adapter', () => {
progressCalls[1].arguments[0],
'Processed 2000 chunks from Bedrock Converse Stream response, requestId=testRequestId',
);
// each block is decomposed into 4 chunks + start and stop of whole message.
const expectedNumberOfAllChunks = numberOfBlocks * 4 + 2;
// each block is decomposed into 4 chunks + start, stop, and metadata of whole message.
const expectedNumberOfAllChunks = numberOfBlocks * 4 + 3;
assert.strictEqual(
progressCalls[2].arguments[0],
`Completed processing ${expectedNumberOfAllChunks.toString()} chunks from Bedrock Converse Stream response, requestId=testRequestId`,
);
});

void it('includes metrics and usage in the final streaming chunk', async () => {
const event: ConversationTurnEvent = {
...commonEvent,
};

const bedrockClient = new BedrockRuntimeClient();
const content = [{ text: 'block1' }];
const bedrockResponse = mockBedrockResponse(content, true);
mock.method(bedrockClient, 'send', () => Promise.resolve(bedrockResponse));

const adapter = new BedrockConverseAdapter(
event,
[],
bedrockClient,
undefined,
messageHistoryRetriever,
);

const chunks: Array<StreamingResponseChunk> =
await askBedrockWithStreaming(adapter);
const lastChunk = chunks[chunks.length - 1];
assert.ok(lastChunk.stopReason);
assert.deepStrictEqual(lastChunk.metrics, { latencyMs: 150 });
assert.deepStrictEqual(lastChunk.usage, {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30,
});
});

void it('throws if tool is duplicated', () => {
assert.throws(
() =>
Expand Down Expand Up @@ -1304,6 +1338,12 @@ const mockConverseStreamCommandOutput = (
stopReason: stopReason,
},
});
streamItems.push({
metadata: {
metrics: { latencyMs: 150 },
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
},
});
return {
$metadata: {},
stream: (async function* (): AsyncGenerator<ConverseStreamOutput> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ export class BedrockConverseAdapter {
let blockIndex = 0;
let lastBlockIndex = 0;
let stopReason = '';
// The following metadata´ are overwritten on each iteration of the tool-use loop.
// Only the final iteration's values are reported, as intermediate iterations
// are tool-use round-trips and the final iteration contains the actual model response.
let latencyMs = 0;
let inputTokens = 0;
let outputTokens = 0;
let totalTokens = 0;
// Accumulates client facing content per turn.
// So that upstream can persist full message at the end of the streaming.
const accumulatedTurnContent: Array<bedrock.ContentBlock> = [];
Expand Down Expand Up @@ -304,6 +311,14 @@ export class BedrockConverseAdapter {
}
} else if (chunk.messageStop) {
stopReason = chunk.messageStop.stopReason ?? '';
this.logger.debug(
`Bedrock stop reason received: stopReason=${chunk.messageStop.stopReason ?? ''}`,
);
} else if (chunk.metadata) {
latencyMs = chunk.metadata.metrics?.latencyMs ?? 0;
inputTokens = chunk.metadata.usage?.inputTokens ?? 0;
outputTokens = chunk.metadata.usage?.outputTokens ?? 0;
totalTokens = chunk.metadata.usage?.totalTokens ?? 0;
}
processedBedrockChunks++;
if (processedBedrockChunks % 1000 === 0) {
Expand All @@ -330,6 +345,8 @@ export class BedrockConverseAdapter {
associatedUserMessageId: this.event.currentMessageId,
contentBlockIndex: lastBlockIndex,
stopReason: stopReason,
metrics: { latencyMs },
usage: { inputTokens, outputTokens, totalTokens },
};
return;
}
Expand Down Expand Up @@ -359,6 +376,8 @@ export class BedrockConverseAdapter {
associatedUserMessageId: this.event.currentMessageId,
contentBlockIndex: lastBlockIndex,
stopReason: stopReason,
metrics: { latencyMs },
usage: { inputTokens, outputTokens, totalTokens },
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,89 @@ void describe('Conversation turn response sender', () => {
});
});

void it('serializes metrics and usage to JSON when streaming', async () => {
const userAgentProvider = new UserAgentProvider(
{} as unknown as ConversationTurnEvent,
);
mock.method(userAgentProvider, 'getUserAgent', () => '');
const graphqlRequestExecutor = new GraphqlRequestExecutor(
'',
'',
userAgentProvider,
);
const executeGraphqlMock = mock.method(
graphqlRequestExecutor,
'executeGraphql',
() => Promise.resolve(),
);
const sender = new ConversationTurnResponseSender(
event,
userAgentProvider,
graphqlRequestExecutor,
);
const chunk: StreamingResponseChunk = {
accumulatedTurnContent: [{ text: 'testContent' }],
associatedUserMessageId: 'testAssociatedUserMessageId',
contentBlockIndex: 0,
conversationId: 'testConversationId',
stopReason: 'end_turn',
metrics: { latencyMs: 150 },
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
};
await sender.sendResponseChunk(chunk);

assert.strictEqual(executeGraphqlMock.mock.calls.length, 1);
const request = executeGraphqlMock.mock.calls[0]
.arguments[0] as GraphqlRequest<MutationStreamingResponseInput>;
// metrics and usage should be serialized to JSON strings
assert.strictEqual(
(request.variables.input as Record<string, unknown>).metrics,
JSON.stringify({ latencyMs: 150 }),
);
assert.strictEqual(
(request.variables.input as Record<string, unknown>).usage,
JSON.stringify({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }),
);
});

void it('does not include metrics and usage for non-stop-reason chunks', async () => {
const userAgentProvider = new UserAgentProvider(
{} as unknown as ConversationTurnEvent,
);
mock.method(userAgentProvider, 'getUserAgent', () => '');
const graphqlRequestExecutor = new GraphqlRequestExecutor(
'',
'',
userAgentProvider,
);
const executeGraphqlMock = mock.method(
graphqlRequestExecutor,
'executeGraphql',
() => Promise.resolve(),
);
const sender = new ConversationTurnResponseSender(
event,
userAgentProvider,
graphqlRequestExecutor,
);
const chunk: StreamingResponseChunk = {
accumulatedTurnContent: [{ text: 'testContent' }],
associatedUserMessageId: 'testAssociatedUserMessageId',
contentBlockIndex: 0,
contentBlockDeltaIndex: 0,
conversationId: 'testConversationId',
contentBlockText: 'testBlockText',
};
await sender.sendResponseChunk(chunk);

assert.strictEqual(executeGraphqlMock.mock.calls.length, 1);
const request = executeGraphqlMock.mock.calls[0]
.arguments[0] as GraphqlRequest<MutationStreamingResponseInput>;
const input = request.variables.input as Record<string, unknown>;
assert.strictEqual(input.metrics, undefined);
assert.strictEqual(input.usage, undefined);
});

void it('sends errors response back to appsync', async () => {
const userAgentProvider = new UserAgentProvider(
{} as unknown as ConversationTurnEvent,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ export type MutationResponseInput = {
};
};

export type SerializedStreamingResponseChunk = Omit<
StreamingResponseChunk,
'metrics' | 'usage'
> & {
metrics?: string;
usage?: string;
};

export type MutationStreamingResponseInput = {
input: StreamingResponseChunk;
input: SerializedStreamingResponseChunk;
};

export type MutationErrorsResponseInput = {
Expand Down Expand Up @@ -133,14 +141,18 @@ export class ConversationTurnResponseSender {
}
}
`;
chunk = {
...chunk,

const { metrics, usage, ...rest } = chunk;
const serializedChunk: SerializedStreamingResponseChunk = {
...rest,
accumulatedTurnContent: this.serializeContent(
chunk.accumulatedTurnContent,
),
...(metrics && { metrics: JSON.stringify(metrics) }),
...(usage && { usage: JSON.stringify(usage) }),
};
const variables: MutationStreamingResponseInput = {
input: chunk,
const variables = {
input: serializedChunk,
};
return { query, variables };
};
Expand Down
8 changes: 8 additions & 0 deletions packages/ai-constructs/src/conversation/runtime/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ export type StreamingResponseChunk = {
contentBlockDoneAtIndex?: never;
contentBlockToolUse?: never;
stopReason?: never;
metrics?: never;
usage?: never;
}
| {
// end of block. applicable to text blocks
Expand All @@ -141,6 +143,8 @@ export type StreamingResponseChunk = {
contentBlockDeltaIndex?: never;
contentBlockToolUse?: never;
stopReason?: never;
metrics?: never;
usage?: never;
}
| {
// tool use
Expand All @@ -149,10 +153,14 @@ export type StreamingResponseChunk = {
contentBlockText?: never;
contentBlockDeltaIndex?: never;
stopReason?: never;
metrics?: never;
usage?: never;
}
| {
// turn complete
stopReason: string;
metrics: { latencyMs: number };
usage: { inputTokens: number; outputTokens: number; totalTokens: number };
contentBlockDoneAtIndex?: never;
contentBlockText?: never;
contentBlockDeltaIndex?: never;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ const schema = a.schema({
message: a.string(),
}),

MockMetrics: a.customType({
latencyMs: a.integer(),
}),

MockUsage: a.customType({
inputTokens: a.integer(),
outputTokens: a.integer(),
totalTokens: a.integer(),
}),

ConversationMessageAssistantResponse: a
.model({
conversationId: a.id(),
Expand Down Expand Up @@ -119,6 +129,8 @@ const schema = a.schema({

// when message is complete
stopReason: a.string(),
metrics: a.ref('MockMetrics'),
usage: a.ref('MockUsage'),

// error
errors: a.ref('MockConversationTurnError').array(),
Expand Down
Loading