From 39767f03f2b045aa5bc2140d4dbb2155f832ab27 Mon Sep 17 00:00:00 2001 From: Tung Wu Date: Thu, 14 May 2026 15:09:24 +0800 Subject: [PATCH] Fix Redis mutex unlock ignoring context cancellation Use context.WithoutCancel in the deferred unlock so the lock is always released even when the request context is cancelled (e.g. client disconnect). Introduce Handle.WithMutex to centralise the lock/unlock pattern and avoid repeating the fix at every call site. --- pkg/lib/infra/redis/handle.go | 12 + pkg/lib/oauth/redis/store.go | 410 ++++++++++--------------- pkg/lib/session/idpsession/provider.go | 117 +++---- 3 files changed, 224 insertions(+), 315 deletions(-) diff --git a/pkg/lib/infra/redis/handle.go b/pkg/lib/infra/redis/handle.go index 14b640a7c42..ac7f1d48772 100644 --- a/pkg/lib/infra/redis/handle.go +++ b/pkg/lib/infra/redis/handle.go @@ -60,3 +60,15 @@ func (h *Handle) NewMutex(name string) *redsync.Mutex { ) return mutex } + +func (h *Handle) WithMutex(ctx context.Context, name string, do func() error) error { + mutex := h.NewMutex(name) + if err := mutex.LockContext(ctx); err != nil { + return err + } + unlockCtx := context.WithoutCancel(ctx) + defer func() { + _, _ = mutex.UnlockContext(unlockCtx) + }() + return do() +} diff --git a/pkg/lib/oauth/redis/store.go b/pkg/lib/oauth/redis/store.go index dbfdbea362a..76dadaf9bdf 100644 --- a/pkg/lib/oauth/redis/store.go +++ b/pkg/lib/oauth/redis/store.go @@ -293,106 +293,74 @@ func (s *Store) CreateOfflineGrant(ctx context.Context, grant *oauth.OfflineGran func (s *Store) UpdateOfflineGrantWithMutator(ctx context.Context, grantID string, expireAt time.Time, mutator func(*oauth.OfflineGrant) *oauth.OfflineGrant) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant = mutator(grant) - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant = mutator(grant) + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantDeviceInfo(ctx context.Context, grantID string, deviceInfo map[string]any, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.DeviceInfo = deviceInfo - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.DeviceInfo = deviceInfo + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantAuthenticatedAt(ctx context.Context, grantID string, authenticatedAt time.Time, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.AuthenticatedAt = authenticatedAt - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.AuthenticatedAt = authenticatedAt + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantApp2AppDeviceKey(ctx context.Context, grantID string, newKey string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.App2AppDeviceKeyJWKJSON = newKey - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.App2AppDeviceKeyJWKJSON = newKey + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantDeviceSecretHash( @@ -402,29 +370,21 @@ func (s *Store) UpdateOfflineGrantDeviceSecretHash( dpopJKT string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.DeviceSecretHash = newDeviceSecretHash - grant.DeviceSecretDPoPJKT = dpopJKT - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.DeviceSecretHash = newDeviceSecretHash + grant.DeviceSecretDPoPJKT = dpopJKT + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) AddOfflineGrantSAMLServiceProviderParticipant( @@ -433,29 +393,22 @@ func (s *Store) AddOfflineGrantSAMLServiceProviderParticipant( newServiceProviderID string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - newParticipatedSAMLServiceProviderIDs := grant.GetParticipatedSAMLServiceProviderIDsSet() - newParticipatedSAMLServiceProviderIDs.Add(newServiceProviderID) - grant.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + newParticipatedSAMLServiceProviderIDs := grant.GetParticipatedSAMLServiceProviderIDsSet() + newParticipatedSAMLServiceProviderIDs.Add(newServiceProviderID) + grant.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) AddOfflineGrantRefreshToken( @@ -463,40 +416,31 @@ func (s *Store) AddOfflineGrantRefreshToken( options oauth.AddOfflineGrantRefreshTokenOptions, ) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), options.OfflineGrantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, options.OfflineGrantID) - if err != nil { - return nil, err - } - - now := s.Clock.NowUTC() - - newRefreshToken := oauth.OfflineGrantRefreshToken{ - InitialTokenHash: options.TokenHash, - ClientID: options.ClientID, - CreatedAt: now, - Scopes: options.Scopes, - AuthorizationID: options.AuthorizationID, - DPoPJKT: options.DPoPJKT, - AccessInfo: &options.AccessInfo, - ExpireAt: options.ShortLivedRefreshTokenExpireAt, - } - - grant.RefreshTokens = append(grant.RefreshTokens, newRefreshToken) - err = s.updateOfflineGrant(ctx, grant, options.OfflineGrantExpireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, options.OfflineGrantID) + if err != nil { + return err + } + now := s.Clock.NowUTC() + newRefreshToken := oauth.OfflineGrantRefreshToken{ + InitialTokenHash: options.TokenHash, + ClientID: options.ClientID, + CreatedAt: now, + Scopes: options.Scopes, + AuthorizationID: options.AuthorizationID, + DPoPJKT: options.DPoPJKT, + AccessInfo: &options.AccessInfo, + ExpireAt: options.ShortLivedRefreshTokenExpireAt, + } + grant.RefreshTokens = append(grant.RefreshTokens, newRefreshToken) + if err = s.updateOfflineGrant(ctx, grant, options.OfflineGrantExpireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) RotateOfflineGrantRefreshToken( @@ -505,91 +449,73 @@ func (s *Store) RotateOfflineGrantRefreshToken( expireAt time.Time, ) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), opts.OfflineGrantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, opts.OfflineGrantID) - if err != nil { - return nil, err - } - - var tokenToRotate *oauth.OfflineGrantRefreshToken - tokenIndex := -1 - for i, token := range grant.RefreshTokens { - if token.MatchInitialHash(opts.InitialRefreshTokenHash) { - tokenToRotate = &grant.RefreshTokens[i] - tokenIndex = i - break + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, opts.OfflineGrantID) + if err != nil { + return err } - } - - if tokenToRotate == nil { - return nil, oauth.ErrGrantNotFound - } - - tokenToRotate.RotatedTokenHash = &opts.NewRefreshTokenHash - t := s.Clock.NowUTC() - tokenToRotate.RotatedAt = &t - grant.RefreshTokens[tokenIndex] = *tokenToRotate - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var tokenToRotate *oauth.OfflineGrantRefreshToken + tokenIndex := -1 + for i, token := range grant.RefreshTokens { + if token.MatchInitialHash(opts.InitialRefreshTokenHash) { + tokenToRotate = &grant.RefreshTokens[i] + tokenIndex = i + break + } + } + if tokenToRotate == nil { + return oauth.ErrGrantNotFound + } + tokenToRotate.RotatedTokenHash = &opts.NewRefreshTokenHash + t := s.Clock.NowUTC() + tokenToRotate.RotatedAt = &t + grant.RefreshTokens[tokenIndex] = *tokenToRotate + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) RemoveOfflineGrantRefreshTokens(ctx context.Context, grantID string, initialTokenHashes []string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - tokenHashesSet := map[string]any{} - for _, hash := range initialTokenHashes { - tokenHashesSet[hash] = hash - } - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - newRefreshTokens := []oauth.OfflineGrantRefreshToken{} - for _, token := range grant.RefreshTokens { - if _, exist := tokenHashesSet[token.InitialTokenHash]; !exist { - newRefreshTokens = append(newRefreshTokens, token) + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + tokenHashesSet := map[string]any{} + for _, hash := range initialTokenHashes { + tokenHashesSet[hash] = hash } - } - grant.RefreshTokens = newRefreshTokens - if grant.HasValidTokens() { - err = s.updateOfflineGrant(ctx, grant, expireAt) + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) if err != nil { - return nil, err + return err } - } else { - // Remove the offline grant if it has no valid tokens - err = s.DeleteOfflineGrant(ctx, grant) - if err != nil { - return nil, err + + newRefreshTokens := []oauth.OfflineGrantRefreshToken{} + for _, token := range grant.RefreshTokens { + if _, exist := tokenHashesSet[token.InitialTokenHash]; !exist { + newRefreshTokens = append(newRefreshTokens, token) + } } - return nil, nil - } - return grant, nil + grant.RefreshTokens = newRefreshTokens + if grant.HasValidTokens() { + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + } else { + // Remove the offline grant if it has no valid tokens + if err = s.DeleteOfflineGrant(ctx, grant); err != nil { + return err + } + } + return nil + }) + return result, err } func (s *Store) updateOfflineGrant(ctx context.Context, grant *oauth.OfflineGrant, expireAt time.Time) error { diff --git a/pkg/lib/session/idpsession/provider.go b/pkg/lib/session/idpsession/provider.go index 7867752cc60..d78201c00ae 100644 --- a/pkg/lib/session/idpsession/provider.go +++ b/pkg/lib/session/idpsession/provider.go @@ -71,34 +71,24 @@ func (p *Provider) MakeSession(attrs *session.Attrs) (*IDPSession, string) { return session, token } -func (p *Provider) Reauthenticate(ctx context.Context, id string, amr []string) (err error) { +func (p *Provider) Reauthenticate(ctx context.Context, id string, amr []string) error { mutexName := sessionMutexName(p.AppID, id) - mutex := p.Redis.NewMutex(mutexName) - err = mutex.LockContext(ctx) - if err != nil { - return - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - s, err := p.Get(ctx, id) - if err != nil { - return - } - - now := p.Clock.NowUTC() - s.AuthenticatedAt = now - s.Attrs.SetAMR(amr) + return p.Redis.WithMutex(ctx, mutexName, func() error { + s, err := p.Get(ctx, id) + if err != nil { + return err + } - setSessionExpireAtForResolvedSession(s, p.Config) - err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession) - if err != nil { - err = fmt.Errorf("failed to update session: %w", err) - return err - } + now := p.Clock.NowUTC() + s.AuthenticatedAt = now + s.Attrs.SetAMR(amr) - return nil + setSessionExpireAtForResolvedSession(s, p.Config) + if err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession); err != nil { + return fmt.Errorf("failed to update session: %w", err) + } + return nil + }) } func (p *Provider) Create(ctx context.Context, session *IDPSession) error { @@ -172,36 +162,23 @@ func (p *Provider) AccessWithID(ctx context.Context, id string, accessEvent acce func (p *Provider) accessWithID(ctx context.Context, id string, accessEvent access.Event) (s *IDPSession, err error) { mutexName := sessionMutexName(p.AppID, id) - mutex := p.Redis.NewMutex(mutexName) - err = mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - s, err = p.Get(ctx, id) - if err != nil { - return nil, err - } - - s.AccessInfo.LastAccess = accessEvent - defer func() { - if err == nil && s != nil { - err = p.accessSideEffects(ctx, s, accessEvent) + err = p.Redis.WithMutex(ctx, mutexName, func() error { + session, err := p.Get(ctx, id) + if err != nil { + return err } - }() - setSessionExpireAtForResolvedSession(s, p.Config) + session.AccessInfo.LastAccess = accessEvent + setSessionExpireAtForResolvedSession(session, p.Config) - err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession) - if err != nil { - err = fmt.Errorf("failed to update session: %w", err) - return nil, err - } + if err = p.Store.Update(ctx, session, session.ExpireAtForResolvedSession); err != nil { + return fmt.Errorf("failed to update session: %w", err) + } - return s, nil + s = session + return p.accessSideEffects(ctx, session, accessEvent) + }) + return } func (p *Provider) accessSideEffects(ctx context.Context, session *IDPSession, accessEvent access.Event) error { @@ -221,29 +198,23 @@ func (p *Provider) accessSideEffects(ctx context.Context, session *IDPSession, a func (p *Provider) AddSAMLServiceProviderParticipant(ctx context.Context, session *IDPSession, serviceProviderID string) (*IDPSession, error) { mutexName := sessionMutexName(p.AppID, session.ID) - mutex := p.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - s, err := p.Get(ctx, session.ID) - if err != nil { - return nil, err - } - newParticipatedSAMLServiceProviderIDs := s.GetParticipatedSAMLServiceProviderIDsSet() - newParticipatedSAMLServiceProviderIDs.Add(serviceProviderID) - s.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() - err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession) - if err != nil { - err = fmt.Errorf("failed to update session: %w", err) - return nil, err - } + var result *IDPSession + err := p.Redis.WithMutex(ctx, mutexName, func() error { + s, err := p.Get(ctx, session.ID) + if err != nil { + return err + } + newParticipatedSAMLServiceProviderIDs := s.GetParticipatedSAMLServiceProviderIDsSet() + newParticipatedSAMLServiceProviderIDs.Add(serviceProviderID) + s.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() + if err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession); err != nil { + return fmt.Errorf("failed to update session: %w", err) + } + result = s + return nil + }) - return s, nil + return result, err } func (p *Provider) CheckSessionExpired(session *IDPSession) (expired bool) {