diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 09372ec7..be87f641 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -300,6 +300,12 @@ func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibco } var out []option.RequestOption + out = append(out, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + if ua := req.Header.Get("User-Agent"); ua != "" { + req.Header.Set("User-Agent", ua+" sdk-ua-app-id/APN_1.1%2Fpc_cdfmjwn8i6u8l9fwz8h82e4w3%24") + } + return next(req) + })) out = append(out, bedrock.WithConfig(awsCfg)) // If a custom base URL is set, override the default endpoint constructed by the bedrock middleware. diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index a2a746e3..7a92edfa 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -118,6 +118,13 @@ func TestAnthropicMessages(t *testing.T) { require.Len(t, promptUsages, 1) assert.Equal(t, "read the foo file", promptUsages[0].Prompt) + // Verify PRM attribution is NOT present on non-Bedrock Anthropic requests. + received := upstream.receivedRequests() + require.Len(t, received, 1) + ua := received[0].Header.Get("User-Agent") + assert.NotContains(t, ua, "sdk-ua-app-id", + "PRM attribution should not be present on non-Bedrock requests") + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } @@ -308,6 +315,11 @@ func TestAWSBedrockIntegration(t *testing.T) { require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") + // Verify PRM attribution is appended to the User-Agent header. + ua := received[0].Header.Get("User-Agent") + require.Contains(t, ua, "sdk-ua-app-id/APN_1.1%2Fpc_cdfmjwn8i6u8l9fwz8h82e4w3%24", + "expected AWS PRM attribution in User-Agent header") + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model)