Skip to content

Commit bd692b8

Browse files
committed
feat: support multiple jwks endpoints
1 parent d20f857 commit bd692b8

5 files changed

Lines changed: 272 additions & 26 deletions

File tree

cmd/serve.go

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"errors"
55
"fmt"
66
"net/http"
7+
_ "net/http/pprof"
78
"os"
89
"path/filepath"
9-
_ "net/http/pprof"
1010
"runtime"
1111
"slices"
12+
"strings"
1213
"sync"
1314
"time"
1415

@@ -24,7 +25,8 @@ import (
2425
"go.uber.org/zap"
2526
)
2627

27-
var jwksUrl, userId string
28+
var jwksUrls []string
29+
var userId string
2830
var ignoreVersionCheck bool
2931
var port int
3032
var shutdownWait int
@@ -55,7 +57,8 @@ to interact and monitor the Scroll Application`,
5557
}
5658

5759
logger.Log().Info("Starting Scroll Daemon")
58-
authorizer, err := services.NewAuthorizer(jwksUrl, userId)
60+
jwksURLs := buildJWKSURLs(jwksUrls)
61+
authorizer, err := services.NewAuthorizer(jwksURLs, userId)
5962
if err != nil {
6063
return err
6164
}
@@ -154,7 +157,7 @@ to interact and monitor the Scroll Application`,
154157
signalHandler := signals.NewSignalHandler(ctx, queueManager, processManager, nil, shutdownWait)
155158
daemonHander := handler.NewDaemonHandler(signalHandler)
156159

157-
s := web.NewServer(jwksUrl, scrollHandler, scrollLogHandler, scrollMetricHandler, annotationHandler, processHandler, queueHandler, websocketHandler, portHandler, healthHandler, coldstarterHandler, daemonHander, authorizer, uiDevHandler, cwd, scrollService.GetDir())
160+
s := web.NewServer(jwksURLs, scrollHandler, scrollLogHandler, scrollMetricHandler, annotationHandler, processHandler, queueHandler, websocketHandler, portHandler, healthHandler, coldstarterHandler, daemonHander, authorizer, uiDevHandler, cwd, scrollService.GetDir())
158161

159162
a := s.Initialize()
160163

