Skip to content

Commit ec2df09

Browse files
committed
update provider interface
1 parent 38501d3 commit ec2df09

24 files changed

Lines changed: 88 additions & 518 deletions

internal/aws/credentials/chain_provider.go

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,19 @@ package credentials
1212

1313
import (
1414
"context"
15+
"errors"
1516

1617
"go.mongodb.org/mongo-driver/v2/internal/aws/awserr"
1718
)
1819

19-
// A ChainProvider will search for a provider which returns credentials
20-
// and cache that provider until Retrieve is called again.
21-
//
2220
// The ChainProvider provides a way of chaining multiple providers together
2321
// which will pick the first available using priority order of the Providers
2422
// in the list.
2523
//
2624
// If none of the Providers retrieve valid credentials Value, ChainProvider's
27-
// Retrieve() will return the error ErrNoValidProvidersFoundInChain.
28-
//
29-
// If a Provider is found which returns valid credentials Value ChainProvider
30-
// will cache that Provider for all calls to IsExpired(), until Retrieve is
31-
// called again.
25+
// Retrieve() will return an error.
3226
type ChainProvider struct {
3327
Providers []Provider
34-
curr Provider
3528
}
3629

3730
// NewChainCredentials returns a pointer to a new Credentials object
@@ -44,31 +37,19 @@ func NewChainCredentials(providers []Provider) *Credentials {
4437

4538
// Retrieve returns the credentials value or error if no provider returned
4639
// without error.
47-
//
48-
// If a provider is found it will be cached and any calls to IsExpired()
49-
// will return the expired state of the cached provider.
5040
func (c *ChainProvider) Retrieve(ctx context.Context) (Value, error) {
5141
var errs = make([]error, 0, len(c.Providers))
5242
for _, p := range c.Providers {
5343
creds, err := p.Retrieve(ctx)
5444
if err == nil {
55-
c.curr = p
56-
return creds, nil
45+
if !creds.Expired() && creds.HasKeys() {
46+
return creds, nil
47+
}
48+
err = errors.New("credentials are invalid")
5749
}
5850
errs = append(errs, err)
5951
}
60-
c.curr = nil
6152

6253
var err = awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
6354
return Value{}, err
6455
}
65-
66-
// IsExpired will returned the expired state of the currently cached provider
67-
// if there is one. If there is no current provider, true will be returned.
68-
func (c *ChainProvider) IsExpired() bool {
69-
if c.curr != nil {
70-
return c.curr.IsExpired()
71-
}
72-
73-
return true
74-
}

internal/aws/credentials/chain_provider_test.go

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,14 @@ import (
1919
)
2020

2121
type secondStubProvider struct {
22-
creds Value
23-
expired bool
24-
err error
22+
creds Value
23+
err error
2524
}
2625

2726
func (s *secondStubProvider) Retrieve(_ context.Context) (Value, error) {
28-
s.expired = false
2927
s.creds.ProviderName = "secondStubProvider"
3028
return s.creds, s.err
3129
}
32-
func (s *secondStubProvider) IsExpired() bool {
33-
return s.expired
34-
}
3530

3631
func TestChainProviderWithNames(t *testing.T) {
3732
p := &ChainProvider{
@@ -106,49 +101,11 @@ func TestChainProviderGet(t *testing.T) {
106101
}
107102
}
108103

109-
func TestChainProviderIsExpired(t *testing.T) {
110-
stubProvider := &stubProvider{expired: true}
111-
p := &ChainProvider{
112-
Providers: []Provider{
113-
stubProvider,
114-
},
115-
}
116-
117-
ctx := context.Background()
118-
119-
if !p.IsExpired() {
120-
t.Errorf("Expect expired to be true before any Retrieve")
121-
}
122-
_, err := p.Retrieve(ctx)
123-
if err != nil {
124-
t.Errorf("Expect no error, got %v", err)
125-
}
126-
if p.IsExpired() {
127-
t.Errorf("Expect not expired after retrieve")
128-
}
129-
130-
stubProvider.expired = true
131-
if !p.IsExpired() {
132-
t.Errorf("Expect return of expired provider")
133-
}
134-
135-
_, err = p.Retrieve(ctx)
136-
if err != nil {
137-
t.Errorf("Expect no error, got %v", err)
138-
}
139-
if p.IsExpired() {
140-
t.Errorf("Expect not expired after retrieve")
141-
}
142-
}
143-
144104
func TestChainProviderWithNoProvider(t *testing.T) {
145105
p := &ChainProvider{
146106
Providers: []Provider{},
147107
}
148108

149-
if !p.IsExpired() {
150-
t.Errorf("Expect expired with no providers")
151-
}
152109
_, err := p.Retrieve(context.Background())
153110
if err.Error() != "NoCredentialProviders: no valid providers in chain" {
154111
t.Errorf("Expect no providers error returned, got %v", err)
@@ -167,9 +124,6 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
167124
},
168125
}
169126

170-
if !p.IsExpired() {
171-
t.Errorf("Expect expired with no providers")
172-
}
173127
_, err := p.Retrieve(context.Background())
174128

175129
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)

internal/aws/credentials/credentials.go

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ package credentials
1212

1313
import (
1414
"context"
15-
"sync"
15+
"sync/atomic"
1616
"time"
1717

1818
"go.mongodb.org/mongo-driver/v2/internal/aws/awserr"
@@ -50,6 +50,18 @@ type Value struct {
5050
ProviderName string
5151
}
5252

53+
// Expired returns if the credentials have expired.
54+
func (v Value) Expired() bool {
55+
if v.CanExpire {
56+
// Calling Round(0) on the current time will truncate the monotonic
57+
// reading only. Ensures credential expiry time is always based on
58+
// reported wall-clock time.
59+
return !v.Expires.After(time.Now().Round(0))
60+
}
61+
62+
return false
63+
}
64+
5365
// HasKeys returns if the credentials Value has both AccessKeyID and
5466
// SecretAccessKey value set.
5567
func (v Value) HasKeys() bool {
@@ -66,10 +78,6 @@ type Provider interface {
6678
// Retrieve returns nil if it successfully retrieved the value.
6779
// Error is returned if the value were not obtainable, or empty.
6880
Retrieve(context.Context) (Value, error)
69-
70-
// IsExpired returns if the credentials are no longer valid, and need
71-
// to be retrieved.
72-
IsExpired() bool
7381
}
7482

7583
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
@@ -85,13 +93,12 @@ type Provider interface {
8593
//
8694
// The first Credentials.Get() will always call Provider.Retrieve() to get the
8795
// first instance of the credentials Value. All calls to Get() after that
88-
// will return the cached credentials Value until IsExpired() returns true.
96+
// will return the cached credentials Value until Expired() returns true.
8997
type Credentials struct {
90-
sf singleflight.Group
91-
92-
m sync.RWMutex
93-
creds Value
9498
provider Provider
99+
100+
creds atomic.Value
101+
sf singleflight.Group
95102
}
96103

97104
// NewCredentials returns a pointer to a new Credentials with the provider set.
@@ -102,34 +109,18 @@ func NewCredentials(provider Provider) *Credentials {
102109
return c
103110
}
104111

105-
// GetWithContext returns the credentials value, or error if the credentials
112+
// Get returns the credentials value, or error if the credentials
106113
// Value failed to be retrieved. Will return early if the passed in context is
107114
// canceled.
108115
//
109116
// Will return the cached credentials Value if it has not expired. If the
110117
// credentials Value has expired the Provider's Retrieve() will be called
111118
// to refresh the credentials.
112-
//
113-
// If Credentials.Expire() was called the credentials Value will be force
114-
// expired, and the next call to Get() will cause them to be refreshed.
115-
func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) {
116-
// Check if credentials are cached, and not expired.
117-
select {
118-
case curCreds, ok := <-c.asyncIsExpired():
119-
// ok will only be true, of the credentials were not expired. ok will
120-
// be false and have no value if the credentials are expired.
121-
if ok {
122-
return curCreds, nil
123-
}
124-
case <-ctx.Done():
125-
return Value{}, awserr.New("RequestCanceled",
126-
"request context canceled", ctx.Err())
127-
}
128-
119+
func (c *Credentials) Get(ctx context.Context) (Value, error) {
129120
// Cannot pass context down to the actual retrieve, because the first
130121
// context would cancel the whole group when there is not direct
131122
// association of items in the group.
132-
resCh := c.sf.DoChan("", func() (interface{}, error) {
123+
resCh := c.sf.DoChan("", func() (any, error) {
133124
return c.singleRetrieve(&suppressedContext{ctx})
134125
})
135126
select {
@@ -141,43 +132,33 @@ func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) {
141132
}
142133
}
143134

144-
func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) {
145-
c.m.Lock()
146-
defer c.m.Unlock()
147-
148-
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
149-
return curCreds, nil
135+
func (c *Credentials) singleRetrieve(ctx context.Context) (any, error) {
136+
if currCreds, ok := c.getCreds(); ok && !currCreds.Expired() {
137+
return currCreds, nil
150138
}
151139

152-
creds, err := c.provider.Retrieve(ctx)
140+
newCreds, err := c.provider.Retrieve(ctx)
153141
if err == nil {
154-
c.creds = creds
142+
c.creds.Store(&newCreds)
155143
}
156144

157-
return creds, err
145+
return newCreds, err
158146
}
159147

160-
// asyncIsExpired returns a channel of credentials Value. If the channel is
161-
// closed the credentials are expired and credentials value are not empty.
162-
func (c *Credentials) asyncIsExpired() <-chan Value {
163-
ch := make(chan Value, 1)
164-
go func() {
165-
c.m.RLock()
166-
defer c.m.RUnlock()
167-
168-
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
169-
ch <- curCreds
170-
}
171-
172-
close(ch)
173-
}()
148+
// getCreds returns the currently stored credentials and true. Returning false
149+
// if no credentials were stored.
150+
func (c *Credentials) getCreds() (Value, bool) {
151+
v := c.creds.Load()
152+
if v == nil {
153+
return Value{}, false
154+
}
174155

175-
return ch
176-
}
156+
val := v.(*Value)
157+
if val == nil || !val.HasKeys() {
158+
return Value{}, false
159+
}
177160

178-
// isExpiredLocked helper method wrapping the definition of expired credentials.
179-
func (c *Credentials) isExpiredLocked(creds interface{}) bool {
180-
return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
161+
return *val, true
181162
}
182163

183164
type suppressedContext struct {

0 commit comments

Comments
 (0)