Skip to content

Commit 0cbaf44

Browse files
rickyromboCopilotdylanjeffersraymondjacobsongithub-actions[bot]
authored
Add support for SDK to use OAuth2.0 PKCE flow (#13804)
Adds helpers + refresh middleware and updates OAuth service to support PKCE flow. Makes API key only configs use the API-forward SDK type --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Dylan Jeffers <dylan@audius.co> Co-authored-by: Ray Jacobson <ray@audius.co> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
1 parent b0a6970 commit 0cbaf44

16 files changed

Lines changed: 1070 additions & 74 deletions

.changeset/breezy-lions-jog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@audius/sdk": minor
3+
---
4+
5+
Add support for OAuth2.0 PKCE access/refresh tokens

packages/sdk/src/sdk/createSdkWithServices.ts

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ import { developmentConfig } from './config/development'
2929
import { productionConfig } from './config/production'
3030
import {
3131
addAppInfoMiddleware,
32-
addRequestSignatureMiddleware
32+
addRequestSignatureMiddleware,
33+
addTokenRefreshMiddleware
3334
} from './middleware'
3435
import { OAuth } from './oauth'
36+
import { OAuthTokenStore } from './oauth/tokenStore'
3537
import {
3638
PaymentRouterClient,
3739
getDefaultPaymentRouterClientConfig
@@ -133,7 +135,7 @@ export const createSdkWithServices = (config: SdkConfig) => {
133135
)
134136
}
135137

136-
// Initialize APIs
138+
// Initialize APIs (also creates tokenStore and oauth)
137139
const apis = initializeApis({
138140
config,
139141
apiKey,
@@ -142,18 +144,7 @@ export const createSdkWithServices = (config: SdkConfig) => {
142144
services
143145
})
144146

145-
// Initialize OAuth
146-
const oauth = isBrowser
147-
? new OAuth({
148-
appName,
149-
apiKey,
150-
usersApi: apis.users,
151-
logger: services.logger
152-
})
153-
: undefined
154-
155147
return {
156-
oauth,
157148
...apis
158149
}
159150
}
@@ -460,11 +451,36 @@ const initializeApis = ({
460451
})
461452
]
462453

454+
// Token store for PKCE flow — provides dynamic accessToken to Configuration
455+
const tokenStore = new OAuthTokenStore()
456+
457+
// Auto-refresh middleware — intercepts 401s and retries with a fresh token.
458+
const oauth =
459+
typeof window !== 'undefined'
460+
? new OAuth({
461+
apiKey,
462+
tokenStore,
463+
basePath
464+
})
465+
: undefined
466+
467+
if (apiKey && oauth) {
468+
middleware.push(
469+
addTokenRefreshMiddleware({
470+
oauth
471+
})
472+
)
473+
}
474+
475+
const bearerToken = 'bearerToken' in config ? config.bearerToken : undefined
476+
463477
const apiClientConfig = new Configuration({
464478
fetchApi: fetch,
465479
middleware,
466480
basePath,
467-
accessToken: 'bearerToken' in config ? config.bearerToken : undefined
481+
// Static bearerToken takes precedence; otherwise use the dynamic store
482+
// so PKCE login can inject tokens after construction.
483+
accessToken: bearerToken ?? tokenStore.asAccessTokenProvider()
468484
})
469485

