Skip to content

Commit 577eba4

Browse files
Fix vault gateway auth propagation (#21865)
1 parent 05eeb36 commit 577eba4

File tree

8 files changed

+291
-69
lines changed

8 files changed

+291
-69
lines changed

core/capabilities/vault/allow_list_based_auth.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ import (
1515
)
1616

1717
const (
18-
allowListBasedAuthRetryCount = 3
18+
// The workflow registry syncer polls every 12s by default. Keep the
19+
// retry window comfortably above that so newly allowlisted requests
20+
// can propagate to every node before auth gives up.
21+
allowListBasedAuthRetryCount = 10
1922
allowListBasedAuthRetryInterval = 3 * time.Second
2023
)
2124

core/capabilities/vault/allow_list_based_auth_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ func testAuthForRequests(t *testing.T, allowlistedRequest, notAllowlistedRequest
180180
authResult, err := auth.AuthorizeRequest(t.Context(), allowlistedRequest)
181181
require.NoError(t, err)
182182
require.Equal(t, owner.Hex(), authResult.AuthorizedOwner())
183-
require.Equal(t, expiry, authResult.GetExpiresAt())
184-
require.NotEmpty(t, authResult.GetDigest())
183+
require.Equal(t, expiry, authResult.ExpiresAt())
184+
require.NotEmpty(t, authResult.Digest())
185185

186186
// Same request is still authorized here; replay protection lives in the generic Authorizer.
187187
authResult, err = auth.AuthorizeRequest(t.Context(), allowlistedRequest)
@@ -237,7 +237,7 @@ func TestAllowListBasedAuth_RetriesUntilRequestIsAllowlisted(t *testing.T) {
237237
authResult, err := auth.AuthorizeRequest(t.Context(), req)
238238
require.NoError(t, err)
239239
require.Equal(t, owner.Hex(), authResult.AuthorizedOwner())
240-
require.Equal(t, expiry, authResult.GetExpiresAt())
240+
require.Equal(t, expiry, authResult.ExpiresAt())
241241
}
242242

243243
func TestAllowListBasedAuth_FailsAfterAllowlistReadRetries(t *testing.T) {

core/capabilities/vault/authorizer.go

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@ func NewAuthResult(orgID, workflowOwner, digest string, expiresAt int64) *AuthRe
2929
}
3030
}
3131

32+
// OrgID returns the authorized org ID, if present.
33+
func (a *AuthResult) OrgID() string {
34+
if a == nil {
35+
return ""
36+
}
37+
return a.orgID
38+
}
39+
40+
// WorkflowOwner returns the authorized workflow owner, if present.
41+
func (a *AuthResult) WorkflowOwner() string {
42+
if a == nil {
43+
return ""
44+
}
45+
return a.workflowOwner
46+
}
47+
3248
// AuthorizedOwner returns the canonical owner to use for request scoping.
3349
func (a *AuthResult) AuthorizedOwner() string {
3450
if a == nil {
@@ -40,34 +56,23 @@ func (a *AuthResult) AuthorizedOwner() string {
4056
return a.workflowOwner
4157
}
4258

43-
// GetDigest returns the request digest used for replay protection.
44-
func (a *AuthResult) GetDigest() string {
59+
// Digest returns the request digest used for replay protection.
60+
func (a *AuthResult) Digest() string {
4561
if a == nil {
4662
return ""
4763
}
4864
return a.digest
4965
}
5066

51-
// GetExpiresAt returns the unix timestamp (UTC) after which this
67+
// ExpiresAt returns the unix timestamp (UTC) after which this
5268
// authorization is no longer valid.
53-
func (a *AuthResult) GetExpiresAt() int64 {
69+
func (a *AuthResult) ExpiresAt() int64 {
5470
if a == nil {
5571
return 0
5672
}
5773
return a.expiresAt
5874
}
5975

60-
// GetUntrustedWorkflowOwner returns the workflow owner only for JWTBasedAuth results.
61-
func (a *AuthResult) GetUntrustedWorkflowOwner() (string, error) {
62-
if a == nil {
63-
return "", errors.New("auth result is nil")
64-
}
65-
if a.orgID == "" {
66-
return "", errors.New("untrusted workflow owner only applies to JWTBasedAuth results")
67-
}
68-
return a.workflowOwner, nil
69-
}
70-
7176
// Authorizer selects the applicable auth mechanism for a Vault request.
7277
type Authorizer interface {
7378
AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error)
@@ -99,11 +104,11 @@ func (a *authorizer) AuthorizeRequest(ctx context.Context, req jsonrpc.Request[j
99104
a.lggr.Errorw("auth mechanism returned nil auth result", "method", req.Method, "requestID", req.ID, "hasAuth", req.Auth != "")
100105
return nil, err
101106
}
102-
if err := a.replayGuard.CheckAndRecord(authResult.GetDigest(), authResult.GetExpiresAt()); err != nil {
103-
a.lggr.Debugw("replay guard rejected request", "method", req.Method, "requestID", req.ID, "owner", authResult.AuthorizedOwner(), "digest", authResult.GetDigest(), "expiresAt", authResult.GetExpiresAt(), "hasAuth", req.Auth != "", "error", err)
107+
if err := a.replayGuard.CheckAndRecord(authResult.Digest(), authResult.ExpiresAt()); err != nil {
108+
a.lggr.Debugw("replay guard rejected request", "method", req.Method, "requestID", req.ID, "owner", authResult.AuthorizedOwner(), "digest", authResult.Digest(), "expiresAt", authResult.ExpiresAt(), "hasAuth", req.Auth != "", "error", err)
104109
return nil, err
105110
}
106-
a.lggr.Debugw("request authorized", "method", req.Method, "requestID", req.ID, "owner", authResult.AuthorizedOwner(), "digest", authResult.GetDigest(), "expiresAt", authResult.GetExpiresAt(), "hasAuth", req.Auth != "")
111+
a.lggr.Debugw("request authorized", "method", req.Method, "requestID", req.ID, "owner", authResult.AuthorizedOwner(), "digest", authResult.Digest(), "expiresAt", authResult.ExpiresAt(), "hasAuth", req.Auth != "")
107112
return authResult, nil
108113
}
109114

core/capabilities/vault/authorizer_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ func TestAuthorizer_UsesJWTWhenGateEnabled(t *testing.T) {
6666

6767
authResult, err := a.AuthorizeRequest(t.Context(), req)
6868
require.NoError(t, err)
69+
require.Equal(t, "org-1", authResult.OrgID())
70+
require.Equal(t, "0xworkflow", authResult.WorkflowOwner())
6971
require.Equal(t, "org-1", authResult.AuthorizedOwner())
70-
71-
untrustedWorkflowOwner, err := authResult.GetUntrustedWorkflowOwner()
72-
require.NoError(t, err)
73-
require.Equal(t, "0xworkflow", untrustedWorkflowOwner)
7472
}
7573

7674
func TestAuthorizer_DelegatesDigestVerificationToJWTAuth(t *testing.T) {
@@ -91,6 +89,8 @@ func TestAuthorizer_DelegatesDigestVerificationToJWTAuth(t *testing.T) {
9189

9290
authResult, err := a.AuthorizeRequest(t.Context(), req)
9391
require.NoError(t, err)
92+
require.Equal(t, "org-1", authResult.OrgID())
93+
require.Empty(t, authResult.WorkflowOwner())
9494
require.Equal(t, "org-1", authResult.AuthorizedOwner())
9595
}
9696

@@ -133,6 +133,8 @@ func TestAuthorizer_RejectsAllowListBasedAuthReplay(t *testing.T) {
133133

134134
authResult, err := a.AuthorizeRequest(t.Context(), req)
135135
require.NoError(t, err)
136+
require.Empty(t, authResult.OrgID())
137+
require.Equal(t, "0xabc", authResult.WorkflowOwner())
136138
require.Equal(t, "0xabc", authResult.AuthorizedOwner())
137139

138140
authResult, err = a.AuthorizeRequest(t.Context(), req)

core/capabilities/vault/gw_handler.go

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -162,33 +162,33 @@ func (h *GatewayHandler) HandleGatewayMessage(ctx context.Context, gatewayID str
162162
var response *jsonrpc.Response[json.RawMessage]
163163
switch req.Method {
164164
case vaulttypes.MethodSecretsCreate:
165-
owner, authErr := h.authorizeAndPrefixRequest(ctx, req)
165+
authResult, authErr := h.authorizeAndPrefixRequest(ctx, req)
166166
if authErr != nil {
167167
response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr)
168168
break
169169
}
170-
response = h.handleSecretsCreate(ctx, gatewayID, req, owner)
170+
response = h.handleSecretsCreate(ctx, gatewayID, req, authResult)
171171
case vaulttypes.MethodSecretsUpdate:
172-
owner, authErr := h.authorizeAndPrefixRequest(ctx, req)
172+
authResult, authErr := h.authorizeAndPrefixRequest(ctx, req)
173173
if authErr != nil {
174174
response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr)
175175
break
176176
}
177-
response = h.handleSecretsUpdate(ctx, gatewayID, req, owner)
177+
response = h.handleSecretsUpdate(ctx, gatewayID, req, authResult)
178178
case vaulttypes.MethodSecretsDelete:
179-
owner, authErr := h.authorizeAndPrefixRequest(ctx, req)
179+
authResult, authErr := h.authorizeAndPrefixRequest(ctx, req)
180180
if authErr != nil {
181181
response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr)
182182
break
183183
}
184-
response = h.handleSecretsDelete(ctx, gatewayID, req, owner)
184+
response = h.handleSecretsDelete(ctx, gatewayID, req, authResult)
185185
case vaulttypes.MethodSecretsList:
186-
owner, authErr := h.authorizeAndPrefixRequest(ctx, req)
186+
authResult, authErr := h.authorizeAndPrefixRequest(ctx, req)
187187
if authErr != nil {
188188
response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr)
189189
break
190190
}
191-
response = h.handleSecretsList(ctx, gatewayID, req, owner)
191+
response = h.handleSecretsList(ctx, gatewayID, req, authResult)
192192
case vaulttypes.MethodPublicKeyGet:
193193
response = h.handlePublicKeyGet(ctx, gatewayID, req)
194194
default:
@@ -207,11 +207,11 @@ func (h *GatewayHandler) HandleGatewayMessage(ctx context.Context, gatewayID str
207207
return nil
208208
}
209209

210-
func (h *GatewayHandler) authorizeAndPrefixRequest(ctx context.Context, req *jsonrpc.Request[json.RawMessage]) (string, error) {
210+
func (h *GatewayHandler) authorizeAndPrefixRequest(ctx context.Context, req *jsonrpc.Request[json.RawMessage]) (*AuthResult, error) {
211211
if h.authorizer == nil {
212212
err := errors.New("authorizer is nil")
213213
h.lggr.Errorw("failed to authorize gateway request", "method", req.Method, "requestID", req.ID, "error", err)
214-
return "", err
214+
return nil, err
215215
}
216216

217217
originalRequestID := req.ID
@@ -225,26 +225,26 @@ func (h *GatewayHandler) authorizeAndPrefixRequest(ctx context.Context, req *jso
225225
authReq.ID = originalRequestID
226226
if err := stripPrefixedRequestIDFromParams(&authReq, originalRequestID); err != nil {
227227
h.lggr.Errorw("failed to normalize gateway request for authorization", "method", req.Method, "requestID", originalRequestID, "error", err)
228-
return "", err
228+
return nil, err
229229
}
230230

231231
h.lggr.Debugw("authorizing gateway request", "method", req.Method, "requestID", originalRequestID)
232232
authResult, err := h.authorizer.AuthorizeRequest(ctx, authReq)
233233
if err != nil {
234234
authErr := fmt.Errorf("request not authorized: %w", err)
235235
h.lggr.Errorw("gateway request authorization failed", "method", req.Method, "requestID", originalRequestID, "hasAuth", req.Auth != "", "incomingOwner", incomingOwner, "error", authErr)
236-
return "", authErr
236+
return nil, authErr
237237
}
238238
authorizedOwner := authResult.AuthorizedOwner()
239239
if incomingOwner != "" && normalizeOwner(incomingOwner) != normalizeOwner(authorizedOwner) {
240240
prefixErr := fmt.Errorf("request owner prefix %q does not match authorized owner %q", incomingOwner, authorizedOwner)
241241
h.lggr.Errorw("gateway request owner prefix mismatch", "method", req.Method, "requestID", originalRequestID, "incomingOwner", incomingOwner, "authorizedOwner", authorizedOwner, "error", prefixErr)
242-
return "", prefixErr
242+
return nil, prefixErr
243243
}
244244

245245
req.ID = authorizedOwner + vaulttypes.RequestIDSeparator + originalRequestID
246-
h.lggr.Debugw("authorized gateway request", "method", req.Method, "requestID", req.ID, "owner", authorizedOwner)
247-
return authorizedOwner, nil
246+
h.lggr.Debugw("authorized gateway request", "method", req.Method, "requestID", req.ID, "owner", authorizedOwner, "orgID", authResult.OrgID(), "workflowOwner", authResult.WorkflowOwner())
247+
return authResult, nil
248248
}
249249

250250
func stripPrefixedRequestIDFromParams(req *jsonrpc.Request[json.RawMessage], originalRequestID string) error {
@@ -296,18 +296,50 @@ func rewriteRequestParams(req *jsonrpc.Request[json.RawMessage], payload any) er
296296
return nil
297297
}
298298

299-
func (h *GatewayHandler) handleSecretsCreate(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], owner string) *jsonrpc.Response[json.RawMessage] {
299+
func setAuthorizedIdentityFields(req any, authResult *AuthResult) error {
300+
if authResult == nil {
301+
return errors.New("auth result is nil")
302+
}
303+
304+
// Critical: the Vault capability trusts OrgId and WorkflowOwner to be set by
305+
// the Vault node only after authorization and request validation succeed. We
306+
// must overwrite any JSON-provided values here; otherwise a malicious request
307+
// could smuggle mismatched identity fields into the capability call.
308+
switch r := req.(type) {
309+
case *vaultcommon.CreateSecretsRequest:
310+
r.OrgId = authResult.OrgID()
311+
r.WorkflowOwner = authResult.WorkflowOwner()
312+
case *vaultcommon.UpdateSecretsRequest:
313+
r.OrgId = authResult.OrgID()
314+
r.WorkflowOwner = authResult.WorkflowOwner()
315+
case *vaultcommon.DeleteSecretsRequest:
316+
r.OrgId = authResult.OrgID()
317+
r.WorkflowOwner = authResult.WorkflowOwner()
318+
case *vaultcommon.ListSecretIdentifiersRequest:
319+
r.OrgId = authResult.OrgID()
320+
r.WorkflowOwner = authResult.WorkflowOwner()
321+
default:
322+
return fmt.Errorf("unsupported vault request type %T", req)
323+
}
324+
325+
return nil
326+
}
327+
328+
func (h *GatewayHandler) handleSecretsCreate(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], authResult *AuthResult) *jsonrpc.Response[json.RawMessage] {
300329
vaultCapRequest := vaultcommon.CreateSecretsRequest{}
301330
if err := json.Unmarshal(*req.Params, &vaultCapRequest); err != nil {
302331
return h.errorResponse(ctx, gatewayID, req, api.UserMessageParseError, err)
303332
}
304333

334+
authorizedOwner := authResult.AuthorizedOwner()
305335
vaultCapRequest.RequestId = req.ID
306-
vaultCapRequest.WorkflowOwner = owner
336+
if err := setAuthorizedIdentityFields(&vaultCapRequest, authResult); err != nil {
337+
return h.errorResponse(ctx, gatewayID, req, api.FatalError, err)
338+
}
307339
for idx, encryptedSecret := range vaultCapRequest.EncryptedSecrets {
308-
if encryptedSecret != nil && encryptedSecret.Id != nil && normalizeOwner(encryptedSecret.Id.Owner) != normalizeOwner(owner) {
309-
h.lggr.Debugw("create secrets request owner mismatch", "requestID", req.ID, "secretOwner", encryptedSecret.Id.Owner, "authorizedOwner", owner, "index", idx)
310-
return h.errorResponse(ctx, gatewayID, req, api.FatalError, fmt.Errorf("secret ID owner %q does not match authorized owner %q at index %d", encryptedSecret.Id.Owner, owner, idx))
340+
if encryptedSecret != nil && encryptedSecret.Id != nil && normalizeOwner(encryptedSecret.Id.Owner) != normalizeOwner(authorizedOwner) {
341+
h.lggr.Debugw("create secrets request owner mismatch", "requestID", req.ID, "secretOwner", encryptedSecret.Id.Owner, "authorizedOwner", authorizedOwner, "index", idx)
342+
return h.errorResponse(ctx, gatewayID, req, api.FatalError, fmt.Errorf("secret ID owner %q does not match authorized owner %q at index %d", encryptedSecret.Id.Owner, authorizedOwner, idx))
311343
}
312344
}
313345

@@ -324,17 +356,20 @@ func (h *GatewayHandler) handleSecretsCreate(ctx context.Context, gatewayID stri
324356
return jsonResponse
325357
}
326358

327-
func (h *GatewayHandler) handleSecretsUpdate(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], owner string) *jsonrpc.Response[json.RawMessage] {
359+
func (h *GatewayHandler) handleSecretsUpdate(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], authResult *AuthResult) *jsonrpc.Response[json.RawMessage] {
328360
vaultCapRequest := vaultcommon.UpdateSecretsRequest{}
329361
if err := json.Unmarshal(*req.Params, &vaultCapRequest); err != nil {
330362
return h.errorResponse(ctx, gatewayID, req, api.UserMessageParseError, err)
331363
}
364+
authorizedOwner := authResult.AuthorizedOwner()
332365
vaultCapRequest.RequestId = req.ID
333-
vaultCapRequest.WorkflowOwner = owner
366+
if err := setAuthorizedIdentityFields(&vaultCapRequest, authResult); err != nil {
367+
return h.errorResponse(ctx, gatewayID, req, api.FatalError, err)
368+
}
334369
for idx, encryptedSecret := range vaultCapRequest.EncryptedSecrets {
335-
if encryptedSecret != nil && encryptedSecret.Id != nil && normalizeOwner(encryptedSecret.Id.Owner) != normalizeOwner(owner) {
336-
h.lggr.Debugw("update secrets request owner mismatch", "requestID", req.ID, "secretOwner", encryptedSecret.Id.Owner, "authorizedOwner", owner, "index", idx)
337-
return h.errorResponse(ctx, gatewayID, req, api.FatalError, fmt.Errorf("secret ID owner %q does not match authorized owner %q at index %d", encryptedSecret.Id.Owner, owner, idx))
370+
if encryptedSecret != nil && encryptedSecret.Id != nil && normalizeOwner(encryptedSecret.Id.Owner) != normalizeOwner(authorizedOwner) {
371+
h.lggr.Debugw("update secrets request owner mismatch", "requestID", req.ID, "secretOwner", encryptedSecret.Id.Owner, "authorizedOwner", authorizedOwner, "index", idx)
372+
return h.errorResponse(ctx, gatewayID, req, api.FatalError, fmt.Errorf("secret ID owner %q does not match authorized owner %q at index %d", encryptedSecret.Id.Owner, authorizedOwner, idx))
338373
}
339374
}
340375

@@ -351,17 +386,20 @@ func (h *GatewayHandler) handleSecretsUpdate(ctx context.Context, gatewayID stri
351386
return jsonResponse
352387
}
353388

354-
func (h *GatewayHandler) handleSecretsDelete(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], owner string) *jsonrpc.Response[json.RawMessage] {
389+
func (h *GatewayHandler) handleSecretsDelete(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], authResult *AuthResult) *jsonrpc.Response[json.RawMessage] {
355390
r := &vaultcommon.DeleteSecretsRequest{}
356391
if err := json.Unmarshal(*req.Params, r); err != nil {
357392
return h.errorResponse(ctx, gatewayID, req, api.UserMessageParseError, err)
358393
}
394+
authorizedOwner := authResult.AuthorizedOwner()
359395
r.RequestId = req.ID
360-
r.WorkflowOwner = owner
396+
if err := setAuthorizedIdentityFields(r, authResult); err != nil {
397+
return h.errorResponse(ctx, gatewayID, req, api.FatalError, err)
398+
}
361399
for idx, secretID := range r.Ids {
362-
if secretID != nil && normalizeOwner(secretID.Owner) != normalizeOwner(owner) {
363-
h.lggr.Debugw("delete secrets request owner mismatch", "requestID", req.ID, "secretOwner", secretID.Owner, "authorizedOwner", owner, "index", idx)
364-
return h.errorResponse(ctx, gatewayID, req, api.FatalError, fmt.Errorf("secret ID owner %q does not match authorized owner %q at index %d", secretID.Owner, owner, idx))
400+
if secretID != nil && normalizeOwner(secretID.Owner) != normalizeOwner(authorizedOwner) {
401+
h.lggr.Debugw("delete secrets request owner mismatch", "requestID", req.ID, "secretOwner", secretID.Owner, "authorizedOwner", authorizedOwner, "index", idx)
402+
return h.errorResponse(ctx, gatewayID, req, api.FatalError, fmt.Errorf("secret ID owner %q does not match authorized owner %q at index %d", secretID.Owner, authorizedOwner, idx))
365403
}
366404
}
367405

@@ -384,13 +422,16 @@ func (h *GatewayHandler) handleSecretsDelete(ctx context.Context, gatewayID stri
384422
}
385423
}
386424

387-
func (h *GatewayHandler) handleSecretsList(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], owner string) *jsonrpc.Response[json.RawMessage] {
425+
func (h *GatewayHandler) handleSecretsList(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage], authResult *AuthResult) *jsonrpc.Response[json.RawMessage] {
388426
r := &vaultcommon.ListSecretIdentifiersRequest{}
389427
if err := json.Unmarshal(*req.Params, r); err != nil {
390428
return h.errorResponse(ctx, gatewayID, req, api.UserMessageParseError, err)
391429
}
392430
r.RequestId = req.ID
393-
r.Owner = owner
431+
r.Owner = authResult.AuthorizedOwner()
432+
if err := setAuthorizedIdentityFields(r, authResult); err != nil {
433+
return h.errorResponse(ctx, gatewayID, req, api.FatalError, err)
434+
}
394435

395436
h.lggr.Debugf("Processing authorized and normalized list secrets request [%s]", r.String())
396437
resp, err := h.secretsService.ListSecretIdentifiers(ctx, r)

0 commit comments

Comments
 (0)