@@ -12,7 +12,7 @@ package credentials
1212
1313import (
1414 "context"
15- "sync"
15+ "sync/atomic "
1616 "time"
1717
1818 "go.mongodb.org/mongo-driver/v2/internal/aws/awserr"
@@ -50,6 +50,18 @@ type Value struct {
5050 ProviderName string
5151}
5252
53+ // Expired returns if the credentials have expired.
54+ func (v Value ) Expired () bool {
55+ if v .CanExpire {
56+ // Calling Round(0) on the current time will truncate the monotonic
57+ // reading only. Ensures credential expiry time is always based on
58+ // reported wall-clock time.
59+ return ! v .Expires .After (time .Now ().Round (0 ))
60+ }
61+
62+ return false
63+ }
64+
5365// HasKeys returns if the credentials Value has both AccessKeyID and
5466// SecretAccessKey value set.
5567func (v Value ) HasKeys () bool {
@@ -66,10 +78,6 @@ type Provider interface {
6678 // Retrieve returns nil if it successfully retrieved the value.
6779 // Error is returned if the value were not obtainable, or empty.
6880 Retrieve (context.Context ) (Value , error )
69-
70- // IsExpired returns if the credentials are no longer valid, and need
71- // to be retrieved.
72- IsExpired () bool
7381}
7482
7583// A Credentials provides concurrency safe retrieval of AWS credentials Value.
@@ -85,13 +93,12 @@ type Provider interface {
8593//
8694// The first Credentials.Get() will always call Provider.Retrieve() to get the
8795// first instance of the credentials Value. All calls to Get() after that
88- // will return the cached credentials Value until IsExpired () returns true.
96+ // will return the cached credentials Value until Expired () returns true.
8997type Credentials struct {
90- sf singleflight.Group
91-
92- m sync.RWMutex
93- creds Value
9498 provider Provider
99+
100+ creds atomic.Value
101+ sf singleflight.Group
95102}
96103
97104// NewCredentials returns a pointer to a new Credentials with the provider set.
@@ -102,34 +109,18 @@ func NewCredentials(provider Provider) *Credentials {
102109 return c
103110}
104111
105- // GetWithContext returns the credentials value, or error if the credentials
112+ // Get returns the credentials value, or error if the credentials
106113// Value failed to be retrieved. Will return early if the passed in context is
107114// canceled.
108115//
109116// Will return the cached credentials Value if it has not expired. If the
110117// credentials Value has expired the Provider's Retrieve() will be called
111118// to refresh the credentials.
112- //
113- // If Credentials.Expire() was called the credentials Value will be force
114- // expired, and the next call to Get() will cause them to be refreshed.
115- func (c * Credentials ) GetWithContext (ctx context.Context ) (Value , error ) {
116- // Check if credentials are cached, and not expired.
117- select {
118- case curCreds , ok := <- c .asyncIsExpired ():
119- // ok will only be true, of the credentials were not expired. ok will
120- // be false and have no value if the credentials are expired.
121- if ok {
122- return curCreds , nil
123- }
124- case <- ctx .Done ():
125- return Value {}, awserr .New ("RequestCanceled" ,
126- "request context canceled" , ctx .Err ())
127- }
128-
119+ func (c * Credentials ) Get (ctx context.Context ) (Value , error ) {
129120 // Cannot pass context down to the actual retrieve, because the first
130121 // context would cancel the whole group when there is not direct
131122 // association of items in the group.
132- resCh := c .sf .DoChan ("" , func () (interface {} , error ) {
123+ resCh := c .sf .DoChan ("" , func () (any , error ) {
133124 return c .singleRetrieve (& suppressedContext {ctx })
134125 })
135126 select {
@@ -141,43 +132,33 @@ func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) {
141132 }
142133}
143134
144- func (c * Credentials ) singleRetrieve (ctx context.Context ) (interface {}, error ) {
145- c .m .Lock ()
146- defer c .m .Unlock ()
147-
148- if curCreds := c .creds ; ! c .isExpiredLocked (curCreds ) {
149- return curCreds , nil
135+ func (c * Credentials ) singleRetrieve (ctx context.Context ) (any , error ) {
136+ if currCreds , ok := c .getCreds (); ok && ! currCreds .Expired () {
137+ return currCreds , nil
150138 }
151139
152- creds , err := c .provider .Retrieve (ctx )
140+ newCreds , err := c .provider .Retrieve (ctx )
153141 if err == nil {
154- c .creds = creds
142+ c .creds . Store ( & newCreds )
155143 }
156144
157- return creds , err
145+ return newCreds , err
158146}
159147
160- // asyncIsExpired returns a channel of credentials Value. If the channel is
161- // closed the credentials are expired and credentials value are not empty.
162- func (c * Credentials ) asyncIsExpired () <- chan Value {
163- ch := make (chan Value , 1 )
164- go func () {
165- c .m .RLock ()
166- defer c .m .RUnlock ()
167-
168- if curCreds := c .creds ; ! c .isExpiredLocked (curCreds ) {
169- ch <- curCreds
170- }
171-
172- close (ch )
173- }()
148+ // getCreds returns the currently stored credentials and true. Returning false
149+ // if no credentials were stored.
150+ func (c * Credentials ) getCreds () (Value , bool ) {
151+ v := c .creds .Load ()
152+ if v == nil {
153+ return Value {}, false
154+ }
174155
175- return ch
176- }
156+ val := v .(* Value )
157+ if val == nil || ! val .HasKeys () {
158+ return Value {}, false
159+ }
177160
178- // isExpiredLocked helper method wrapping the definition of expired credentials.
179- func (c * Credentials ) isExpiredLocked (creds interface {}) bool {
180- return creds == nil || creds .(Value ) == Value {} || c .provider .IsExpired ()
161+ return * val , true
181162}
182163
183164type suppressedContext struct {
0 commit comments