-
Notifications
You must be signed in to change notification settings - Fork 258
Expand file tree
/
Copy pathtokenRefresh.ts
More file actions
316 lines (278 loc) · 12.4 KB
/
tokenRefresh.ts
File metadata and controls
316 lines (278 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import { Account, PrismaClient } from '@sourcebot/db';
import {
BitbucketCloudIdentityProviderConfig,
BitbucketServerIdentityProviderConfig,
GitHubIdentityProviderConfig,
GitLabIdentityProviderConfig,
} from '@sourcebot/schemas/v3/index.type';
import {
createLogger,
decryptOAuthToken,
encryptOAuthToken,
env,
getTokenFromConfig,
IdentityProviderType,
loadConfig,
} from '@sourcebot/shared';
import { z } from 'zod';
const logger = createLogger('backend-ee-token-refresh');
const SUPPORTED_PROVIDERS = [
'github',
'gitlab',
'bitbucket-cloud',
'bitbucket-server',
] as const satisfies IdentityProviderType[];
type SupportedProvider = (typeof SUPPORTED_PROVIDERS)[number];
const isSupportedProvider = (provider: string): provider is SupportedProvider =>
SUPPORTED_PROVIDERS.includes(provider as SupportedProvider);
// @see: https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
const OAuthTokenResponseSchema = z.object({
access_token: z.string(),
token_type: z.string().optional(),
expires_in: z.number().optional(),
refresh_token: z.string().optional(),
scope: z.string().optional(),
});
type OAuthTokenResponse = z.infer<typeof OAuthTokenResponseSchema>;
type ProviderCredentials = {
clientId: string;
clientSecret: string;
baseUrl?: string;
};
const EXPIRY_BUFFER_S = 5 * 60; // 5 minutes
/**
* Ensures the OAuth access token for a given account is fresh.
*
* - If the token is not expired (or has no expiry), decrypts and returns it as-is.
* - If the token is expired or near expiry, attempts a refresh using the OAuth
* client credentials from the config file (or deprecated env vars).
* - On successful refresh: persists the new tokens to the DB, clears any
* tokenRefreshErrorMessage, and returns the fresh access token.
* - On failure: sets tokenRefreshErrorMessage on the account and throws, so
* the calling job fails with a clear error.
*/
export const ensureFreshAccountToken = async (
account: Account,
db: PrismaClient,
): Promise<string> => {
if (!account.access_token) {
throw new Error(`Account ${account.id} (${account.provider}) has no access token.`);
}
if (!isSupportedProvider(account.provider)) {
// Non-refreshable provider — just decrypt and return whatever is stored.
const token = decryptOAuthToken(account.access_token);
if (!token) {
throw new Error(`Failed to decrypt access token for account ${account.id}.`);
}
return token;
}
const now = Math.floor(Date.now() / 1000);
const isExpiredOrNearExpiry =
account.expires_at !== null &&
account.expires_at > 0 &&
now >= account.expires_at - EXPIRY_BUFFER_S;
if (!isExpiredOrNearExpiry) {
const token = decryptOAuthToken(account.access_token);
if (!token) {
throw new Error(`Failed to decrypt access token for account ${account.id}.`);
}
return token;
}
if (!account.refresh_token) {
const message = `Account ${account.id} (${account.provider}) token is expired and has no refresh token.`;
logger.error(message);
await setTokenRefreshError(account.id, message, db);
throw new Error(message);
}
const refreshToken = decryptOAuthToken(account.refresh_token);
if (!refreshToken) {
const message = `Failed to decrypt refresh token for account ${account.id} (${account.provider}).`;
logger.error(message);
await setTokenRefreshError(account.id, message, db);
throw new Error(message);
}
logger.debug(`Refreshing OAuth token for account ${account.id} (${account.provider})...`);
const refreshResponse = await refreshOAuthToken(account.provider, refreshToken);
if (!refreshResponse) {
const message = `OAuth token refresh failed for account ${account.id} (${account.provider}).`;
logger.error(message);
await setTokenRefreshError(account.id, message, db);
throw new Error(message);
}
const newExpiresAt = refreshResponse.expires_in
? Math.floor(Date.now() / 1000) + refreshResponse.expires_in
: null;
await db.account.update({
where: { id: account.id },
data: {
access_token: encryptOAuthToken(refreshResponse.access_token),
// Only update refresh_token if a new one was provided; preserve the
// existing one otherwise (some providers use rotating refresh tokens,
// others reuse the same one).
...(refreshResponse.refresh_token !== undefined
? { refresh_token: encryptOAuthToken(refreshResponse.refresh_token) }
: {}),
expires_at: newExpiresAt,
tokenRefreshErrorMessage: null,
},
});
logger.debug(`Successfully refreshed OAuth token for account ${account.id} (${account.provider}).`);
return refreshResponse.access_token;
};
const setTokenRefreshError = async (accountId: string, message: string, db: PrismaClient) => {
await db.account.update({
where: { id: accountId },
data: { tokenRefreshErrorMessage: message },
});
};
const refreshOAuthToken = async (
provider: SupportedProvider,
refreshToken: string,
): Promise<OAuthTokenResponse | null> => {
try {
const config = await loadConfig(env.CONFIG_PATH);
const identityProviders = config?.identityProviders ?? [];
const providerConfigs = identityProviders.filter(idp => idp.provider === provider);
// If no provider configs in the config file, try deprecated env vars.
if (providerConfigs.length === 0) {
const envCredentials = getDeprecatedEnvCredentials(provider);
if (envCredentials) {
logger.debug(`Using deprecated env vars for ${provider} token refresh`);
const result = await tryRefreshToken(provider, refreshToken, envCredentials);
if (result) {
return result;
}
logger.error(`Failed to refresh ${provider} token using deprecated env credentials`);
return null;
}
logger.error(`No provider config or env credentials found for: ${provider}`);
return null;
}
// Loop through all provider configs and return on first successful fetch
//
// The reason we have to do this is because 1) we might have multiple providers of the same type (ex. we're connecting to multiple gitlab instances) and 2) there isn't
// a trivial way to map a provider config to the associated Account object in the DB. The reason the config is involved at all here is because we need the client
// id/secret in order to refresh the token, and that info is in the config. We could in theory bypass this by storing the client id/secret for the provider in the
// Account table but we decided not to do that since these are secret. Instead, we simply try all of the client/id secrets for the associated provider type. This is safe
// to do because only the correct client id/secret will work since we're using a specific refresh token.
for (const providerConfig of providerConfigs) {
try {
const linkedAccountProviderConfig = providerConfig as
GitHubIdentityProviderConfig |
GitLabIdentityProviderConfig |
BitbucketCloudIdentityProviderConfig |
BitbucketServerIdentityProviderConfig;
// Get client credentials from config
const clientId = await getTokenFromConfig(linkedAccountProviderConfig.clientId);
const clientSecret = await getTokenFromConfig(linkedAccountProviderConfig.clientSecret);
const baseUrl = 'baseUrl' in linkedAccountProviderConfig
? linkedAccountProviderConfig.baseUrl
: undefined;
const result = await tryRefreshToken(provider, refreshToken, { clientId, clientSecret, baseUrl });
if (result) {
return result;
}
} catch (configError) {
logger.debug(`Error trying provider config for ${provider}:`, configError);
continue;
}
}
logger.error(`All provider configs failed for: ${provider}`);
return null;
} catch (e) {
logger.error(`Error refreshing ${provider} token:`, e);
return null;
}
};
const tryRefreshToken = async (
provider: SupportedProvider,
refreshToken: string,
credentials: ProviderCredentials,
): Promise<OAuthTokenResponse | null> => {
const { clientId, clientSecret, baseUrl } = credentials;
let url: string;
if (baseUrl) {
// Use a trailing-slash-normalized base so relative paths append correctly,
// preserving any context path (e.g. https://example.com/bitbucket/).
const base = baseUrl.endsWith('/') ? baseUrl : baseUrl + '/';
if (provider === 'github') {
url = new URL('login/oauth/access_token', base).toString();
} else if (provider === 'bitbucket-server') {
url = new URL('rest/oauth2/latest/token', base).toString();
} else {
url = new URL('oauth/token', base).toString();
}
} else if (provider === 'github') {
url = 'https://github.com/login/oauth/access_token';
} else if (provider === 'gitlab') {
url = 'https://gitlab.com/oauth/token';
} else if (provider === 'bitbucket-cloud') {
url = 'https://bitbucket.org/site/oauth2/access_token';
} else {
logger.error(`Unsupported provider for token refresh: ${provider}`);
return null;
}
// Bitbucket requires client credentials via HTTP Basic Auth rather than request body params.
// @see: https://support.atlassian.com/bitbucket-cloud/docs/use-oauth-on-bitbucket-cloud/
const useBasicAuth = provider === 'bitbucket-cloud';
// Build request body parameters
const bodyParams: Record<string, string> = {
// @see: https://datatracker.ietf.org/doc/html/rfc6749#section-6 (refresh token grant)
grant_type: 'refresh_token',
refresh_token: refreshToken,
};
if (!useBasicAuth) {
// @see: https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 (client authentication)
bodyParams.client_id = clientId;
bodyParams.client_secret = clientSecret;
}
// GitLab requires redirect_uri to match the original authorization request
// even when refreshing tokens. Use URL constructor to handle trailing slashes.
if (provider === 'gitlab') {
bodyParams.redirect_uri = new URL('/api/auth/callback/gitlab', env.AUTH_URL).toString();
}
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
...(useBasicAuth ? {
Authorization: `Basic ${Buffer.from(`${clientId}:${clientSecret}`).toString('base64')}`,
} : {}),
},
body: new URLSearchParams(bodyParams),
});
if (!response.ok) {
const errorText = await response.text();
logger.error(`Failed to refresh ${provider} token: ${response.status} ${errorText}`);
return null;
}
const json = await response.json();
const result = OAuthTokenResponseSchema.safeParse(json);
if (!result.success) {
logger.error(`Invalid OAuth token response from ${provider}:\n${result.error.message}`);
return null;
}
return result.data;
}
/**
* Get credentials from deprecated environment variables.
* This is for backwards compatibility with deployments using env vars instead of config file.
*/
const getDeprecatedEnvCredentials = (provider: string): ProviderCredentials | null => {
if (provider === 'github' && env.AUTH_EE_GITHUB_CLIENT_ID && env.AUTH_EE_GITHUB_CLIENT_SECRET) {
return {
clientId: env.AUTH_EE_GITHUB_CLIENT_ID,
clientSecret: env.AUTH_EE_GITHUB_CLIENT_SECRET,
baseUrl: env.AUTH_EE_GITHUB_BASE_URL,
};
}
if (provider === 'gitlab' && env.AUTH_EE_GITLAB_CLIENT_ID && env.AUTH_EE_GITLAB_CLIENT_SECRET) {
return {
clientId: env.AUTH_EE_GITLAB_CLIENT_ID,
clientSecret: env.AUTH_EE_GITLAB_CLIENT_SECRET,
baseUrl: env.AUTH_EE_GITLAB_BASE_URL,
};
}
return null;
}