Skip to content

Commit c665d0d

Browse files
committed
add optional redis for backend storage
1 parent 93994bc commit c665d0d

18 files changed

Lines changed: 426 additions & 48 deletions

File tree

backend/auth/handler.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func (ah *AuthHandler) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
118118
ah.respondWithError(w, authReq.ClientType, "failed to generate PKCE", http.StatusInternalServerError)
119119
return
120120
}
121-
if err := ah.sessionStore.SavePKCEVerifier(authReq.SessionID, verifier); err != nil {
121+
if err := ah.sessionStore.SavePKCEVerifier(r.Context(), authReq.SessionID, verifier); err != nil {
122122
logger.Error(err, "failed to store PKCE verifier")
123123
ah.respondWithError(w, authReq.ClientType, "failed to store PKCE verifier", http.StatusInternalServerError)
124124
return
@@ -194,7 +194,7 @@ func (ah *AuthHandler) HandleCallback(w http.ResponseWriter, r *http.Request) {
194194
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
195195
}
196196

197-
verifier, err := ah.sessionStore.LoadAndDeletePKCEVerifier(authCode.SessionID)
197+
verifier, err := ah.sessionStore.LoadAndDeletePKCEVerifier(r.Context(), authCode.SessionID)
198198
if err != nil || verifier == "" {
199199
logger.Error(err, "PKCE verifier not found for session; cannot exchange code", "sessionID", authCode.SessionID)
200200
msg := "PKCE verifier not found. If you run multiple backend instances, use a shared session store (e.g. Redis) so the instance handling the callback can read the verifier stored at authorize time."
@@ -217,7 +217,7 @@ func (ah *AuthHandler) HandleCallback(w http.ResponseWriter, r *http.Request) {
217217
return
218218
}
219219
// Set session expiration and store in middleware
220-
err = ah.sessionStore.Save(sessionState)
220+
err = ah.sessionStore.Save(r.Context(), sessionState)
221221
if err != nil {
222222
logger.Error(err, "failed to save session state")
223223
http.Error(w, "internal error", http.StatusInternalServerError)

backend/auth/middleware.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func (am *AuthMiddleware) verifyState(next http.Handler) http.Handler {
182182
return
183183
}
184184

185-
if !am.isValidSession(state.SessionID) {
185+
if !am.isValidSession(r.Context(), state.SessionID) {
186186
logger.V(2).Info("Session expired or invalid", "sessionID", state.SessionID)
187187
writeErrorResponse(w, http.StatusUnauthorized, kubebindv1alpha2.ErrorCodeAuthenticationFailed, "Authentication required", "Session has expired or is invalid")
188188
return
@@ -197,8 +197,8 @@ func (am *AuthMiddleware) verifyState(next http.Handler) http.Handler {
197197
}
198198

199199
// isValidSession checks if a session ID exists and hasn't expired
200-
func (am *AuthMiddleware) isValidSession(sessionID string) bool {
201-
sessionInfo, err := am.sessionStore.Load(sessionID)
200+
func (am *AuthMiddleware) isValidSession(ctx context.Context, sessionID string) bool {
201+
sessionInfo, err := am.sessionStore.Load(ctx, sessionID)
202202
if err != nil {
203203
return false
204204
}

backend/http/handler.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,14 @@ func NewHandler(
112112
mgr *kubernetes.Manager,
113113
frontend string,
114114
tokenExpiry time.Duration,
115+
sessionStore session.Store,
115116
) (*handler, error) {
116117
// Create JWT service for CLI authentication
117118
jwtService, err := auth.NewJWTService("kube-bind-backend")
118119
if err != nil {
119120
return nil, fmt.Errorf("failed to create JWT service: %w", err)
120121
}
121122

122-
sessionStore := session.NewInMemoryStore()
123-
124123
// Create auth middleware for request authentication
125124
authMiddleware := auth.NewAuthMiddleware(jwtService, cookieSigningKey, cookieEncryptionKey, mgr, sessionStore)
126125

backend/options/options.go

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ import (
3333
)
3434

3535
type Options struct {
36-
Logs *logs.Options
37-
OIDC *OIDC
38-
Cookie *Cookie
39-
Serve *Serve
36+
Logs *logs.Options
37+
OIDC *OIDC
38+
Cookie *Cookie
39+
Serve *Serve
40+
Session *Session
4041

4142
ProviderKcp *providerkcp.Options
4243

@@ -81,10 +82,11 @@ type ExtraOptions struct {
8182
}
8283

8384
type completedOptions struct {
84-
Logs *logs.Options
85-
OIDC *OIDC
86-
Cookie *Cookie
87-
Serve *Serve
85+
Logs *logs.Options
86+
OIDC *OIDC
87+
Cookie *Cookie
88+
Serve *Serve
89+
Session *Session
8890

8991
// Provider specific options
9092
ProviderKcp *providerkcp.CompletedOptions
@@ -106,6 +108,7 @@ func NewOptions() *Options {
106108
OIDC: NewOIDC(),
107109
Cookie: NewCookie(),
108110
Serve: NewServe(),
111+
Session: NewSession(),
109112
ProviderKcp: providerkcp.NewOptions(),
110113

111114
ExtraOptions: ExtraOptions{
@@ -155,6 +158,7 @@ func (options *Options) AddFlags(fs *pflag.FlagSet) {
155158
options.OIDC.AddFlags(fs)
156159
options.Cookie.AddFlags(fs)
157160
options.Serve.AddFlags(fs)
161+
options.Session.AddFlags(fs)
158162
options.ProviderKcp.AddFlags(fs)
159163

160164
fs.StringVar(&options.KubeConfig, "kubeconfig", options.KubeConfig, "path to a kubeconfig. Only required if out-of-cluster")
@@ -207,6 +211,9 @@ func (options *Options) Complete() (*CompletedOptions, error) {
207211
if err := options.Cookie.Complete(); err != nil {
208212
return nil, err
209213
}
214+
if err := options.Session.Complete(); err != nil {
215+
return nil, err
216+
}
210217
}
211218

212219
// normalize the scope and the isolation
@@ -243,6 +250,7 @@ func (options *Options) Complete() (*CompletedOptions, error) {
243250
OIDC: options.OIDC,
244251
Cookie: options.Cookie,
245252
Serve: options.Serve,
253+
Session: options.Session,
246254
ExtraOptions: options.ExtraOptions,
247255
},
248256
}
@@ -276,6 +284,9 @@ func (options *CompletedOptions) Validate() error {
276284
if err := options.Cookie.Validate(); err != nil {
277285
return err
278286
}
287+
if err := options.Session.Validate(); err != nil {
288+
return err
289+
}
279290
}
280291

281292
if options.ConsumerScope != string(kubebindv1alpha2.NamespacedScope) && options.ConsumerScope != string(kubebindv1alpha2.ClusterScope) {

backend/options/session.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
Copyright 2026 The Kube Bind Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package options
18+
19+
import (
20+
"fmt"
21+
22+
"github.com/spf13/pflag"
23+
)
24+
25+
type SessionType string
26+
27+
const (
28+
SessionTypeInMemory SessionType = "inmemory"
29+
SessionTypeRedis SessionType = "redis"
30+
)
31+
32+
type Session struct {
33+
Type SessionType
34+
RedisAddress string
35+
RedisPassword string
36+
}
37+
38+
func NewSession() *Session {
39+
return &Session{
40+
Type: SessionTypeInMemory,
41+
}
42+
}
43+
44+
func (options *Session) AddFlags(fs *pflag.FlagSet) {
45+
fs.StringVar((*string)(&options.Type), "session-storage-backend", string(options.Type), "The session storage backend to use. Possible values are: 'inmemory', 'redis'")
46+
fs.StringVar(&options.RedisAddress, "redis-addr", options.RedisAddress, "The redis address (e.g. localhost:6379) to connect to if session-storage-backend=redis")
47+
fs.StringVar(&options.RedisPassword, "redis-password", options.RedisPassword, "The connection password to use if session-storage-backend=redis")
48+
}
49+
50+
func (options *Session) Complete() error {
51+
return nil
52+
}
53+
54+
func (options *Session) Validate() error {
55+
if options.Type != SessionTypeInMemory && options.Type != SessionTypeRedis {
56+
return fmt.Errorf("unknown session storage backend %q, must be either 'inmemory' or 'redis'", options.Type)
57+
}
58+
59+
if options.Type == SessionTypeRedis && options.RedisAddress == "" {
60+
return fmt.Errorf("redis-addr must be specified when using session-storage-backend=redis")
61+
}
62+
63+
return nil
64+
}

backend/server.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ import (
3838
"github.com/kube-bind/kube-bind/backend/controllers/servicenamespace"
3939
http "github.com/kube-bind/kube-bind/backend/http"
4040
kube "github.com/kube-bind/kube-bind/backend/kubernetes"
41+
backendoptions "github.com/kube-bind/kube-bind/backend/options"
42+
"github.com/kube-bind/kube-bind/backend/session"
4143
kubebindv1alpha2 "github.com/kube-bind/kube-bind/sdk/apis/kubebind/v1alpha2"
4244
)
4345

@@ -126,6 +128,11 @@ func NewServer(ctx context.Context, c *Config) (*Server, error) {
126128
}
127129
}
128130

131+
sessionStore := session.NewInMemoryStore()
132+
if c.Options.Session.Type == backendoptions.SessionTypeRedis {
133+
sessionStore = session.NewRedisStore(c.Options.Session.RedisAddress, c.Options.Session.RedisPassword)
134+
}
135+
129136
handler, err := http.NewHandler(
130137
s,
131138
s.Config.Options.OIDC.OIDCServer,
@@ -140,6 +147,7 @@ func NewServer(ctx context.Context, c *Config) (*Server, error) {
140147
s.Kubernetes,
141148
c.Options.Frontend,
142149
c.Options.TokenExpiry,
150+
sessionStore,
143151
)
144152
if err != nil {
145153
return nil, fmt.Errorf("error setting up HTTP Handler: %w", err)

backend/session/redis.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
Copyright 2026 The Kube Bind Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package session
18+
19+
import (
20+
"context"
21+
"errors"
22+
"fmt"
23+
"time"
24+
25+
"github.com/redis/go-redis/v9"
26+
"github.com/vmihailenco/msgpack/v4"
27+
"k8s.io/klog/v2"
28+
)
29+
30+
type RedisStore struct {
31+
client *redis.Client
32+
}
33+
34+
func NewRedisStore(redisAddr string, redisPassword string) Store {
35+
client := redis.NewClient(&redis.Options{
36+
Addr: redisAddr,
37+
Password: redisPassword,
38+
DB: 0,
39+
})
40+
41+
return &RedisStore{
42+
client: client,
43+
}
44+
}
45+
46+
func (s *RedisStore) Save(ctx context.Context, state *State) error {
47+
encoded, err := state.Encode()
48+
if err != nil {
49+
return fmt.Errorf("failed to encode state: %w", err)
50+
}
51+
52+
key := fmt.Sprintf("session:%s", state.SessionID)
53+
54+
var ttl time.Duration = 0
55+
if !state.ExpiresAt.IsZero() {
56+
ttl = time.Until(state.ExpiresAt)
57+
if ttl <= 0 {
58+
klog.FromContext(context.Background()).V(4).Info("Session already expired, skipping saving to redis", "sessionID", state.SessionID)
59+
return nil
60+
}
61+
}
62+
63+
err = s.client.Set(ctx, key, encoded, ttl).Err()
64+
if err != nil {
65+
return fmt.Errorf("failed to save session to redis: %w", err)
66+
}
67+
return nil
68+
}
69+
70+
func (s *RedisStore) Load(ctx context.Context, sessionID string) (*State, error) {
71+
key := fmt.Sprintf("session:%s", sessionID)
72+
73+
val, err := s.client.Get(ctx, key).Bytes()
74+
if err != nil {
75+
if errors.Is(err, redis.Nil) {
76+
return nil, ErrSessionNotFound
77+
}
78+
return nil, fmt.Errorf("failed to load session from redis: %w", err)
79+
}
80+
81+
var state State
82+
err = msgpack.Unmarshal(val, &state)
83+
if err != nil {
84+
return nil, fmt.Errorf("failed to decode state from redis: %w", err)
85+
}
86+
87+
return &state, nil
88+
}
89+
90+
func (s *RedisStore) Delete(ctx context.Context, sessionID string) error {
91+
key := fmt.Sprintf("session:%s", sessionID)
92+
err := s.client.Del(ctx, key).Err()
93+
if err != nil {
94+
return fmt.Errorf("failed to delete session from redis: %w", err)
95+
}
96+
return nil
97+
}
98+
99+
func (s *RedisStore) SavePKCEVerifier(ctx context.Context, sessionID, verifier string) error {
100+
if sessionID == "" || verifier == "" {
101+
return errors.New("sessionID and verifier cannot be empty")
102+
}
103+
104+
key := fmt.Sprintf("pkce:%s", sessionID)
105+
106+
err := s.client.Set(ctx, key, verifier, 10*time.Minute).Err()
107+
if err != nil {
108+
return fmt.Errorf("failed to save pkce to redis: %w", err)
109+
}
110+
return nil
111+
}
112+
113+
func (s *RedisStore) LoadAndDeletePKCEVerifier(ctx context.Context, sessionID string) (string, error) {
114+
if sessionID == "" {
115+
return "", ErrPKCEVerifierNotFound
116+
}
117+
118+
key := fmt.Sprintf("pkce:%s", sessionID)
119+
120+
val, err := s.client.Get(ctx, key).Result()
121+
if err != nil {
122+
if errors.Is(err, redis.Nil) {
123+
return "", ErrPKCEVerifierNotFound
124+
}
125+
return "", fmt.Errorf("failed to load pkce from redis: %w", err)
126+
}
127+
128+
_ = s.client.Del(ctx, key).Err()
129+
130+
return val, nil
131+
}

0 commit comments

Comments
 (0)