diff --git a/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap b/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap index a6cf9e90c4..0cb0d45528 100644 --- a/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap +++ b/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap @@ -214,6 +214,8 @@ export function request(ctx) { associatedUserMessageId, accumulatedTurnContent, errors, + metrics, + usage, } = ctx.args.input; const { owner } = ctx.args; @@ -230,7 +232,7 @@ export function request(ctx) { const { createdAt, updatedAt } = ctx.stash.defaultValues; const assistantResponseId = \`\${associatedUserMessageId}#response\`; - const expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; + let expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; const expressionValues = util.dynamodb.toMapValues({ ':typename': 'ConversationMessagePirateChat', @@ -255,6 +257,18 @@ export function request(ctx) { '#updatedAt': 'updatedAt', }; + if (metrics) { + expression += ', #metrics = :metrics'; + expressionValues[':metrics'] = metrics; + expressionNames['#metrics'] = 'metrics'; + } + + if (usage) { + expression += ', #usage = :usage'; + expressionValues[':usage'] = usage; + expressionNames['#usage'] = 'usage'; + } + return { operation: 'UpdateItem', key: util.dynamodb.toMapValues({ id: assistantResponseId }), @@ -300,7 +314,8 @@ function generateRandomPadding() { const base = 'abcdefghijklmnopqrstuvwxyz0123456789'; const rand = Math.floor(Math.random() * 36); return base.slice(0, rand); -}" +} +" `; exports[`ConversationTransformer valid schemas should transform conversation route with inference configuration: AssistantResponseStreamMutation init slot function code 1`] = ` @@ -676,7 +691,7 @@ export function request(ctx) { const { graphqlApiEndpoint } = ctx.stash; const userAgent = createUserAgent(request); - const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p'; + const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p metrics { latencyMs } usage { inputTokens outputTokens totalTokens }'; const streamingResponseMutation = { name: 'createAssistantResponseStreamPirateChat', @@ -1078,6 +1093,8 @@ export function request(ctx) { associatedUserMessageId, accumulatedTurnContent, errors, + metrics, + usage, } = ctx.args.input; const { owner } = ctx.args; @@ -1094,7 +1111,7 @@ export function request(ctx) { const { createdAt, updatedAt } = ctx.stash.defaultValues; const assistantResponseId = \`\${associatedUserMessageId}#response\`; - const expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; + let expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; const expressionValues = util.dynamodb.toMapValues({ ':typename': 'ConversationMessagePirateChat', @@ -1119,6 +1136,18 @@ export function request(ctx) { '#updatedAt': 'updatedAt', }; + if (metrics) { + expression += ', #metrics = :metrics'; + expressionValues[':metrics'] = metrics; + expressionNames['#metrics'] = 'metrics'; + } + + if (usage) { + expression += ', #usage = :usage'; + expressionValues[':usage'] = usage; + expressionNames['#usage'] = 'usage'; + } + return { operation: 'UpdateItem', key: util.dynamodb.toMapValues({ id: assistantResponseId }), @@ -1164,7 +1193,8 @@ function generateRandomPadding() { const base = 'abcdefghijklmnopqrstuvwxyz0123456789'; const rand = Math.floor(Math.random() * 36); return base.slice(0, rand); -}" +} +" `; exports[`ConversationTransformer valid schemas should transform conversation route with model query tool including relationships: AssistantResponseStreamMutation init slot function code 1`] = ` @@ -1540,7 +1570,7 @@ export function request(ctx) { const { graphqlApiEndpoint } = ctx.stash; const userAgent = createUserAgent(request); - const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p'; + const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p metrics { latencyMs } usage { inputTokens outputTokens totalTokens }'; const streamingResponseMutation = { name: 'createAssistantResponseStreamPirateChat', @@ -1942,6 +1972,8 @@ export function request(ctx) { associatedUserMessageId, accumulatedTurnContent, errors, + metrics, + usage, } = ctx.args.input; const { owner } = ctx.args; @@ -1958,7 +1990,7 @@ export function request(ctx) { const { createdAt, updatedAt } = ctx.stash.defaultValues; const assistantResponseId = \`\${associatedUserMessageId}#response\`; - const expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; + let expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; const expressionValues = util.dynamodb.toMapValues({ ':typename': 'ConversationMessagePirateChat', @@ -1983,6 +2015,18 @@ export function request(ctx) { '#updatedAt': 'updatedAt', }; + if (metrics) { + expression += ', #metrics = :metrics'; + expressionValues[':metrics'] = metrics; + expressionNames['#metrics'] = 'metrics'; + } + + if (usage) { + expression += ', #usage = :usage'; + expressionValues[':usage'] = usage; + expressionNames['#usage'] = 'usage'; + } + return { operation: 'UpdateItem', key: util.dynamodb.toMapValues({ id: assistantResponseId }), @@ -2028,7 +2072,8 @@ function generateRandomPadding() { const base = 'abcdefghijklmnopqrstuvwxyz0123456789'; const rand = Math.floor(Math.random() * 36); return base.slice(0, rand); -}" +} +" `; exports[`ConversationTransformer valid schemas should transform conversation route with model query tool: AssistantResponseStreamMutation init slot function code 1`] = ` @@ -2404,7 +2449,7 @@ export function request(ctx) { const { graphqlApiEndpoint } = ctx.stash; const userAgent = createUserAgent(request); - const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p'; + const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p metrics { latencyMs } usage { inputTokens outputTokens totalTokens }'; const streamingResponseMutation = { name: 'createAssistantResponseStreamPirateChat', @@ -2806,6 +2851,8 @@ export function request(ctx) { associatedUserMessageId, accumulatedTurnContent, errors, + metrics, + usage, } = ctx.args.input; const { owner } = ctx.args; @@ -2822,7 +2869,7 @@ export function request(ctx) { const { createdAt, updatedAt } = ctx.stash.defaultValues; const assistantResponseId = \`\${associatedUserMessageId}#response\`; - const expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; + let expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; const expressionValues = util.dynamodb.toMapValues({ ':typename': 'ConversationMessagePirateChat', @@ -2847,6 +2894,18 @@ export function request(ctx) { '#updatedAt': 'updatedAt', }; + if (metrics) { + expression += ', #metrics = :metrics'; + expressionValues[':metrics'] = metrics; + expressionNames['#metrics'] = 'metrics'; + } + + if (usage) { + expression += ', #usage = :usage'; + expressionValues[':usage'] = usage; + expressionNames['#usage'] = 'usage'; + } + return { operation: 'UpdateItem', key: util.dynamodb.toMapValues({ id: assistantResponseId }), @@ -2892,7 +2951,8 @@ function generateRandomPadding() { const base = 'abcdefghijklmnopqrstuvwxyz0123456789'; const rand = Math.floor(Math.random() * 36); return base.slice(0, rand); -}" +} +" `; exports[`ConversationTransformer valid schemas should transform conversation route with query tools: AssistantResponseStreamMutation init slot function code 1`] = ` @@ -3268,7 +3328,7 @@ export function request(ctx) { const { graphqlApiEndpoint } = ctx.stash; const userAgent = createUserAgent(request); - const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p'; + const selectionSet = 'associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p metrics { latencyMs } usage { inputTokens outputTokens totalTokens }'; const streamingResponseMutation = { name: 'createAssistantResponseStreamPirateChat', diff --git a/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts b/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts index 1cecbede95..3ebd0cc8f5 100644 --- a/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts +++ b/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts @@ -141,6 +141,8 @@ export const createAssistantResponseStreamingMutationInput = (messageModelName: makeInputValueDefinition('accumulatedTurnContent', makeListType(makeNamedType('AmplifyAIContentBlockInput'))), makeInputValueDefinition('errors', makeListType(makeNamedType('AmplifyAIConversationTurnErrorInput'))), makeInputValueDefinition('p', makeNamedType('String')), + makeInputValueDefinition('metrics', makeNamedType('AWSJSON')), + makeInputValueDefinition('usage', makeNamedType('AWSJSON')), ], }; }; @@ -156,6 +158,24 @@ export const createConversationTurnErrorInput = (): InputObjectTypeDefinitionNod }; }; +export const constructMetricsType = (): ObjectTypeDefinitionNode => ({ + kind: 'ObjectTypeDefinition', + name: { kind: 'Name', value: 'AmplifyAIMetrics' }, + fields: [makeField('latencyMs', [], makeNamedType('Int'))], + directives: [], +}); + +export const constructUsageType = (): ObjectTypeDefinitionNode => ({ + kind: 'ObjectTypeDefinition', + name: { kind: 'Name', value: 'AmplifyAIUsage' }, + fields: [ + makeField('inputTokens', [], makeNamedType('Int')), + makeField('outputTokens', [], makeNamedType('Int')), + makeField('totalTokens', [], makeNamedType('Int')), + ], + directives: [], +}); + export const createAssistantStreamingMutationField = (fieldName: string, inputTypeName: string): FieldDefinitionNode => { const args = [makeInputValueDefinition('input', makeNonNullType(makeNamedType(inputTypeName)))]; const cognitoAuthDirective = makeDirective('aws_cognito_user_pools', []); @@ -282,11 +302,13 @@ const constructConversationMessageModel = ( const context = makeField('aiContext', [], makeNamedType('AWSJSON')); const uiComponents = makeField('toolConfiguration', [], makeNamedType('AmplifyAIToolConfiguration')); const associatedUserMessageId = makeField('associatedUserMessageId', [], makeNamedType('ID')); + const metrics = makeField('metrics', [], makeNamedType('AmplifyAIMetrics')); + const usage = makeField('usage', [], makeNamedType('AmplifyAIUsage')); const object = { ...blankObject(modelName), interfaces: [conversationMessageInterface], - fields: [id, conversationId, conversationField, role, content, context, uiComponents, associatedUserMessageId], + fields: [id, conversationId, conversationField, role, content, context, uiComponents, associatedUserMessageId, metrics, usage], directives: typeDirectives, }; @@ -315,6 +337,8 @@ export const constructStreamResponseType = (): ObjectTypeDefinitionNode => { makeField('contentBlockDoneAtIndex', [], makeNamedType('Int')), makeField('stopReason', [], makeNamedType('String')), + makeField('metrics', [], makeNamedType('AmplifyAIMetrics')), + makeField('usage', [], makeNamedType('AmplifyAIUsage')), ], }; }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/send-message-pipeline-definition.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/send-message-pipeline-definition.ts index d06d55d0bc..8df8fca683 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/send-message-pipeline-definition.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/send-message-pipeline-definition.ts @@ -134,4 +134,4 @@ function templateGenerator(slotName: string) { return createS3AssetMappingTemplateGenerator('Mutation', slotName, fieldName); } -const streamingSelectionSet = `associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p`; +const streamingSelectionSet = `associatedUserMessageId contentBlockDeltaIndex contentBlockDoneAtIndex contentBlockIndex contentBlockText contentBlockToolUse { toolUseId name input } conversationId id stopReason owner errors { errorType message } p metrics { latencyMs } usage { inputTokens outputTokens totalTokens }`; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/templates/assistant-streaming-mutation-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/templates/assistant-streaming-mutation-resolver-fn.template.js index c08825ba88..5c0f924af4 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/templates/assistant-streaming-mutation-resolver-fn.template.js +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/templates/assistant-streaming-mutation-resolver-fn.template.js @@ -12,6 +12,8 @@ export function request(ctx) { associatedUserMessageId, accumulatedTurnContent, errors, + metrics, + usage, } = ctx.args.input; const { owner } = ctx.args; @@ -28,7 +30,7 @@ export function request(ctx) { const { createdAt, updatedAt } = ctx.stash.defaultValues; const assistantResponseId = `${associatedUserMessageId}#response`; - const expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; + let expression = 'SET #typename = :typename, #conversationId = :conversationId, #associatedUserMessageId = :associatedUserMessageId, #role = :role, #content = :content, #owner = :owner, #createdAt = if_not_exists(#createdAt, :createdAt), #updatedAt = :updatedAt'; const expressionValues = util.dynamodb.toMapValues({ ':typename': '[[CONVERSATION_MESSAGE_TYPE_NAME]]', @@ -53,6 +55,18 @@ export function request(ctx) { '#updatedAt': 'updatedAt', }; + if (metrics) { + expression += ', #metrics = :metrics'; + expressionValues[':metrics'] = metrics; + expressionNames['#metrics'] = 'metrics'; + } + + if (usage) { + expression += ', #usage = :usage'; + expressionValues[':usage'] = usage; + expressionNames['#usage'] = 'usage'; + } + return { operation: 'UpdateItem', key: util.dynamodb.toMapValues({ id: assistantResponseId }), @@ -98,4 +112,4 @@ function generateRandomPadding() { const base = 'abcdefghijklmnopqrstuvwxyz0123456789'; const rand = Math.floor(Math.random() * 36); return base.slice(0, rand); -} \ No newline at end of file +} diff --git a/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-prepare-handler.ts b/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-prepare-handler.ts index 217a1a7d21..09c0a086ac 100644 --- a/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-prepare-handler.ts +++ b/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-prepare-handler.ts @@ -2,7 +2,12 @@ import { ModelTransformer } from '@aws-amplify/graphql-model-transformer'; import { BelongsToTransformer, HasManyTransformer } from '@aws-amplify/graphql-relational-transformer'; import { DDB_AMPLIFY_MANAGED_DATASOURCE_STRATEGY, InvalidTransformerError } from '@aws-amplify/graphql-transformer-core'; import { ConversationDirectiveConfiguration } from '../conversation-directive-configuration'; -import { constructStreamResponseType, createConversationTurnErrorInput } from '../graphql-types/message-model'; +import { + constructStreamResponseType, + createConversationTurnErrorInput, + constructMetricsType, + constructUsageType, +} from '../graphql-types/message-model'; import { TransformerAuthProvider, TransformerPrepareStepContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; /** @@ -50,6 +55,8 @@ export class ConversationPrepareHandler { // add once per schema const conversationTurnErrorInput = createConversationTurnErrorInput(); ctx.output.addInput(conversationTurnErrorInput); + ctx.output.addObject(constructMetricsType()); + ctx.output.addObject(constructUsageType()); for (const directive of directives) { this.prepareResourcesForDirective(directive, ctx);