diff --git a/pkg/lib/infra/redis/handle.go b/pkg/lib/infra/redis/handle.go index 14b640a7c4..ac7f1d4877 100644 --- a/pkg/lib/infra/redis/handle.go +++ b/pkg/lib/infra/redis/handle.go @@ -60,3 +60,15 @@ func (h *Handle) NewMutex(name string) *redsync.Mutex { ) return mutex } + +func (h *Handle) WithMutex(ctx context.Context, name string, do func() error) error { + mutex := h.NewMutex(name) + if err := mutex.LockContext(ctx); err != nil { + return err + } + unlockCtx := context.WithoutCancel(ctx) + defer func() { + _, _ = mutex.UnlockContext(unlockCtx) + }() + return do() +} diff --git a/pkg/lib/oauth/redis/store.go b/pkg/lib/oauth/redis/store.go index dbfdbea362..76dadaf9bd 100644 --- a/pkg/lib/oauth/redis/store.go +++ b/pkg/lib/oauth/redis/store.go @@ -293,106 +293,74 @@ func (s *Store) CreateOfflineGrant(ctx context.Context, grant *oauth.OfflineGran func (s *Store) UpdateOfflineGrantWithMutator(ctx context.Context, grantID string, expireAt time.Time, mutator func(*oauth.OfflineGrant) *oauth.OfflineGrant) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant = mutator(grant) - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant = mutator(grant) + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantDeviceInfo(ctx context.Context, grantID string, deviceInfo map[string]any, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.DeviceInfo = deviceInfo - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.DeviceInfo = deviceInfo + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantAuthenticatedAt(ctx context.Context, grantID string, authenticatedAt time.Time, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.AuthenticatedAt = authenticatedAt - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.AuthenticatedAt = authenticatedAt + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantApp2AppDeviceKey(ctx context.Context, grantID string, newKey string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.App2AppDeviceKeyJWKJSON = newKey - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.App2AppDeviceKeyJWKJSON = newKey + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) UpdateOfflineGrantDeviceSecretHash( @@ -402,29 +370,21 @@ func (s *Store) UpdateOfflineGrantDeviceSecretHash( dpopJKT string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - grant.DeviceSecretHash = newDeviceSecretHash - grant.DeviceSecretDPoPJKT = dpopJKT - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + grant.DeviceSecretHash = newDeviceSecretHash + grant.DeviceSecretDPoPJKT = dpopJKT + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) AddOfflineGrantSAMLServiceProviderParticipant( @@ -433,29 +393,22 @@ func (s *Store) AddOfflineGrantSAMLServiceProviderParticipant( newServiceProviderID string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - newParticipatedSAMLServiceProviderIDs := grant.GetParticipatedSAMLServiceProviderIDsSet() - newParticipatedSAMLServiceProviderIDs.Add(newServiceProviderID) - grant.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) + if err != nil { + return err + } + newParticipatedSAMLServiceProviderIDs := grant.GetParticipatedSAMLServiceProviderIDsSet() + newParticipatedSAMLServiceProviderIDs.Add(newServiceProviderID) + grant.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) AddOfflineGrantRefreshToken( @@ -463,40 +416,31 @@ func (s *Store) AddOfflineGrantRefreshToken( options oauth.AddOfflineGrantRefreshTokenOptions, ) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), options.OfflineGrantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, options.OfflineGrantID) - if err != nil { - return nil, err - } - - now := s.Clock.NowUTC() - - newRefreshToken := oauth.OfflineGrantRefreshToken{ - InitialTokenHash: options.TokenHash, - ClientID: options.ClientID, - CreatedAt: now, - Scopes: options.Scopes, - AuthorizationID: options.AuthorizationID, - DPoPJKT: options.DPoPJKT, - AccessInfo: &options.AccessInfo, - ExpireAt: options.ShortLivedRefreshTokenExpireAt, - } - - grant.RefreshTokens = append(grant.RefreshTokens, newRefreshToken) - err = s.updateOfflineGrant(ctx, grant, options.OfflineGrantExpireAt) - if err != nil { - return nil, err - } - - return grant, nil + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, options.OfflineGrantID) + if err != nil { + return err + } + now := s.Clock.NowUTC() + newRefreshToken := oauth.OfflineGrantRefreshToken{ + InitialTokenHash: options.TokenHash, + ClientID: options.ClientID, + CreatedAt: now, + Scopes: options.Scopes, + AuthorizationID: options.AuthorizationID, + DPoPJKT: options.DPoPJKT, + AccessInfo: &options.AccessInfo, + ExpireAt: options.ShortLivedRefreshTokenExpireAt, + } + grant.RefreshTokens = append(grant.RefreshTokens, newRefreshToken) + if err = s.updateOfflineGrant(ctx, grant, options.OfflineGrantExpireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) RotateOfflineGrantRefreshToken( @@ -505,91 +449,73 @@ func (s *Store) RotateOfflineGrantRefreshToken( expireAt time.Time, ) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), opts.OfflineGrantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, opts.OfflineGrantID) - if err != nil { - return nil, err - } - - var tokenToRotate *oauth.OfflineGrantRefreshToken - tokenIndex := -1 - for i, token := range grant.RefreshTokens { - if token.MatchInitialHash(opts.InitialRefreshTokenHash) { - tokenToRotate = &grant.RefreshTokens[i] - tokenIndex = i - break + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, opts.OfflineGrantID) + if err != nil { + return err } - } - - if tokenToRotate == nil { - return nil, oauth.ErrGrantNotFound - } - - tokenToRotate.RotatedTokenHash = &opts.NewRefreshTokenHash - t := s.Clock.NowUTC() - tokenToRotate.RotatedAt = &t - grant.RefreshTokens[tokenIndex] = *tokenToRotate - - err = s.updateOfflineGrant(ctx, grant, expireAt) - if err != nil { - return nil, err - } - - return grant, nil + var tokenToRotate *oauth.OfflineGrantRefreshToken + tokenIndex := -1 + for i, token := range grant.RefreshTokens { + if token.MatchInitialHash(opts.InitialRefreshTokenHash) { + tokenToRotate = &grant.RefreshTokens[i] + tokenIndex = i + break + } + } + if tokenToRotate == nil { + return oauth.ErrGrantNotFound + } + tokenToRotate.RotatedTokenHash = &opts.NewRefreshTokenHash + t := s.Clock.NowUTC() + tokenToRotate.RotatedAt = &t + grant.RefreshTokens[tokenIndex] = *tokenToRotate + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + return nil + }) + return result, err } func (s *Store) RemoveOfflineGrantRefreshTokens(ctx context.Context, grantID string, initialTokenHashes []string, expireAt time.Time) (*oauth.OfflineGrant, error) { mutexName := offlineGrantMutexName(string(s.AppID), grantID) - mutex := s.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - tokenHashesSet := map[string]any{} - for _, hash := range initialTokenHashes { - tokenHashesSet[hash] = hash - } - - grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) - if err != nil { - return nil, err - } - - newRefreshTokens := []oauth.OfflineGrantRefreshToken{} - for _, token := range grant.RefreshTokens { - if _, exist := tokenHashesSet[token.InitialTokenHash]; !exist { - newRefreshTokens = append(newRefreshTokens, token) + var result *oauth.OfflineGrant + err := s.Redis.WithMutex(ctx, mutexName, func() error { + tokenHashesSet := map[string]any{} + for _, hash := range initialTokenHashes { + tokenHashesSet[hash] = hash } - } - grant.RefreshTokens = newRefreshTokens - if grant.HasValidTokens() { - err = s.updateOfflineGrant(ctx, grant, expireAt) + grant, err := s.GetOfflineGrantWithoutExpireAt(ctx, grantID) if err != nil { - return nil, err + return err } - } else { - // Remove the offline grant if it has no valid tokens - err = s.DeleteOfflineGrant(ctx, grant) - if err != nil { - return nil, err + + newRefreshTokens := []oauth.OfflineGrantRefreshToken{} + for _, token := range grant.RefreshTokens { + if _, exist := tokenHashesSet[token.InitialTokenHash]; !exist { + newRefreshTokens = append(newRefreshTokens, token) + } } - return nil, nil - } - return grant, nil + grant.RefreshTokens = newRefreshTokens + if grant.HasValidTokens() { + if err = s.updateOfflineGrant(ctx, grant, expireAt); err != nil { + return err + } + result = grant + } else { + // Remove the offline grant if it has no valid tokens + if err = s.DeleteOfflineGrant(ctx, grant); err != nil { + return err + } + } + return nil + }) + return result, err } func (s *Store) updateOfflineGrant(ctx context.Context, grant *oauth.OfflineGrant, expireAt time.Time) error { diff --git a/pkg/lib/session/idpsession/provider.go b/pkg/lib/session/idpsession/provider.go index 7867752cc6..d78201c00a 100644 --- a/pkg/lib/session/idpsession/provider.go +++ b/pkg/lib/session/idpsession/provider.go @@ -71,34 +71,24 @@ func (p *Provider) MakeSession(attrs *session.Attrs) (*IDPSession, string) { return session, token } -func (p *Provider) Reauthenticate(ctx context.Context, id string, amr []string) (err error) { +func (p *Provider) Reauthenticate(ctx context.Context, id string, amr []string) error { mutexName := sessionMutexName(p.AppID, id) - mutex := p.Redis.NewMutex(mutexName) - err = mutex.LockContext(ctx) - if err != nil { - return - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - s, err := p.Get(ctx, id) - if err != nil { - return - } - - now := p.Clock.NowUTC() - s.AuthenticatedAt = now - s.Attrs.SetAMR(amr) + return p.Redis.WithMutex(ctx, mutexName, func() error { + s, err := p.Get(ctx, id) + if err != nil { + return err + } - setSessionExpireAtForResolvedSession(s, p.Config) - err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession) - if err != nil { - err = fmt.Errorf("failed to update session: %w", err) - return err - } + now := p.Clock.NowUTC() + s.AuthenticatedAt = now + s.Attrs.SetAMR(amr) - return nil + setSessionExpireAtForResolvedSession(s, p.Config) + if err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession); err != nil { + return fmt.Errorf("failed to update session: %w", err) + } + return nil + }) } func (p *Provider) Create(ctx context.Context, session *IDPSession) error { @@ -172,36 +162,23 @@ func (p *Provider) AccessWithID(ctx context.Context, id string, accessEvent acce func (p *Provider) accessWithID(ctx context.Context, id string, accessEvent access.Event) (s *IDPSession, err error) { mutexName := sessionMutexName(p.AppID, id) - mutex := p.Redis.NewMutex(mutexName) - err = mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - s, err = p.Get(ctx, id) - if err != nil { - return nil, err - } - - s.AccessInfo.LastAccess = accessEvent - defer func() { - if err == nil && s != nil { - err = p.accessSideEffects(ctx, s, accessEvent) + err = p.Redis.WithMutex(ctx, mutexName, func() error { + session, err := p.Get(ctx, id) + if err != nil { + return err } - }() - setSessionExpireAtForResolvedSession(s, p.Config) + session.AccessInfo.LastAccess = accessEvent + setSessionExpireAtForResolvedSession(session, p.Config) - err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession) - if err != nil { - err = fmt.Errorf("failed to update session: %w", err) - return nil, err - } + if err = p.Store.Update(ctx, session, session.ExpireAtForResolvedSession); err != nil { + return fmt.Errorf("failed to update session: %w", err) + } - return s, nil + s = session + return p.accessSideEffects(ctx, session, accessEvent) + }) + return } func (p *Provider) accessSideEffects(ctx context.Context, session *IDPSession, accessEvent access.Event) error { @@ -221,29 +198,23 @@ func (p *Provider) accessSideEffects(ctx context.Context, session *IDPSession, a func (p *Provider) AddSAMLServiceProviderParticipant(ctx context.Context, session *IDPSession, serviceProviderID string) (*IDPSession, error) { mutexName := sessionMutexName(p.AppID, session.ID) - mutex := p.Redis.NewMutex(mutexName) - err := mutex.LockContext(ctx) - if err != nil { - return nil, err - } - defer func() { - _, _ = mutex.UnlockContext(ctx) - }() - - s, err := p.Get(ctx, session.ID) - if err != nil { - return nil, err - } - newParticipatedSAMLServiceProviderIDs := s.GetParticipatedSAMLServiceProviderIDsSet() - newParticipatedSAMLServiceProviderIDs.Add(serviceProviderID) - s.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() - err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession) - if err != nil { - err = fmt.Errorf("failed to update session: %w", err) - return nil, err - } + var result *IDPSession + err := p.Redis.WithMutex(ctx, mutexName, func() error { + s, err := p.Get(ctx, session.ID) + if err != nil { + return err + } + newParticipatedSAMLServiceProviderIDs := s.GetParticipatedSAMLServiceProviderIDsSet() + newParticipatedSAMLServiceProviderIDs.Add(serviceProviderID) + s.ParticipatedSAMLServiceProviderIDs = newParticipatedSAMLServiceProviderIDs.Keys() + if err = p.Store.Update(ctx, s, s.ExpireAtForResolvedSession); err != nil { + return fmt.Errorf("failed to update session: %w", err) + } + result = s + return nil + }) - return s, nil + return result, err } func (p *Provider) CheckSessionExpired(session *IDPSession) (expired bool) {