Skip to content

Commit 6d32c7a

Browse files
committed
feat: add AWS Bedrock provider with SigV4 authentication
- Add SigV4AuthGenerator implementing AuthHeadersGenerator, using AWS SDK v2 signer to produce Authorization, X-Amz-Date, X-Amz-Content-Sha256, and X-Amz-Security-Token headers - Introduce dataEnrichers map on ApiKeyInjectionPlugin to enrich the credentials map with per-request data (body, endpoint) before signing, keeping the AuthHeadersGenerator interface unchanged - Thread ExternalModel.spec.endpoint through reconciler → modelInfoStore → CycleState → enricher → SigV4 signer, and extract the AWS region from the endpoint hostname - Add provider constant aws-bedrock and CycleState key endpoint - Add unit tests for SigV4 signing, region extraction, credential enrichment, session token handling, and missing-endpoint error path
1 parent 74db5d3 commit 6d32c7a

12 files changed

Lines changed: 538 additions & 12 deletions

File tree

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ require (
1616

1717
require (
1818
github.com/Masterminds/semver/v3 v3.4.0 // indirect
19+
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
20+
github.com/aws/smithy-go v1.25.1 // indirect
1921
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
2022
github.com/google/pprof v0.0.0-20260402051712-545e8a4df936 // indirect
2123
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect

go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8
66
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
77
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
88
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
9+
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
10+
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
11+
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
12+
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
913
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
1014
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
1115
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
Copyright 2026 The opendatahub.io Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package auth
18+
19+
import (
20+
"context"
21+
"crypto/sha256"
22+
"fmt"
23+
"net/http"
24+
"strings"
25+
"time"
26+
27+
"github.com/aws/aws-sdk-go-v2/aws"
28+
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
29+
)
30+
31+
const (
32+
awsAccessKeyField = "aws-access-key-id"
33+
awsSecretKeyField = "aws-secret-access-key"
34+
awsSessionTokenField = "aws-session-token"
35+
enrichedBodyField = "_request_body"
36+
enrichedEndpointField = "_endpoint"
37+
enrichedPathField = "_path"
38+
bedrockService = "bedrock"
39+
defaultPath = "/v1/chat/completions"
40+
)
41+
42+
// compile-time interface check
43+
var _ AuthHeadersGenerator = &SigV4AuthGenerator{}
44+
45+
// SigV4AuthGenerator generates AWS Signature Version 4 authentication headers.
46+
// All inputs (credentials, request body, endpoint) come from the credentialsData map,
47+
// where request-specific fields are injected by the credentials enricher.
48+
type SigV4AuthGenerator struct{}
49+
50+
// GenerateAuthHeaders computes a SigV4 signature and returns the required AWS auth headers.
51+
// Expects the following keys in credentialsData:
52+
// - "aws-access-key-id" and "aws-secret-access-key" (from the Kubernetes Secret)
53+
// - "_request_body" and "_endpoint" (injected by the credentials enricher)
54+
// - "aws-session-token" (optional, for temporary credentials)
55+
func (g *SigV4AuthGenerator) GenerateAuthHeaders(credentialsData map[string]string) (map[string]string, error) {
56+
accessKey, ok := credentialsData[awsAccessKeyField]
57+
if !ok || accessKey == "" {
58+
return nil, fmt.Errorf("credentials missing required field %s", awsAccessKeyField)
59+
}
60+
secretKey, ok := credentialsData[awsSecretKeyField]
61+
if !ok || secretKey == "" {
62+
return nil, fmt.Errorf("credentials missing required field %s", awsSecretKeyField)
63+
}
64+
65+
body, ok := credentialsData[enrichedBodyField]
66+
if !ok {
67+
return nil, fmt.Errorf("credentials missing required field %s (enricher not applied?)", enrichedBodyField)
68+
}
69+
endpoint, ok := credentialsData[enrichedEndpointField]
70+
if !ok || endpoint == "" {
71+
return nil, fmt.Errorf("credentials missing required field %s (enricher not applied?)", enrichedEndpointField)
72+
}
73+
74+
region, err := regionFromEndpoint(endpoint)
75+
if err != nil {
76+
return nil, fmt.Errorf("failed to extract region: %w", err)
77+
}
78+
79+
creds := aws.Credentials{
80+
AccessKeyID: accessKey,
81+
SecretAccessKey: secretKey,
82+
SessionToken: credentialsData[awsSessionTokenField],
83+
Source: "ExternalModelSecret",
84+
}
85+
86+
bodyHash := sha256Hex([]byte(body))
87+
88+
path := credentialsData[enrichedPathField]
89+
if path == "" {
90+
path = defaultPath
91+
}
92+
93+
req, err := http.NewRequest(http.MethodPost, "https://"+endpoint+path, strings.NewReader(body))
94+
if err != nil {
95+
return nil, fmt.Errorf("failed to build HTTP request for signing: %w", err)
96+
}
97+
req.Header.Set("Content-Type", "application/json")
98+
99+
signer := v4.NewSigner()
100+
if err := signer.SignHTTP(context.Background(), creds, req, bodyHash, bedrockService, region, time.Now()); err != nil {
101+
return nil, fmt.Errorf("SigV4 signing failed: %w", err)
102+
}
103+
104+
headers := map[string]string{
105+
"Authorization": req.Header.Get("Authorization"),
106+
"X-Amz-Date": req.Header.Get("X-Amz-Date"),
107+
"X-Amz-Content-Sha256": bodyHash,
108+
}
109+
if creds.SessionToken != "" {
110+
headers["X-Amz-Security-Token"] = creds.SessionToken
111+
}
112+
113+
return headers, nil
114+
}
115+
116+
// regionFromEndpoint extracts the AWS region from a Bedrock endpoint hostname.
117+
// e.g. "bedrock-runtime.us-east-1.amazonaws.com" → "us-east-1"
118+
func regionFromEndpoint(endpoint string) (string, error) {
119+
parts := strings.Split(endpoint, ".")
120+
if len(parts) < 4 {
121+
return "", fmt.Errorf("invalid AWS endpoint format: %q (expected service.region.amazonaws.com)", endpoint)
122+
}
123+
return parts[1], nil
124+
}
125+
126+
func sha256Hex(data []byte) string {
127+
h := sha256.Sum256(data)
128+
return fmt.Sprintf("%x", h)
129+
}
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
Copyright 2026 The opendatahub.io Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package auth
18+
19+
import (
20+
"crypto/sha256"
21+
"fmt"
22+
"strings"
23+
"testing"
24+
)
25+
26+
func TestSigV4AuthHeadersGenerator(t *testing.T) {
27+
validCredentials := map[string]string{
28+
"aws-access-key-id": "AKIAIOSFODNN7EXAMPLE",
29+
"aws-secret-access-key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
30+
"_request_body": `{"model":"anthropic.claude-v2","prompt":"hello"}`,
31+
"_endpoint": "bedrock-runtime.us-east-1.amazonaws.com",
32+
}
33+
34+
validCredentialsWithToken := map[string]string{
35+
"aws-access-key-id": "AKIAIOSFODNN7EXAMPLE",
36+
"aws-secret-access-key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
37+
"aws-session-token": "FwoGZXIvYXdzEBYaDH7example-session-token",
38+
"_request_body": `{"model":"anthropic.claude-v2","prompt":"hello"}`,
39+
"_endpoint": "bedrock-runtime.us-east-1.amazonaws.com",
40+
}
41+
42+
tests := []struct {
43+
name string
44+
credentials map[string]string
45+
wantHeaders []string
46+
wantAuthPrefix string
47+
wantNoHeader string
48+
wantBodyHash bool
49+
wantErrContains string
50+
}{
51+
{
52+
name: "valid credentials without session token",
53+
credentials: validCredentials,
54+
wantHeaders: []string{"Authorization", "X-Amz-Date", "X-Amz-Content-Sha256"},
55+
wantAuthPrefix: "AWS4-HMAC-SHA256",
56+
wantNoHeader: "X-Amz-Security-Token",
57+
wantBodyHash: true,
58+
},
59+
{
60+
name: "valid credentials with session token",
61+
credentials: validCredentialsWithToken,
62+
wantHeaders: []string{"Authorization", "X-Amz-Date", "X-Amz-Content-Sha256", "X-Amz-Security-Token"},
63+
wantAuthPrefix: "AWS4-HMAC-SHA256",
64+
wantBodyHash: true,
65+
},
66+
{
67+
name: "missing access key returns error",
68+
credentials: map[string]string{
69+
"aws-secret-access-key": "secret",
70+
"_request_body": "{}",
71+
"_endpoint": "bedrock-runtime.us-east-1.amazonaws.com",
72+
},
73+
wantErrContains: "aws-access-key-id",
74+
},
75+
{
76+
name: "missing secret key returns error",
77+
credentials: map[string]string{
78+
"aws-access-key-id": "AKIA...",
79+
"_request_body": "{}",
80+
"_endpoint": "bedrock-runtime.us-east-1.amazonaws.com",
81+
},
82+
wantErrContains: "aws-secret-access-key",
83+
},
84+
{
85+
name: "missing request body returns error",
86+
credentials: map[string]string{
87+
"aws-access-key-id": "AKIA...",
88+
"aws-secret-access-key": "secret",
89+
"_endpoint": "bedrock-runtime.us-east-1.amazonaws.com",
90+
},
91+
wantErrContains: "_request_body",
92+
},
93+
{
94+
name: "missing endpoint returns error",
95+
credentials: map[string]string{
96+
"aws-access-key-id": "AKIA...",
97+
"aws-secret-access-key": "secret",
98+
"_request_body": "{}",
99+
},
100+
wantErrContains: "_endpoint",
101+
},
102+
{
103+
name: "invalid endpoint format returns error",
104+
credentials: map[string]string{
105+
"aws-access-key-id": "AKIA...",
106+
"aws-secret-access-key": "secret",
107+
"_request_body": "{}",
108+
"_endpoint": "localhost",
109+
},
110+
wantErrContains: "invalid AWS endpoint",
111+
},
112+
{
113+
name: "empty access key returns error",
114+
credentials: map[string]string{
115+
"aws-access-key-id": "",
116+
"aws-secret-access-key": "secret",
117+
"_request_body": "{}",
118+
"_endpoint": "bedrock-runtime.us-east-1.amazonaws.com",
119+
},
120+
wantErrContains: "aws-access-key-id",
121+
},
122+
{
123+
name: "empty credentials returns error",
124+
credentials: map[string]string{},
125+
wantErrContains: "aws-access-key-id",
126+
},
127+
}
128+
129+
for _, test := range tests {
130+
t.Run(test.name, func(t *testing.T) {
131+
generator := &SigV4AuthGenerator{}
132+
authHeaders, err := generator.GenerateAuthHeaders(test.credentials)
133+
134+
if test.wantErrContains != "" {
135+
if err == nil {
136+
t.Errorf("expected error containing %q but got nil", test.wantErrContains)
137+
} else if !strings.Contains(err.Error(), test.wantErrContains) {
138+
t.Errorf("expected error containing %q, got: %v", test.wantErrContains, err)
139+
}
140+
return
141+
}
142+
143+
if err != nil {
144+
t.Fatalf("unexpected error: %v", err)
145+
}
146+
147+
for _, header := range test.wantHeaders {
148+
val, ok := authHeaders[header]
149+
if !ok {
150+
t.Errorf("expected header %q not found in result", header)
151+
continue
152+
}
153+
if val == "" {
154+
t.Errorf("header %q is empty", header)
155+
}
156+
}
157+
158+
if test.wantAuthPrefix != "" {
159+
auth := authHeaders["Authorization"]
160+
if !strings.HasPrefix(auth, test.wantAuthPrefix) {
161+
t.Errorf("Authorization header should start with %q, got: %q", test.wantAuthPrefix, auth)
162+
}
163+
}
164+
165+
if test.wantNoHeader != "" {
166+
if _, ok := authHeaders[test.wantNoHeader]; ok {
167+
t.Errorf("header %q should not be present", test.wantNoHeader)
168+
}
169+
}
170+
171+
if test.wantBodyHash {
172+
body := test.credentials["_request_body"]
173+
expectedHash := fmt.Sprintf("%x", sha256.Sum256([]byte(body)))
174+
if got := authHeaders["X-Amz-Content-Sha256"]; got != expectedHash {
175+
t.Errorf("content hash mismatch: want %q, got %q", expectedHash, got)
176+
}
177+
}
178+
})
179+
}
180+
}
181+
182+
func TestRegionFromEndpoint(t *testing.T) {
183+
tests := []struct {
184+
name string
185+
endpoint string
186+
wantRegion string
187+
wantErr bool
188+
}{
189+
{
190+
name: "standard bedrock endpoint",
191+
endpoint: "bedrock-runtime.us-east-1.amazonaws.com",
192+
wantRegion: "us-east-1",
193+
},
194+
{
195+
name: "eu-west-1 region",
196+
endpoint: "bedrock-runtime.eu-west-1.amazonaws.com",
197+
wantRegion: "eu-west-1",
198+
},
199+
{
200+
name: "ap-southeast-1 region",
201+
endpoint: "bedrock-runtime.ap-southeast-1.amazonaws.com",
202+
wantRegion: "ap-southeast-1",
203+
},
204+
{
205+
name: "too few parts",
206+
endpoint: "localhost",
207+
wantErr: true,
208+
},
209+
{
210+
name: "only three parts",
211+
endpoint: "bedrock.us-east-1.com",
212+
wantErr: true,
213+
},
214+
}
215+
216+
for _, test := range tests {
217+
t.Run(test.name, func(t *testing.T) {
218+
region, err := regionFromEndpoint(test.endpoint)
219+
220+
if test.wantErr {
221+
if err == nil {
222+
t.Errorf("expected error but got nil")
223+
}
224+
return
225+
}
226+
227+
if err != nil {
228+
t.Fatalf("unexpected error: %v", err)
229+
}
230+
231+
if region != test.wantRegion {
232+
t.Errorf("region mismatch: want %q, got %q", test.wantRegion, region)
233+
}
234+
})
235+
}
236+
}

0 commit comments

Comments
 (0)