Skip to content

Commit 72f0383

Browse files
vault: refresh jwks in background
1 parent 1ee853c commit 72f0383

4 files changed

Lines changed: 115 additions & 8 deletions

File tree

core/capabilities/vault/gw_handler.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ type GatewayHandler struct {
8080
secretsService vaulttypes.SecretsService
8181
gatewayConnector gatewayConnector
8282
authorizer Authorizer
83+
jwtAuthService services.Service
8384
lggr logger.Logger
8485
metrics *metrics
8586
}
@@ -97,11 +98,12 @@ func NewGatewayHandler(secretsService vaulttypes.SecretsService, connector gatew
9798
return nil, fmt.Errorf("failed to create JWTBasedAuth: %w", err)
9899
}
99100
cfg.authorizer = NewAuthorizer(allowListBasedAuth, jwtBasedAuth, lggr)
101+
return newGatewayHandlerWithAuthorizer(secretsService, connector, cfg.authorizer, jwtBasedAuth, lggr)
100102
}
101-
return newGatewayHandlerWithAuthorizer(secretsService, connector, cfg.authorizer, lggr)
103+
return newGatewayHandlerWithAuthorizer(secretsService, connector, cfg.authorizer, nil, lggr)
102104
}
103105

104-
func newGatewayHandlerWithAuthorizer(secretsService vaulttypes.SecretsService, connector gatewayConnector, authorizer Authorizer, lggr logger.Logger) (*GatewayHandler, error) {
106+
func newGatewayHandlerWithAuthorizer(secretsService vaulttypes.SecretsService, connector gatewayConnector, authorizer Authorizer, jwtAuthService services.Service, lggr logger.Logger) (*GatewayHandler, error) {
105107
metrics, err := newMetrics()
106108
if err != nil {
107109
return nil, fmt.Errorf("failed to create metrics: %w", err)
@@ -111,6 +113,7 @@ func newGatewayHandlerWithAuthorizer(secretsService vaulttypes.SecretsService, c
111113
secretsService: secretsService,
112114
gatewayConnector: connector,
113115
authorizer: authorizer,
116+
jwtAuthService: jwtAuthService,
114117
lggr: lggr.Named(HandlerName),
115118
metrics: metrics,
116119
}
@@ -123,17 +126,26 @@ func newGatewayHandlerWithAuthorizer(secretsService vaulttypes.SecretsService, c
123126
}
124127

125128
func (h *GatewayHandler) start(ctx context.Context) error {
129+
if h.jwtAuthService != nil {
130+
if err := h.jwtAuthService.Start(ctx); err != nil {
131+
return fmt.Errorf("failed to start JWTBasedAuth: %w", err)
132+
}
133+
}
126134
if gwerr := h.gatewayConnector.AddHandler(ctx, h.Methods(), h); gwerr != nil {
127135
return fmt.Errorf("failed to add vault handler to connector: %w", gwerr)
128136
}
129137
return nil
130138
}
131139

132140
func (h *GatewayHandler) close() error {
141+
var jwtAuthErr error
142+
if h.jwtAuthService != nil {
143+
jwtAuthErr = h.jwtAuthService.Close()
144+
}
133145
if gwerr := h.gatewayConnector.RemoveHandler(context.Background(), h.Methods()); gwerr != nil {
134-
return fmt.Errorf("failed to remove vault handler from connector: %w", gwerr)
146+
return errors.Join(fmt.Errorf("failed to remove vault handler from connector: %w", gwerr), jwtAuthErr)
135147
}
136-
return nil
148+
return jwtAuthErr
137149
}
138150

139151
func (h *GatewayHandler) ID(ctx context.Context) (string, error) {

core/capabilities/vault/jwt_based_auth.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2"
2020
"github.com/smartcontractkit/chainlink-common/pkg/logger"
21+
"github.com/smartcontractkit/chainlink-common/pkg/services"
2122
"github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings"
2223
"github.com/smartcontractkit/chainlink-common/pkg/settings/limits"
2324
)
@@ -75,11 +76,15 @@ type jsonWebKeySet struct {
7576
//
7677
// Reference: cre-platform-graphql/internal/auth/jwt_auth0.go
7778
type jwtBasedAuth struct {
79+
services.Service
80+
eng *services.Engine
81+
7882
issuerURL string
7983
audience string
8084
jwksURL string
8185
refreshInterval time.Duration
8286
authEnabledGate limits.GateLimiter
87+
refreshEnabled bool
8388

8489
mu sync.RWMutex
8590
keySet *jsonWebKeySet
@@ -145,15 +150,23 @@ func NewJWTBasedAuth(cfg JWTBasedAuthConfig, limitsFactory limits.Factory, lggr
145150
httpClient = &http.Client{Timeout: defaultHTTPTimeout}
146151
}
147152

148-
return &jwtBasedAuth{
153+
v := &jwtBasedAuth{
149154
issuerURL: cfg.IssuerURL,
150155
audience: cfg.Audience,
151156
jwksURL: jwksURL,
152157
refreshInterval: refreshInterval,
153158
authEnabledGate: options.authEnabledGate,
159+
refreshEnabled: !options.skipConfigChecks,
154160
httpClient: httpClient,
155161
lggr: logger.Named(lggr, "VaultJWTBasedAuth"),
156-
}, nil
162+
}
163+
v.Service, v.eng = services.Config{
164+
Name: "VaultJWTBasedAuth",
165+
Start: v.start,
166+
Close: v.close,
167+
}.NewServiceEngine(v.lggr)
168+
169+
return v, nil
157170
}
158171

159172
func newVaultJWTAuthEnabledGateLimiter(limitsFactory limits.Factory, lggr logger.Logger) limits.GateLimiter {
@@ -166,6 +179,24 @@ func newVaultJWTAuthEnabledGateLimiter(limitsFactory limits.Factory, lggr logger
166179
return limiter
167180
}
168181

182+
func (v *jwtBasedAuth) start(context.Context) error {
183+
if !v.refreshEnabled {
184+
v.lggr.Debug("JWTBasedAuth periodic JWKS refresh disabled")
185+
return nil
186+
}
187+
188+
v.eng.GoTick(services.NewTicker(v.refreshInterval), func(ctx context.Context) {
189+
if err := v.refreshJWKS(ctx); err != nil {
190+
v.lggr.Warnw("periodic JWKS refresh failed", "error", err)
191+
}
192+
})
193+
return nil
194+
}
195+
196+
func (v *jwtBasedAuth) close() error {
197+
return v.authEnabledGate.Close()
198+
}
199+
169200
// AuthorizeRequest verifies JWTBasedAuth state and token claims, and returns a common AuthResult.
170201
func (v *jwtBasedAuth) AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) {
171202
isEnabled, err := v.authEnabledGate.Limit(ctx)

core/capabilities/vault/jwt_based_auth_test.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ type testJWKSServer struct {
3636
server *httptest.Server
3737
mu sync.Mutex
3838
keys []testRSAKey
39+
hits chan struct{}
3940
}
4041

4142
func newTestJWKSServer(t *testing.T, keys ...testRSAKey) *testJWKSServer {
4243
t.Helper()
43-
s := &testJWKSServer{keys: keys}
44+
s := &testJWKSServer{keys: keys, hits: make(chan struct{}, 32)}
4445
mux := http.NewServeMux()
4546
mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
4647
s.mu.Lock()
4748
currentKeys := s.keys
4849
s.mu.Unlock()
50+
select {
51+
case s.hits <- struct{}{}:
52+
default:
53+
}
4954

5055
ks := jsonWebKeySet{}
5156
for _, k := range currentKeys {
@@ -62,6 +67,21 @@ func newTestJWKSServer(t *testing.T, keys ...testRSAKey) *testJWKSServer {
6267

6368
func (s *testJWKSServer) URL() string { return s.server.URL }
6469

70+
func (s *testJWKSServer) waitForHits(t *testing.T, count int, timeout time.Duration) {
71+
t.Helper()
72+
73+
deadline := time.NewTimer(timeout)
74+
defer deadline.Stop()
75+
76+
for range count {
77+
select {
78+
case <-s.hits:
79+
case <-deadline.C:
80+
t.Fatalf("timed out waiting for %d JWKS hits", count)
81+
}
82+
}
83+
}
84+
6585
func (s *testJWKSServer) setKeys(keys ...testRSAKey) {
6686
s.mu.Lock()
6787
defer s.mu.Unlock()
@@ -413,6 +433,34 @@ func TestJWTBasedAuth_JWKSServerUnavailable(t *testing.T) {
413433
assert.ErrorIs(t, err, ErrJWKSKeyNotFound)
414434
}
415435

436+
func TestJWTBasedAuth_StartRefreshesJWKSPeriodically(t *testing.T) {
437+
rsaKey := generateTestRSAKey(t, "key-1")
438+
jwksServer := newTestJWKSServer(t, rsaKey)
439+
440+
v, err := NewJWTBasedAuth(JWTBasedAuthConfig{
441+
IssuerURL: jwksServer.URL() + "/",
442+
Audience: "https://api.test.chain.link",
443+
JWKSRefreshInterval: 10 * time.Millisecond,
444+
}, limits.Factory{Settings: cresettings.DefaultGetter}, logger.TestLogger(t), WithJWTBasedAuthGateLimiter(limits.NewGateLimiter(true)))
445+
require.NoError(t, err)
446+
447+
require.NoError(t, v.Start(t.Context()))
448+
jwksServer.waitForHits(t, 2, time.Second)
449+
require.NoError(t, v.Close())
450+
}
451+
452+
func TestJWTBasedAuth_DisabledStartSkipsPeriodicRefresh(t *testing.T) {
453+
v, err := NewJWTBasedAuth(
454+
JWTBasedAuthConfig{},
455+
limits.Factory{Settings: cresettings.DefaultGetter},
456+
logger.TestLogger(t),
457+
WithDisabledJWTBasedAuth(),
458+
)
459+
require.NoError(t, err)
460+
require.NoError(t, v.Start(t.Context()))
461+
require.NoError(t, v.Close())
462+
}
463+
416464
func TestNewJWTBasedAuth_InvalidConfig(t *testing.T) {
417465
lggr := logger.TestLogger(t)
418466

core/services/gateway/handlers/vault/handler.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ type handler struct {
138138
mu sync.RWMutex
139139
stopCh services.StopChan
140140
authorizer vaultcap.Authorizer
141+
jwtAuth services.Service
141142
*vaultcap.RequestValidator
142143

143144
nodeRateLimiter *ratelimit.RateLimiter
@@ -184,10 +185,14 @@ func NewHandler(methodConfig json.RawMessage, donConfig *config.DONConfig, don g
184185
return nil, fmt.Errorf("failed to create JWTBasedAuth: %w", err)
185186
}
186187
authorizer := vaultcap.NewAuthorizer(allowListBasedAuth, jwtBasedAuth, lggr)
187-
return newHandlerWithAuthorizer(methodConfig, donConfig, don, capabilitiesRegistry, authorizer, lggr, clock, limitsFactory)
188+
return newHandlerWithJWTAuth(methodConfig, donConfig, don, capabilitiesRegistry, authorizer, jwtBasedAuth, lggr, clock, limitsFactory)
188189
}
189190

190191
func newHandlerWithAuthorizer(methodConfig json.RawMessage, donConfig *config.DONConfig, don gwhandlers.DON, capabilitiesRegistry capabilitiesRegistry, authorizer vaultcap.Authorizer, lggr logger.Logger, clock clockwork.Clock, limitsFactory limits.Factory) (*handler, error) {
192+
return newHandlerWithJWTAuth(methodConfig, donConfig, don, capabilitiesRegistry, authorizer, nil, lggr, clock, limitsFactory)
193+
}
194+
195+
func newHandlerWithJWTAuth(methodConfig json.RawMessage, donConfig *config.DONConfig, don gwhandlers.DON, capabilitiesRegistry capabilitiesRegistry, authorizer vaultcap.Authorizer, jwtAuth services.Service, lggr logger.Logger, clock clockwork.Clock, limitsFactory limits.Factory) (*handler, error) {
191196
var cfg Config
192197
if err := json.Unmarshal(methodConfig, &cfg); err != nil {
193198
return nil, fmt.Errorf("failed to unmarshal method config: %w", err)
@@ -232,6 +237,7 @@ func newHandlerWithAuthorizer(methodConfig json.RawMessage, donConfig *config.DO
232237
activeRequests: make(map[string]*activeRequest),
233238
mu: sync.RWMutex{},
234239
authorizer: authorizer,
240+
jwtAuth: jwtAuth,
235241
stopCh: make(services.StopChan),
236242
metrics: metrics,
237243
aggregator: &baseAggregator{capabilitiesRegistry: capabilitiesRegistry},
@@ -243,6 +249,11 @@ func newHandlerWithAuthorizer(methodConfig json.RawMessage, donConfig *config.DO
243249
func (h *handler) Start(_ context.Context) error {
244250
return h.StartOnce("VaultHandler", func() error {
245251
h.lggr.Debug("starting vault handler")
252+
if h.jwtAuth != nil {
253+
if err := h.jwtAuth.Start(context.Background()); err != nil {
254+
return fmt.Errorf("failed to start JWTBasedAuth: %w", err)
255+
}
256+
}
246257
go func() {
247258
ctx, cancel := h.stopCh.NewCtx()
248259
defer cancel()
@@ -270,7 +281,12 @@ func (h *handler) Close() error {
270281
return h.StopOnce("VaultHandler", func() error {
271282
h.lggr.Debug("closing vault handler")
272283
close(h.stopCh)
284+
var jwtAuthErr error
285+
if h.jwtAuth != nil {
286+
jwtAuthErr = h.jwtAuth.Close()
287+
}
273288
return errors.Join(
289+
jwtAuthErr,
274290
h.writeMethodsEnabled.Close(),
275291
h.MaxRequestBatchSizeLimiter.Close(),
276292
)

0 commit comments

Comments
 (0)