From 780cbe1a86c093d875c1ea2d31fe794b886436c5 Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Mon, 30 Mar 2026 11:36:19 +0200 Subject: [PATCH 1/4] feat: disconnect upstream refreshing Signed-off-by: maksim.nabokikh --- server/refreshhandlers.go | 42 +++++++++- server/refreshhandlers_test.go | 144 +++++++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 2 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 48fc39f130..a4295586ad 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -10,6 +10,7 @@ import ( "time" "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/pkg/featureflags" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -107,6 +108,10 @@ func newInternalServerError() *refreshError { return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} } +func newUpstreamRefreshError(desc string) *refreshError { + return &refreshError{msg: errInvalidGrant, desc: desc, code: http.StatusBadGateway} +} + func newBadRequestError(desc string) *refreshError { return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} } @@ -271,7 +276,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident) if err != nil { s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err) - return ident, newInternalServerError() + return ident, newUpstreamRefreshError(err.Error()) } return newIdent, nil @@ -327,6 +332,20 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( Groups: rCtx.storageToken.Claims.Groups, } + // Pre-fetch UserIdentity outside the storage transaction to avoid deadlocks with + // storage backends that use a single lock (e.g., memory storage). + // This is used as a fallback when the upstream connector refresh fails. + var cachedIdentity *storage.UserIdentity + if featureflags.SessionsEnabled.Enabled() { + ui, err := s.storage.GetUserIdentity(ctx, rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) + if err != nil { + s.logger.WarnContext(ctx, "failed to pre-fetch user identity for upstream refresh fallback", + "user_id", rCtx.storageToken.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID, "err", err) + } else { + cachedIdentity = &ui + } + } + refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { rotationEnabled := s.refreshTokenPolicy.RotationEnabled() reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) @@ -373,7 +392,26 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( // Dex will call the connector's Refresh method only once if request is not in reuse interval. ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) if rerr != nil { - return old, rerr + // When sessions are enabled and the upstream provider fails (e.g., expired upstream + // refresh token), fall back to claims stored in UserIdentity instead of failing the + // entire refresh. This matches the behavior of other identity brokers (Keycloak, Auth0) + // that do not contact the upstream on every downstream refresh. + if cachedIdentity != nil { + s.logger.WarnContext(ctx, "upstream refresh failed, using cached identity from last login", + "err", rerr, "user_id", cachedIdentity.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID) + ident = connector.Identity{ + UserID: cachedIdentity.Claims.UserID, + Username: cachedIdentity.Claims.Username, + PreferredUsername: cachedIdentity.Claims.PreferredUsername, + Email: cachedIdentity.Claims.Email, + EmailVerified: cachedIdentity.Claims.EmailVerified, + Groups: cachedIdentity.Claims.Groups, + } + rerr = nil + } + if rerr != nil { + return old, rerr + } } // Update the claims of the refresh token. diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index 8db80c31eb..eb2177ada8 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -2,8 +2,10 @@ package server import ( "bytes" + "context" "encoding/base64" "encoding/json" + "errors" "log/slog" "net/http" "net/http/httptest" @@ -16,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/dexidp/dex/connector" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -347,6 +350,147 @@ func TestRefreshTokenAuthTime(t *testing.T) { } } +// failingRefreshConnector implements connector.CallbackConnector and connector.RefreshConnector +// but always returns an error on Refresh, simulating an upstream provider failure. +type failingRefreshConnector struct { + identity connector.Identity +} + +func (f *failingRefreshConnector) LoginURL(_ connector.Scopes, callbackURL, state string) (string, []byte, error) { + u, _ := url.Parse(callbackURL) + v := u.Query() + v.Set("state", state) + u.RawQuery = v.Encode() + return u.String(), nil, nil +} + +func (f *failingRefreshConnector) HandleCallback(_ connector.Scopes, _ []byte, _ *http.Request) (connector.Identity, error) { + return f.identity, nil +} + +func (f *failingRefreshConnector) Refresh(_ context.Context, _ connector.Scopes, _ connector.Identity) (connector.Identity, error) { + return connector.Identity{}, errors.New("upstream: refresh token expired") +} + +func TestUpstreamRefreshFailureFallsBackToUserIdentity(t *testing.T) { + t0 := time.Now().UTC().Round(time.Second) + loginTime := t0.Add(-10 * time.Minute) + + tests := []struct { + name string + sessionsEnabled bool + createUserIdentity bool + wantOK bool + }{ + { + name: "sessions enabled with user identity - fallback succeeds", + sessionsEnabled: true, + createUserIdentity: true, + wantOK: true, + }, + { + name: "sessions enabled without user identity - fallback fails", + sessionsEnabled: true, + createUserIdentity: false, + wantOK: false, + }, + { + name: "sessions disabled - no fallback, error returned", + sessionsEnabled: false, + createUserIdentity: false, + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + setSessionsEnabled(t, tc.sessionsEnabled) + + httpServer, s := newTestServer(t, func(c *Config) { + c.Now = func() time.Time { return t0 } + }) + defer httpServer.Close() + + if tc.sessionsEnabled { + s.sessionConfig = &SessionConfig{ + CookieName: "dex_session", + AbsoluteLifetime: 24 * time.Hour, + } + } + + mockRefreshTokenTestStorage(t, s.storage, false) + + // Replace the connector with one that always fails on Refresh. + // ResourceVersion must match the storage connector (empty by default in + // mockRefreshTokenTestStorage) to prevent getConnector from re-opening it. + s.mu.Lock() + s.connectors["test"] = Connector{ + Connector: &failingRefreshConnector{ + identity: connector.Identity{ + UserID: "0-385-28089-0", + Username: "Kilgore Trout", + Email: "kilgore@kilgore.trout", + }, + }, + } + s.mu.Unlock() + + if tc.createUserIdentity { + err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{ + UserID: "1", + ConnectorID: "test", + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + CreatedAt: loginTime, + LastLogin: loginTime, + }) + require.NoError(t, err) + } + + u, err := url.Parse(s.issuerURL.String()) + require.NoError(t, err) + + tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) + require.NoError(t, err) + + u.Path = path.Join(u.Path, "/token") + v := url.Values{} + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", tokenData) + + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + req.SetBasicAuth("test", "barfoo") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + if tc.wantOK { + require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String()) + + var resp struct { + IDToken string `json:"id_token"` + } + err = json.Unmarshal(rr.Body.Bytes(), &resp) + require.NoError(t, err) + + // Verify the returned claims match UserIdentity, not the connector. + claims := decodeJWTClaims(t, resp.IDToken) + assert.Equal(t, "jane.doe@example.com", claims["email"]) + assert.Equal(t, "jane", claims["name"]) + } else { + require.NotEqual(t, http.StatusOK, rr.Code, + "expected error when upstream fails without fallback") + } + }) + } +} + func TestRefreshTokenPolicy(t *testing.T) { lastTime := time.Now() l := slog.New(slog.DiscardHandler) From 4d4c58d8fc3bf76dfa888ac24f956af73ff24953 Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Mon, 6 Apr 2026 18:36:26 +0200 Subject: [PATCH 2/4] Use the C approach Signed-off-by: maksim.nabokikh --- server/refreshhandlers.go | 42 ++++++++++++++-------------------- server/refreshhandlers_test.go | 16 ++++++------- 2 files changed, 25 insertions(+), 33 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index a4295586ad..0e0e85aef9 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -332,18 +332,17 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( Groups: rCtx.storageToken.Claims.Groups, } - // Pre-fetch UserIdentity outside the storage transaction to avoid deadlocks with - // storage backends that use a single lock (e.g., memory storage). - // This is used as a fallback when the upstream connector refresh fails. - var cachedIdentity *storage.UserIdentity + // When sessions are enabled, use claims from UserIdentity instead of refreshing + // the upstream token. This disconnects downstream refresh from the upstream provider. + var userIdent *storage.UserIdentity if featureflags.SessionsEnabled.Enabled() { ui, err := s.storage.GetUserIdentity(ctx, rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) if err != nil { - s.logger.WarnContext(ctx, "failed to pre-fetch user identity for upstream refresh fallback", + s.logger.ErrorContext(ctx, "failed to get user identity for refresh", "user_id", rCtx.storageToken.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID, "err", err) - } else { - cachedIdentity = &ui + return nil, ident, newInternalServerError() } + userIdent = &ui } refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { @@ -390,25 +389,18 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( // Call only once if there is a request which is not in the reuse interval. // This is required to avoid multiple calls to the external IdP for concurrent requests. // Dex will call the connector's Refresh method only once if request is not in reuse interval. - ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) - if rerr != nil { - // When sessions are enabled and the upstream provider fails (e.g., expired upstream - // refresh token), fall back to claims stored in UserIdentity instead of failing the - // entire refresh. This matches the behavior of other identity brokers (Keycloak, Auth0) - // that do not contact the upstream on every downstream refresh. - if cachedIdentity != nil { - s.logger.WarnContext(ctx, "upstream refresh failed, using cached identity from last login", - "err", rerr, "user_id", cachedIdentity.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID) - ident = connector.Identity{ - UserID: cachedIdentity.Claims.UserID, - Username: cachedIdentity.Claims.Username, - PreferredUsername: cachedIdentity.Claims.PreferredUsername, - Email: cachedIdentity.Claims.Email, - EmailVerified: cachedIdentity.Claims.EmailVerified, - Groups: cachedIdentity.Claims.Groups, - } - rerr = nil + // When sessions are enabled, use cached identity instead of refreshing upstream. + if userIdent != nil { + ident = connector.Identity{ + UserID: userIdent.Claims.UserID, + Username: userIdent.Claims.Username, + PreferredUsername: userIdent.Claims.PreferredUsername, + Email: userIdent.Claims.Email, + EmailVerified: userIdent.Claims.EmailVerified, + Groups: userIdent.Claims.Groups, } + } else { + ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) if rerr != nil { return old, rerr } diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index eb2177ada8..d24caee596 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -351,7 +351,7 @@ func TestRefreshTokenAuthTime(t *testing.T) { } // failingRefreshConnector implements connector.CallbackConnector and connector.RefreshConnector -// but always returns an error on Refresh, simulating an upstream provider failure. +// but always returns an error on Refresh, proving that the upstream is not contacted. type failingRefreshConnector struct { identity connector.Identity } @@ -372,7 +372,7 @@ func (f *failingRefreshConnector) Refresh(_ context.Context, _ connector.Scopes, return connector.Identity{}, errors.New("upstream: refresh token expired") } -func TestUpstreamRefreshFailureFallsBackToUserIdentity(t *testing.T) { +func TestRefreshDisconnectsUpstreamWhenSessionsEnabled(t *testing.T) { t0 := time.Now().UTC().Round(time.Second) loginTime := t0.Add(-10 * time.Minute) @@ -383,19 +383,19 @@ func TestUpstreamRefreshFailureFallsBackToUserIdentity(t *testing.T) { wantOK bool }{ { - name: "sessions enabled with user identity - fallback succeeds", + name: "sessions enabled - uses user identity, skips upstream", sessionsEnabled: true, createUserIdentity: true, wantOK: true, }, { - name: "sessions enabled without user identity - fallback fails", + name: "sessions enabled without user identity - fails", sessionsEnabled: true, createUserIdentity: false, wantOK: false, }, { - name: "sessions disabled - no fallback, error returned", + name: "sessions disabled - upstream failure returns error", sessionsEnabled: false, createUserIdentity: false, wantOK: false, @@ -421,8 +421,8 @@ func TestUpstreamRefreshFailureFallsBackToUserIdentity(t *testing.T) { mockRefreshTokenTestStorage(t, s.storage, false) // Replace the connector with one that always fails on Refresh. - // ResourceVersion must match the storage connector (empty by default in - // mockRefreshTokenTestStorage) to prevent getConnector from re-opening it. + // When sessions are enabled this connector should never be called; + // when sessions are disabled, the failure proves the error path works. s.mu.Lock() s.connectors["test"] = Connector{ Connector: &failingRefreshConnector{ @@ -485,7 +485,7 @@ func TestUpstreamRefreshFailureFallsBackToUserIdentity(t *testing.T) { assert.Equal(t, "jane", claims["name"]) } else { require.NotEqual(t, http.StatusOK, rr.Code, - "expected error when upstream fails without fallback") + "expected error when sessions disabled or user identity missing") } }) } From 90bb8eb4bef8b95d8f3ea0384cb1555b74fccc72 Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Mon, 6 Apr 2026 18:59:26 +0200 Subject: [PATCH 3/4] Fixes according to codereview comments Signed-off-by: maksim.nabokikh --- server/refreshhandlers.go | 17 ++++++++++------- server/refreshhandlers_test.go | 16 +++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 0e0e85aef9..cf0f1b1c31 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -10,7 +10,6 @@ import ( "time" "github.com/dexidp/dex/connector" - "github.com/dexidp/dex/pkg/featureflags" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -108,8 +107,8 @@ func newInternalServerError() *refreshError { return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} } -func newUpstreamRefreshError(desc string) *refreshError { - return &refreshError{msg: errInvalidGrant, desc: desc, code: http.StatusBadGateway} +func newUpstreamRefreshError() *refreshError { + return &refreshError{msg: errInvalidGrant, desc: "Upstream identity provider refresh failed.", code: http.StatusBadGateway} } func newBadRequestError(desc string) *refreshError { @@ -276,7 +275,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident) if err != nil { s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err) - return ident, newUpstreamRefreshError(err.Error()) + return ident, newUpstreamRefreshError() } return newIdent, nil @@ -332,10 +331,14 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( Groups: rCtx.storageToken.Claims.Groups, } - // When sessions are enabled, use claims from UserIdentity instead of refreshing - // the upstream token. This disconnects downstream refresh from the upstream provider. + // When sessions are enabled, downstream token refresh is disconnected from the upstream + // identity provider. Instead of calling the connector's Refresh method (which would contact + // the upstream IdP and may fail if the upstream refresh token has expired), we use the claims + // stored in UserIdentity at the time of the last interactive login. This aligns with the + // behavior of other identity brokers (e.g., Keycloak, Auth0) that treat downstream sessions + // independently from the upstream provider session lifetime. var userIdent *storage.UserIdentity - if featureflags.SessionsEnabled.Enabled() { + if s.sessionConfig != nil { ui, err := s.storage.GetUserIdentity(ctx, rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) if err != nil { s.logger.ErrorContext(ctx, "failed to get user identity for refresh", diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index d24caee596..e870da541f 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -279,17 +279,17 @@ func TestRefreshTokenAuthTime(t *testing.T) { mockRefreshTokenTestStorage(t, s.storage, false) if tc.createUserIdentity { - // The mock connector returns UserID "0-385-28089-0" on Refresh, - // so the UserIdentity must use that ID to be found by handleRefreshToken. + // UserIdentity must match the refresh token's Claims.UserID ("1") + // because updateRefreshToken looks it up by that ID. err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{ - UserID: "0-385-28089-0", + UserID: "1", ConnectorID: "test", Claims: storage.Claims{ - UserID: "0-385-28089-0", - Username: "Kilgore Trout", - Email: "kilgore@kilgore.trout", + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", EmailVerified: true, - Groups: []string{"authors"}, + Groups: []string{"a", "b"}, }, CreatedAt: loginTime, LastLogin: loginTime, @@ -404,8 +404,6 @@ func TestRefreshDisconnectsUpstreamWhenSessionsEnabled(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - setSessionsEnabled(t, tc.sessionsEnabled) - httpServer, s := newTestServer(t, func(c *Config) { c.Now = func() time.Time { return t0 } }) From ec003f5ef91dcbf69fff9deb96275df6d10f00bb Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Fri, 17 Apr 2026 07:53:36 +0200 Subject: [PATCH 4/4] Read User Identity once Signed-off-by: maksim.nabokikh --- server/refreshhandlers.go | 43 +++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index cf0f1b1c31..7d3c7d4cd5 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -312,7 +312,7 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr } // updateRefreshToken updates refresh token and offline session in the storage -func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) { +func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext, userIdent *storage.UserIdentity) (*internal.RefreshToken, connector.Identity, *refreshError) { var rerr *refreshError newToken := &internal.RefreshToken{ @@ -337,17 +337,6 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( // stored in UserIdentity at the time of the last interactive login. This aligns with the // behavior of other identity brokers (e.g., Keycloak, Auth0) that treat downstream sessions // independently from the upstream provider session lifetime. - var userIdent *storage.UserIdentity - if s.sessionConfig != nil { - ui, err := s.storage.GetUserIdentity(ctx, rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) - if err != nil { - s.logger.ErrorContext(ctx, "failed to get user identity for refresh", - "user_id", rCtx.storageToken.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID, "err", err) - return nil, ident, newInternalServerError() - } - userIdent = &ui - } - refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { rotationEnabled := s.refreshTokenPolicy.RotationEnabled() reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) @@ -457,7 +446,24 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx) + var userIdent *storage.UserIdentity + + if s.sessionConfig != nil { + ui, err := s.storage.GetUserIdentity(r.Context(), rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get user identity", "err", err) + s.refreshTokenErrHelper(w, newInternalServerError()) + return + } + userIdent = &ui + } + + authTime := time.Time{} + if userIdent != nil { + authTime = userIdent.LastLogin + } + + newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx, userIdent) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -472,17 +478,6 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - authTime := time.Time{} - if s.sessionConfig != nil { - ui, err := s.storage.GetUserIdentity(r.Context(), ident.UserID, rCtx.storageToken.ConnectorID) - if err != nil { - s.logger.ErrorContext(r.Context(), "failed to get user identity", "err", err) - s.refreshTokenErrHelper(w, newInternalServerError()) - return - } - authTime = ui.LastLogin - } - accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID, authTime) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err)