470486
const tracks = new TracksApi(apiClientConfig, services)
@@ -507,6 +523,8 @@ const initializeApis = ({
507523
const search = new SearchApi(apiClientConfig)
508524

509525
return {
526+
oauth,
527+
tokenStore,
510528
tracks,
511529
users,
512530
albums,

packages/sdk/src/sdk/createSdkWithoutServices.ts

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ import { developmentConfig } from './config/development'
2626
import { productionConfig } from './config/production'
2727
import {
2828
addAppInfoMiddleware,
29-
addRequestSignatureMiddleware
29+
addRequestSignatureMiddleware,
30+
addTokenRefreshMiddleware
3031
} from './middleware'
3132
import { OAuth } from './oauth'
33+
import { OAuthTokenStore } from './oauth/tokenStore'
3234
import { Logger, Storage, StorageNodeSelector } from './services'
3335
import { type SdkConfig } from './types'
3436

@@ -54,6 +56,19 @@ export const createSdkWithoutServices = (config: SdkConfig) => {
5456

5557
const middleware: Middleware[] = []
5658

59+
// Token store for PKCE flow — provides dynamic accessToken to Configuration
60+
const tokenStore = new OAuthTokenStore()
61+
62+
// Initialize OAuth early so it can be passed to middleware
63+
const oauth =
64+
typeof window !== 'undefined'
65+
? new OAuth({
66+
apiKey,
67+
tokenStore,
68+
basePath
69+
})
70+
: undefined
71+
5772
if (apiSecret || services?.audiusWalletClient) {
5873
middleware.push(
5974
addRequestSignatureMiddleware({
@@ -81,25 +96,30 @@ export const createSdkWithoutServices = (config: SdkConfig) => {
8196
)
8297
}
8398

99+
// Auto-refresh middleware — intercepts 401s and retries with a fresh token.
100+
if (apiKey && oauth) {
101+
middleware.push(
102+
addTokenRefreshMiddleware({
103+
oauth
104+
})
105+
)
106+
}
107+
84108
const apiConfig = new Configuration({
85109
fetchApi: fetch,
86110
middleware,
87111
basePath,
88-
accessToken: bearerToken
112+
// Static bearerToken takes precedence; otherwise use the dynamic store
113+
// so PKCE login can inject tokens after construction.
114+
accessToken: bearerToken ?? tokenStore.asAccessTokenProvider()
89115
})
90116

91-
// Initialize OAuth
117+
// Initialize API clients
92118
const usersApi = new UsersApi(apiConfig)
93-
const oauth =
94-
typeof window !== 'undefined'
95-
? new OAuth({
96-
apiKey,
97-
usersApi
98-
})
99-
: undefined
100119

101120
return {
102121
oauth,
122+
tokenStore,
103123
tracks: new TracksApi(apiConfig),
104124
users: usersApi,
105125
// albums
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import { describe, it, expect, vi, beforeEach } from 'vitest'
2+
3+
import type { OAuth } from '../oauth/OAuth'
4+
5+
import { addTokenRefreshMiddleware } from './addTokenRefreshMiddleware'
6+
7+
// Minimal fetch mock helper
8+
function mockResponse(status: number, body?: object): Response {
9+
return new Response(body ? JSON.stringify(body) : null, {
10+
status,
11+
headers: { 'Content-Type': 'application/json' }
12+
})
13+
}
14+
15+
function createMockOAuth(
16+
refreshBehaviour: () => Promise<string | null>,
17+
hasRefreshToken = true
18+
): OAuth {
19+
return {
20+
hasRefreshToken,
21+
refreshAccessToken: refreshBehaviour
22+
} as unknown as OAuth
23+
}
24+
25+
describe('addTokenRefreshMiddleware', () => {
26+
beforeEach(() => {
27+
vi.restoreAllMocks()
28+
})
29+
30+
it('passes through non-401 responses unchanged', async () => {
31+
const oauth = createMockOAuth(async () => 'token')
32+
const mw = addTokenRefreshMiddleware({ oauth })
33+
const response = mockResponse(200, { data: 'ok' })
34+
35+
const result = await mw.post!({
36+
fetch,
37+
url: 'https://api.example.com/v1/users/me',
38+
init: {},
39+
response
40+
})
41+
42+
expect(result).toBe(response)
43+
})
44+
45+
it('passes through 401 without calling refreshAccessToken when unauthenticated', async () => {
46+
const refreshFn = vi.fn()
47+
const oauth = createMockOAuth(refreshFn, false)
48+
const mw = addTokenRefreshMiddleware({ oauth })
49+
const response = mockResponse(401)
50+
51+
const result = await mw.post!({
52+
fetch,
53+
url: 'https://api.example.com/v1/users/me',
54+
init: {},
55+
response
56+
})
57+
58+
expect(result).toBe(response)
59+
expect(refreshFn).not.toHaveBeenCalled()
60+
})
61+
62+
it('passes through 401 when refresh returns null (no refresh token)', async () => {
63+
const oauth = createMockOAuth(async () => null)
64+
const mw = addTokenRefreshMiddleware({ oauth })
65+
const response = mockResponse(401)
66+
67+
const result = await mw.post!({
68+
fetch,
69+
url: 'https://api.example.com/v1/users/me',
70+
init: {},
71+
response
72+
})
73+
74+
expect(result).toBe(response)
75+
})
76+
77+
it('refreshes and retries on 401 when refresh succeeds', async () => {
78+
const oauth = createMockOAuth(async () => 'new-access')
79+
const retryResponse = mockResponse(200, { data: 'success' })
80+
const contextFetch = vi.fn().mockResolvedValueOnce(retryResponse)
81+
82+
const mw = addTokenRefreshMiddleware({ oauth })
83+
84+
const result = await mw.post!({
85+
fetch: contextFetch,
86+
url: 'https://api.example.com/v1/tracks/123',
87+
init: {
88+
method: 'GET',
89+
headers: { Authorization: 'Bearer expired-access' }
90+
},
91+
response: mockResponse(401)
92+
})
93+
94+
// Original request was retried with new token
95+
expect(contextFetch).toHaveBeenCalledWith(
96+
'https://api.example.com/v1/tracks/123',
97+
expect.objectContaining({
98+
headers: expect.objectContaining({
99+
Authorization: 'Bearer new-access'
100+
})
101+
})
102+
)
103+
104+
expect(result).toBe(retryResponse)
105+
})
106+
107+
it('surfaces 401 when refresh fails', async () => {
108+
const oauth = createMockOAuth(async () => null)
109+
const mw = addTokenRefreshMiddleware({ oauth })
110+
const original401 = mockResponse(401)
111+
112+
const result = await mw.post!({
113+
fetch,
114+
url: 'https://api.example.com/v1/tracks/123',
115+
init: {},
116+
response: original401
117+
})
118+
119+
expect(result).toBe(original401)
120+
})
121+
122+
it('surfaces 401 when refreshAccessToken throws', async () => {
123+
const oauth = createMockOAuth(async () => {
124+
throw new Error('network failure')
125+
})
126+
const mw = addTokenRefreshMiddleware({ oauth })
127+
const original401 = mockResponse(401)
128+
129+
const result = await mw.post!({
130+
fetch,
131+
url: 'https://api.example.com/v1/tracks/123',
132+
init: {},
133+
response: original401
134+
})
135+
136+
expect(result).toBe(original401)
137+
})
138+
})
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import type { Middleware, ResponseContext } from '../api/generated/default'
2+
import type { OAuth } from '../oauth/OAuth'
3+
import fetch from '../utils/fetch'
4+
5+
/**
6+
* Middleware that transparently refreshes an expired access token on 401.
7+
*
8+
* When a response comes back with HTTP 401 the middleware delegates to
9+
* `OAuth.refreshAccessToken()` which checks for a refresh token, performs the
10+
* HTTP exchange, and updates the token store. On success the original request
11+
* is retried with the fresh access token. On failure the 401 propagates.
12+
*
13+
* When the client is unauthenticated (no refresh token stored) the middleware
14+
* short-circuits immediately, avoiding noisy error callbacks.
15+
*/
16+
export const addTokenRefreshMiddleware = ({
17+
oauth
18+
}: {
19+
oauth: OAuth
20+
}): Middleware => {
21+
let refreshInFlight: Promise<string | null> | null = null
22+
23+
return {
24+
post: async (context: ResponseContext): Promise<Response | void> => {
25+
if (context.response.status !== 401) {
26+
return context.response
27+
}
28+
29+
// Skip refresh when unauthenticated to avoid noisy error callbacks.
30+
if (!oauth.hasRefreshToken) {
31+
return context.response
32+
}
33+
34+
// Coalesce concurrent 401s into a single refresh call.
35+
if (!refreshInFlight) {
36+
refreshInFlight = oauth
37+
.refreshAccessToken()
38+
.catch(() => null)
39+
.finally(() => {
40+
refreshInFlight = null
41+
})
42+
}
43+
44+
const newAccessToken = await refreshInFlight
45+
if (!newAccessToken) {
46+
return context.response
47+
}
48+
49+
// Retry the original request with the new access token.
50+
const retryInit: RequestInit = {
51+
...context.init,
52+
headers: {
53+
...((context.init.headers as Record<string, string>) ?? {}),
54+
Authorization: `Bearer ${newAccessToken}`
55+
}
56+
}
57+
return fetch(context.url, retryInit)
58+
}
59+
}
60+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
export { addAppInfoMiddleware } from './addAppInfoMiddleware'
22
export { addRequestSignatureMiddleware } from './addRequestSignatureMiddleware'
3+
export { addTokenRefreshMiddleware } from './addTokenRefreshMiddleware'

0 commit comments

Comments
 (0)