@@ -312,16 +315,35 @@ to interact and monitor the Scroll Application`,
312315
},
313316
}
314317

318+
func buildJWKSURLs(values []string) []string {
319+
urls := make([]string, 0, len(values))
320+
seen := make(map[string]struct{}, len(values))
321+
322+
for _, url := range values {
323+
url = strings.TrimSpace(url)
324+
if url == "" {
325+
continue
326+
}
327+
if _, ok := seen[url]; ok {
328+
continue
329+
}
330+
seen[url] = struct{}{}
331+
urls = append(urls, url)
332+
}
333+
334+
return urls
335+
}
336+
315337
func init() {
316338
ServeCommand.Flags().StringVarP(&pprofBind, "pprof", "", "", "Enable pprof on the given bind. This is useful for debugging purposes. E.g. --pprof=localhost:6060 or --pprof=:6060")
317339

318340
ServeCommand.Flags().IntVarP(&port, "port", "p", 8081, "Port")
319341

320342
ServeCommand.Flags().IntVarP(&shutdownWait, "shutdown-wait", "", 10, "Wait interval how long the process is allowed to shutdown. First normal shutdown, then forced shutdown")
321343

322-
ServeCommand.Flags().StringVarP(&jwksUrl, "jwks-server", "", "", "JWKS Server to authenticate requests against")
344+
ServeCommand.Flags().StringSliceVarP(&jwksUrls, "jwks-server", "", nil, "JWKS servers to authenticate requests against. Can be comma-separated or set multiple times")
323345

324-
ServeCommand.Flags().StringVarP(&userId, "user-id", "u", "", "Allowed user ID, if JWKS is not set. It checks claims.sub of the JWT token")
346+
ServeCommand.Flags().StringVarP(&userId, "user-id", "u", "", "Allowed user ID. When JWKS authentication is enabled, checks claims.sub of the JWT token")
325347

326348
ServeCommand.Flags().BoolVarP(&idleScroll, "idle", "", false, "Don't start the queue manager, just use coldstarter")
327349

cmd/serve_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package cmd
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
func TestBuildJWKSURLsNormalizesAndDeduplicates(t *testing.T) {
9+
got := buildJWKSURLs([]string{
10+
" https://old.example/jwks ",
11+
"",
12+
"https://api.druid.gg/auth/v2/jwks",
13+
"https://next.example/jwks",
14+
" https://old.example/jwks ",
15+
"https://api.druid.gg/auth/v2/jwks",
16+
})
17+
18+
want := []string{
19+
"https://old.example/jwks",
20+
"https://api.druid.gg/auth/v2/jwks",
21+
"https://next.example/jwks",
22+
}
23+
24+
if !reflect.DeepEqual(got, want) {
25+
t.Fatalf("buildJWKSURLs() = %#v, want %#v", got, want)
26+
}
27+
}
28+
29+
func TestBuildJWKSURLsEmptyWhenUnset(t *testing.T) {
30+
got := buildJWKSURLs([]string{"", " "})
31+
if len(got) != 0 {
32+
t.Fatalf("buildJWKSURLs() = %#v, want empty slice", got)
33+
}
34+
}

cmd/server/web/server.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ type Server struct {
4545
}
4646

4747
func NewServer(
48-
jwlsUrl string,
48+
jwksURLs []string,
4949
scrollHandler ports.ScrollHandlerInterface,
5050
scrollLogHandler ports.ScrollLogHandlerInterface,
5151
scrollMetricHandler ports.ScrollMetricHandlerInterface,
@@ -88,9 +88,9 @@ func NewServer(
8888
watchHandler: watchHandler,
8989
}
9090

91-
if jwlsUrl != "" {
91+
if len(jwksURLs) > 0 {
9292
server.jwtMiddleware = jwtware.New(jwtware.Config{
93-
KeySetURLs: []string{jwlsUrl},
93+
KeySetURLs: jwksURLs,
9494
})
9595
}
9696

internal/core/services/authorizer_service.go

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package services
33
import (
44
"encoding/json"
55
"errors"
6+
"fmt"
67
"strings"
78
"time"
89

@@ -16,13 +17,14 @@ import (
1617
)
1718

1819
type AuthorizerService struct {
19-
jwksUrl string
20-
jwks *keyfunc.JWKS
21-
userId string
22-
tokens map[string]time.Time
20+
jwksUrls []string
21+
jwks []*keyfunc.JWKS
22+
userId string
23+
tokens map[string]time.Time
2324
}
2425

25-
func NewAuthorizer(jwksURL string, userId string) (ports.AuthorizerServiceInterface, error) {
26+
func NewAuthorizer(jwksURLs []string, userId string) (ports.AuthorizerServiceInterface, error) {
27+
jwksURLs = normalizeJWKSURLs(jwksURLs)
2628

2729
// Create the keyfunc options. Refresh the JWKS every hour and log errors.
2830
var refreshInterval = time.Hour
@@ -33,18 +35,22 @@ func NewAuthorizer(jwksURL string, userId string) (ports.AuthorizerServiceInterf
3335
},
3436
}
3537

36-
if jwksURL != "" {
37-
// Create the JWKS from the resource at the given URL.
38-
var jwks, err = keyfunc.Get(jwksURL, options)
39-
if err != nil {
40-
return nil, err
38+
if len(jwksURLs) > 0 {
39+
jwksList := make([]*keyfunc.JWKS, 0, len(jwksURLs))
40+
for _, jwksURL := range jwksURLs {
41+
// Create the JWKS from the resource at the given URL.
42+
jwks, err := keyfunc.Get(jwksURL, options)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to get JWKS from %q: %w", jwksURL, err)
45+
}
46+
jwksList = append(jwksList, jwks)
4147
}
4248

4349
return &AuthorizerService{
44-
jwks: jwks,
45-
jwksUrl: jwksURL,
46-
userId: userId,
47-
tokens: make(map[string]time.Time),
50+
jwks: jwksList,
51+
jwksUrls: jwksURLs,
52+
userId: userId,
53+
tokens: make(map[string]time.Time),
4854
}, nil
4955

5056
} else {
@@ -54,9 +60,26 @@ func NewAuthorizer(jwksURL string, userId string) (ports.AuthorizerServiceInterf
5460
}
5561
}
5662

63+
func normalizeJWKSURLs(jwksURLs []string) []string {
64+
normalized := make([]string, 0, len(jwksURLs))
65+
seen := make(map[string]struct{}, len(jwksURLs))
66+
for _, jwksURL := range jwksURLs {
67+
jwksURL = strings.TrimSpace(jwksURL)
68+
if jwksURL == "" {
69+
continue
70+
}
71+
if _, ok := seen[jwksURL]; ok {
72+
continue
73+
}
74+
seen[jwksURL] = struct{}{}
75+
normalized = append(normalized, jwksURL)
76+
}
77+
return normalized
78+
}
79+
5780
func (auth *AuthorizerService) CheckHeader(c *fiber.Ctx) (*time.Time, error) {
5881

59-
if auth.jwksUrl == "" {
82+
if len(auth.jwksUrls) == 0 {
6083
return nil, nil
6184
}
6285

@@ -71,8 +94,7 @@ func (auth *AuthorizerService) CheckHeader(c *fiber.Ctx) (*time.Time, error) {
7194
jwtToken := splitToken[1]
7295

7396
// Parse the JWT.
74-
token, err := jwt.Parse(jwtToken, auth.jwks.Keyfunc)
75-
97+
token, err := auth.parseJWT(jwtToken)
7698
if err != nil {
7799
return nil, errors.New("Failed to parse the JWT.\nError: " + err.Error())
78100
}
@@ -105,6 +127,22 @@ func (auth *AuthorizerService) CheckHeader(c *fiber.Ctx) (*time.Time, error) {
105127
return &tm, nil
106128
}
107129

130+
func (auth *AuthorizerService) parseJWT(jwtToken string) (*jwt.Token, error) {
131+
var parseErrors []string
132+
for index, jwks := range auth.jwks {
133+
token, err := jwt.Parse(jwtToken, jwks.Keyfunc)
134+
if err == nil && token.Valid {
135+
return token, nil
136+
}
137+
if err != nil {
138+
parseErrors = append(parseErrors, fmt.Sprintf("JWKS %d: %s", index+1, err.Error()))
139+
continue
140+
}
141+
parseErrors = append(parseErrors, fmt.Sprintf("JWKS %d: token is not valid", index+1))
142+
}
143+
return nil, errors.New(strings.Join(parseErrors, "; "))
144+
}
145+
108146
func (auth *AuthorizerService) CheckQuery(token string) (*time.Time, error) {
109147
if validUntil, ok := auth.tokens[token]; ok {
110148
defer delete(auth.tokens, token)

0 commit comments

Comments
 (0)