Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/clients/auth_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, erro
// Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring
// between retrieving the token and upstream systems validating it.
now := time.Now().Add(tokenExpirationLeeway)
return now.After(expirationTimestampNumeric.Time), nil
return now.After(expirationTimestampNumeric.Time) || now.Equal(expirationTimestampNumeric.Time), nil
}
60 changes: 32 additions & 28 deletions core/clients/continuous_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"testing"
"testing/synctest"
"time"

"github.com/golang-jwt/jwt/v5"
Expand Down Expand Up @@ -93,36 +94,39 @@ func TestContinuousRefreshToken(t *testing.T) {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
t.Parallel()
accessToken, err := signToken(accessTokensTimeToLive)
if err != nil {
t.Fatalf("failed to sign access token: %v", err)
}
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn)
defer cancel()

authFlow := &fakeAuthFlow{
backgroundTokenRefreshContext: ctx,
doError: tt.doError,
accessTokensTimeToLive: accessTokensTimeToLive,
accessToken: accessToken,
}
synctest.Test(t, func(t *testing.T) {
accessToken, err := signToken(accessTokensTimeToLive)
if err != nil {
t.Fatalf("failed to sign access token: %v", err)
}
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn)
defer cancel()

authFlow := &fakeAuthFlow{
backgroundTokenRefreshContext: ctx,
doError: tt.doError,
accessTokensTimeToLive: accessTokensTimeToLive,
accessToken: accessToken,
}

refresher := &continuousTokenRefresher{
flow: authFlow,
timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration,
timeBetweenContextCheck: timeBetweenContextCheck,
timeBetweenTries: timeBetweenTries,
}
refresher := &continuousTokenRefresher{
flow: authFlow,
timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration,
timeBetweenContextCheck: timeBetweenContextCheck,
timeBetweenTries: timeBetweenTries,
}

err = refresher.continuousRefreshToken()
if err == nil {
t.Fatalf("routine finished with non-nil error")
}
numberDoCalls := authFlow.getTokenCalls()
if numberDoCalls != tt.expectedNumberDoCalls {
t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls)
}
err = refresher.continuousRefreshToken()
synctest.Wait()
if err == nil {
t.Fatalf("routine finished with non-nil error")
}
numberDoCalls := authFlow.getTokenCalls()
if numberDoCalls != tt.expectedNumberDoCalls {
t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls)
}
})
})
}
}
Expand Down
Loading