Skip to content

Commit 743aa5f

Browse files
Add Credentials interface and reimplement PAT authentication. (#1260)
## What changes are proposed in this pull request? This PR adds the missing `Credentials` interface to the experimental auth package. That interface is meant to replace the current `CredentialsProvider` interface. The PR also re-implements PAT using this new interface. This PR also moves the token cache in its own file to better isolate the package's main interfaces. ## How is this tested? Unit + integration tests. NO_CHANGELOG=true
1 parent d4cc97c commit 743aa5f

7 files changed

Lines changed: 343 additions & 202 deletions

File tree

config/auth_pat.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,28 @@ package config
33
import (
44
"context"
55
"fmt"
6-
"net/http"
76

87
"github.com/databricks/databricks-sdk-go/config/credentials"
8+
authcred "github.com/databricks/databricks-sdk-go/config/experimental/auth/credentials"
99
)
1010

11-
type PatCredentials struct {
12-
}
11+
type PatCredentials struct{}
1312

1413
func (c PatCredentials) Name() string {
1514
return "pat"
1615
}
1716

1817
func (c PatCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
19-
if cfg.Token == "" || cfg.Host == "" {
20-
return nil, nil
18+
// Host inference is not supported for PAT authentication. This is
19+
// arguably a redundant check as requests will fail anyway if no
20+
// host is provided. This check exists for backward compatibility
21+
// with the previous PAT credentials implementation.
22+
if cfg.Host == "" {
23+
return nil, fmt.Errorf("host is required for PAT authentication")
2124
}
22-
visitor := func(r *http.Request) error {
23-
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.Token))
24-
return nil
25+
creds, err := authcred.NewPATCredentials(cfg.Token)
26+
if err != nil {
27+
return nil, err
2528
}
26-
return credentials.CredentialsProviderFn(visitor), nil
29+
return credentials.FromCredentials(creds), nil
2730
}

config/credentials/credentials.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ func (c CredentialsProviderFn) SetHeaders(r *http.Request) error {
3030
return c(r)
3131
}
3232

33+
// FromCredentials returns a new CredentialsProvider that uses the given
34+
// Credentials to set headers on the request.
35+
//
36+
// The returned CredentialsProvider will override the headers if these
37+
// are already set.
38+
func FromCredentials(c auth.Credentials) CredentialsProvider {
39+
return CredentialsProviderFn(func(r *http.Request) error {
40+
headers, err := c.AuthHeaders(context.Background())
41+
if err != nil {
42+
return err
43+
}
44+
for k, v := range headers {
45+
r.Header.Set(k, v)
46+
}
47+
return nil
48+
})
49+
}
50+
3351
// NewCredentialsProvider returns a new CredentialsProvider that uses the
3452
// provided function to set headers on the request.
3553
//

config/experimental/auth/auth.go

Lines changed: 23 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,28 @@ package auth
66

77
import (
88
"context"
9-
"sync"
10-
"time"
119

1210
"golang.org/x/oauth2"
1311
)
1412

15-
const (
16-
// Default duration for the stale period. The number as been set arbitrarily
17-
// and might be changed in the future.
18-
defaultStaleDuration = 3 * time.Minute
19-
)
13+
// Credentials is anything that can return authentication headers.
14+
type Credentials interface {
15+
AuthHeaders(context.Context) (map[string]string, error)
16+
}
17+
18+
// CredentialsFn is an adapter to allow the use of ordinary functions as
19+
// Credentials.
20+
//
21+
// Example:
22+
//
23+
// creds := CredentialsFn(func(ctx context.Context) (map[string]string, error) {
24+
// return map[string]string{"Authorization": "Bearer " + token}, nil
25+
// })
26+
type CredentialsFn func(context.Context) (map[string]string, error)
27+
28+
func (fn CredentialsFn) AuthHeaders(ctx context.Context) (map[string]string, error) {
29+
return fn(ctx)
30+
}
2031

2132
// A TokenSource is anything that can return a token.
2233
type TokenSource interface {
@@ -39,190 +50,9 @@ func (fn TokenSourceFn) Token(ctx context.Context) (*oauth2.Token, error) {
3950
return fn(ctx)
4051
}
4152

42-
type Option func(*cachedTokenSource)
43-
44-
// WithCachedToken sets the initial token to be used by a cached token source.
45-
func WithCachedToken(t *oauth2.Token) Option {
46-
return func(cts *cachedTokenSource) {
47-
cts.cachedToken = t
48-
}
49-
}
50-
51-
// WithAsyncRefresh enables or disables the asynchronous token refresh.
52-
func WithAsyncRefresh(b bool) Option {
53-
return func(cts *cachedTokenSource) {
54-
cts.disableAsync = !b
55-
}
56-
}
57-
58-
// NewCachedTokenProvider wraps a [TokenSource] to cache the tokens it returns.
59-
// By default, the cache will refresh tokens asynchronously a few minutes before
60-
// they expire.
61-
//
62-
// The token cache is safe for concurrent use by multiple goroutines and will
63-
// guarantee that only one token refresh is triggered at a time.
64-
//
65-
// The token cache does not take care of retries in case the token source
66-
// returns and error; it is the responsibility of the provided token source to
67-
// handle retries appropriately.
68-
//
69-
// If the TokenSource is already a cached token source (obtained by calling this
70-
// function), it is returned as is.
71-
func NewCachedTokenSource(ts TokenSource, opts ...Option) TokenSource {
72-
// This is meant as a niche optimization to avoid double caching of the
73-
// token source in situations where the user calls needs caching guarantees
74-
// but does not know if the token source is already cached.
75-
if cts, ok := ts.(*cachedTokenSource); ok {
76-
return cts
77-
}
78-
79-
cts := &cachedTokenSource{
80-
tokenSource: ts,
81-
staleDuration: defaultStaleDuration,
82-
timeNow: time.Now,
83-
}
84-
85-
for _, opt := range opts {
86-
opt(cts)
87-
}
88-
89-
return cts
90-
}
91-
92-
type cachedTokenSource struct {
93-
// The token source to obtain tokens from.
94-
tokenSource TokenSource
95-
96-
// If true, only refresh the token with a blocking call when it is expired.
97-
disableAsync bool
98-
99-
// Duration during which a token is considered stale, see tokenState.
100-
staleDuration time.Duration
101-
102-
mu sync.Mutex
103-
cachedToken *oauth2.Token
104-
105-
// Indicates that an async refresh is in progress. This is used to prevent
106-
// multiple async refreshes from being triggered at the same time.
107-
isRefreshing bool
108-
109-
// Error returned by the last refresh. Async refreshes are disabled if this
110-
// value is not nil so that the cache does not continue sending request to
111-
// a potentially failing server. The next blocking call will re-enable async
112-
// refreshes by setting this value to nil if it succeeds, or return the
113-
// error if it fails.
114-
refreshErr error
115-
116-
timeNow func() time.Time // for testing
117-
}
118-
119-
// Token returns a token from the cache or fetches a new one if the current
120-
// token is expired.
121-
func (cts *cachedTokenSource) Token(ctx context.Context) (*oauth2.Token, error) {
122-
if cts.disableAsync {
123-
return cts.blockingToken(ctx)
124-
}
125-
return cts.asyncToken(ctx)
126-
}
127-
128-
// tokenState represents the state of the token. Each token can be in one of
129-
// the following three states:
130-
// - fresh: The token is valid.
131-
// - stale: The token is valid but will expire soon.
132-
// - expired: The token has expired and cannot be used.
133-
//
134-
// Token state through time:
135-
//
136-
// issue time expiry time
137-
// v v
138-
// | fresh | stale | expired -> time
139-
// | valid |
140-
type tokenState int
141-
142-
const (
143-
fresh tokenState = iota // The token is valid.
144-
stale // The token is valid but will expire soon.
145-
expired // The token has expired and cannot be used.
146-
)
147-
148-
// tokenState returns the state of the token. The function is not thread-safe
149-
// and should be called with the lock held.
150-
func (c *cachedTokenSource) tokenState() tokenState {
151-
if c.cachedToken == nil {
152-
return expired
153-
}
154-
switch lifeSpan := c.cachedToken.Expiry.Sub(c.timeNow()); {
155-
case lifeSpan <= 0:
156-
return expired
157-
case lifeSpan <= c.staleDuration:
158-
return stale
159-
default:
160-
return fresh
161-
}
162-
}
163-
164-
func (cts *cachedTokenSource) asyncToken(ctx context.Context) (*oauth2.Token, error) {
165-
cts.mu.Lock()
166-
ts := cts.tokenState()
167-
t := cts.cachedToken
168-
cts.mu.Unlock()
169-
170-
switch ts {
171-
case fresh:
172-
return t, nil
173-
case stale:
174-
cts.triggerAsyncRefresh(ctx)
175-
return t, nil
176-
default: // expired
177-
return cts.blockingToken(ctx)
178-
}
179-
}
180-
181-
func (cts *cachedTokenSource) blockingToken(ctx context.Context) (*oauth2.Token, error) {
182-
cts.mu.Lock()
183-
184-
// The lock is kept for the entire operation to ensure that only one
185-
// blockingToken operation is running at a time.
186-
defer cts.mu.Unlock()
187-
188-
// This is important to recover from potential previous failed attempts
189-
// to refresh the token asynchronously, see declaration of refreshErr for
190-
// more information.
191-
cts.isRefreshing = false
192-
cts.refreshErr = nil
193-
194-
// It's possible that the token got refreshed (either by a blockingToken or
195-
// an asyncRefresh call) while this particular call was waiting to acquire
196-
// the mutex. This check avoids refreshing the token again in such cases.
197-
if ts := cts.tokenState(); ts != expired { // fresh or stale
198-
return cts.cachedToken, nil
199-
}
200-
201-
t, err := cts.tokenSource.Token(ctx)
202-
if err != nil {
203-
return nil, err
204-
}
205-
cts.cachedToken = t
206-
return t, nil
207-
}
208-
209-
func (cts *cachedTokenSource) triggerAsyncRefresh(ctx context.Context) {
210-
cts.mu.Lock()
211-
defer cts.mu.Unlock()
212-
if !cts.isRefreshing && cts.refreshErr == nil {
213-
cts.isRefreshing = true
214-
215-
go func() {
216-
t, err := cts.tokenSource.Token(ctx)
217-
218-
cts.mu.Lock()
219-
defer cts.mu.Unlock()
220-
cts.isRefreshing = false
221-
if err != nil {
222-
cts.refreshErr = err
223-
return
224-
}
225-
cts.cachedToken = t
226-
}()
227-
}
53+
// OAuthCredentials is a Credentials and TokenSource that can be used to
54+
// authenticate with OAuth.
55+
type OAuthCredentials interface {
56+
Credentials
57+
TokenSource
22858
}

0 commit comments

Comments
 (0)