Skip to content

Commit b2a0199

Browse files
committed
Make refresh tokens more robust when deleting users
1 parent fa95e2c commit b2a0199

8 files changed

Lines changed: 212 additions & 80 deletions

File tree

apps/backend/src/app/api/latest/auth/oauth/callback/[provider_id]/route.tsx

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,29 @@ const handler = createSmartRouteHandler({
354354
throw new KnownErrors.SignUpNotEnabled();
355355
}
356356

357-
const currentUser = projectUserId ? await usersCrudHandlers.adminRead({ tenancy, user_id: projectUserId }) : null;
357+
// Set currentUser to the user that was signed in with the `token` access token during the /authorize request
358+
let currentUser;
359+
if (projectUserId) {
360+
// note that it's possible that the user has been deleted, but the request is still done with a token that was issued before the user was deleted
361+
// (or the user was deleted between the /authorize and /callback requests)
362+
// hence, we catch the error and ignore if that's the case
363+
try {
364+
currentUser = await usersCrudHandlers.adminRead({
365+
tenancy,
366+
user_id: projectUserId,
367+
allowedErrorTypes: [KnownErrors.UserNotFound],
368+
});
369+
} catch (error) {
370+
if (KnownErrors.UserNotFound.isInstance(error)) {
371+
currentUser = null;
372+
} else {
373+
throw error;
374+
}
375+
}
376+
} else {
377+
currentUser = null;
378+
}
379+
358380
const newAccountBeforeAuthMethod = await createOrUpgradeAnonymousUser(
359381
tenancy,
360382
currentUser,

apps/backend/src/app/api/latest/auth/oauth/token/route.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ export const POST = createSmartRouteHandler({
4141
oauthRequest,
4242
oauthResponse,
4343
{
44-
// note the `accessTokenLifetime` won't have any effect here because we set it in the `generateAccessToken` function
44+
// note the `accessTokenLifetime` won't have any effect here because we set it in the `generateAccessTokenFromRefreshTokenIfValid` function
4545
refreshTokenLifetime: 60 * 60 * 24 * 365, // 1 year
4646
alwaysIssueNewRefreshToken: false, // add token rotation later
4747
}

apps/backend/src/app/api/latest/auth/sessions/current/refresh/route.tsx

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { generateAccessToken } from "@/lib/tokens";
1+
import { generateAccessTokenFromRefreshTokenIfValid } from "@/lib/tokens";
22
import { getPrismaClientForTenancy, globalPrismaClient } from "@/prisma-client";
33
import { createSmartRouteHandler } from "@/route-handlers/smart-route-handler";
44
import { KnownErrors } from "@stackframe/stack-shared";
@@ -37,16 +37,15 @@ export const POST = createSmartRouteHandler({
3737
},
3838
});
3939

40-
if (!sessionObj || (sessionObj.expiresAt && sessionObj.expiresAt < new Date())) {
41-
throw new KnownErrors.RefreshTokenNotFoundOrExpired();
42-
}
43-
44-
const accessToken = await generateAccessToken({
40+
const accessToken = await generateAccessTokenFromRefreshTokenIfValid({
4541
tenancy,
46-
userId: sessionObj.projectUserId,
47-
refreshTokenId: sessionObj.id,
42+
refreshTokenObj: sessionObj,
4843
});
4944

45+
if (!accessToken) {
46+
throw new KnownErrors.RefreshTokenNotFoundOrExpired();
47+
}
48+
5049
return {
5150
statusCode: 200,
5251
bodyType: "json",

apps/backend/src/lib/tokens.tsx

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
import { usersCrudHandlers } from '@/app/api/latest/users/crud';
22
import { globalPrismaClient } from '@/prisma-client';
3-
import { Prisma } from '@prisma/client';
43
import { KnownErrors } from '@stackframe/stack-shared';
54
import { yupBoolean, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
5+
import { AccessTokenPayload } from '@stackframe/stack-shared/dist/sessions';
66
import { generateSecureRandomString } from '@stackframe/stack-shared/dist/utils/crypto';
77
import { getEnvVariable } from '@stackframe/stack-shared/dist/utils/env';
8-
import { StackAssertionError, throwErr } from '@stackframe/stack-shared/dist/utils/errors';
8+
import { StackAssertionError } from '@stackframe/stack-shared/dist/utils/errors';
99
import { getPrivateJwks, getPublicJwkSet, signJWT, verifyJWT } from '@stackframe/stack-shared/dist/utils/jwt';
1010
import { Result } from '@stackframe/stack-shared/dist/utils/results';
1111
import { traceSpan } from '@stackframe/stack-shared/dist/utils/telemetry';
1212
import * as jose from 'jose';
1313
import { JOSEError, JWTExpired } from 'jose/errors';
1414
import { SystemEventTypes, logEvent } from './events';
1515
import { Tenancy } from './tenancies';
16-
import { AccessTokenPayload } from '@stackframe/stack-shared/dist/sessions';
1716

1817
export const authorizationHeaderSchema = yupString().matches(/^StackSession [^ ]+$/);
1918

@@ -116,25 +115,45 @@ export async function decodeAccessToken(accessToken: string, { allowAnonymous }:
116115
});
117116
}
118117

119-
export async function generateAccessToken(options: {
118+
export async function isRefreshTokenValid(options: {
120119
tenancy: Tenancy,
121-
userId: string,
122-
refreshTokenId: string,
120+
refreshTokenObj: null | {
121+
projectUserId: string,
122+
id: string,
123+
expiresAt: Date | null,
124+
},
123125
}) {
126+
return !!await generateAccessTokenFromRefreshTokenIfValid(options);
127+
}
128+
129+
export async function generateAccessTokenFromRefreshTokenIfValid(options: {
130+
tenancy: Tenancy,
131+
refreshTokenObj: null | {
132+
projectUserId: string,
133+
id: string,
134+
expiresAt: Date | null,
135+
},
136+
}) {
137+
if (!options.refreshTokenObj) {
138+
return null;
139+
}
140+
141+
if (options.refreshTokenObj.expiresAt && options.refreshTokenObj.expiresAt < new Date()) {
142+
return null;
143+
}
144+
124145
let user;
125146
try {
126147
user = await usersCrudHandlers.adminRead({
127148
tenancy: options.tenancy,
128-
user_id: options.userId,
149+
user_id: options.refreshTokenObj.projectUserId,
129150
allowedErrorTypes: [KnownErrors.UserNotFound],
130151
});
131152
} catch (error) {
132153
if (error instanceof KnownErrors.UserNotFound) {
133-
throw new StackAssertionError(`User not found in generateAccessToken. Was the user's account deleted?`, {
134-
userId: options.userId,
135-
refreshTokenId: options.refreshTokenId,
136-
tenancy: options.tenancy,
137-
});
154+
// The user was deleted — their refresh token still exists because we don't cascade deletes across source-of-truth/global tables.
155+
// => refresh token is invalid
156+
return null;
138157
}
139158
throw error;
140159
}
@@ -144,17 +163,17 @@ export async function generateAccessToken(options: {
144163
{
145164
projectId: options.tenancy.project.id,
146165
branchId: options.tenancy.branchId,
147-
userId: options.userId,
148-
sessionId: options.refreshTokenId,
166+
userId: options.refreshTokenObj.projectUserId,
167+
sessionId: options.refreshTokenObj.id,
149168
isAnonymous: user.is_anonymous,
150169
}
151170
);
152171

153172
const payload: Omit<AccessTokenPayload, "iss" | "aud"> = {
154-
sub: options.userId,
173+
sub: options.refreshTokenObj.projectUserId,
155174
project_id: options.tenancy.project.id,
156175
branch_id: options.tenancy.branchId,
157-
refresh_token_id: options.refreshTokenId,
176+
refresh_token_id: options.refreshTokenObj.id,
158177
role: 'authenticated',
159178
name: user.display_name,
160179
email: user.primary_email,
@@ -171,44 +190,39 @@ export async function generateAccessToken(options: {
171190
});
172191
}
173192

174-
export async function createAuthTokens(options: {
193+
type CreateRefreshTokenOptions = {
175194
tenancy: Tenancy,
176195
projectUserId: string,
177196
expiresAt?: Date,
178197
isImpersonation?: boolean,
179-
}) {
198+
}
199+
200+
export async function createRefreshTokenObj(options: CreateRefreshTokenOptions) {
180201
options.expiresAt ??= new Date(Date.now() + 1000 * 60 * 60 * 24 * 365);
181202
options.isImpersonation ??= false;
182203

183204
const refreshToken = generateSecureRandomString();
184205

185-
try {
186-
const refreshTokenObj = await globalPrismaClient.projectUserRefreshToken.create({
187-
data: {
188-
tenancyId: options.tenancy.id,
189-
projectUserId: options.projectUserId,
190-
refreshToken: refreshToken,
191-
expiresAt: options.expiresAt,
192-
isImpersonation: options.isImpersonation,
193-
},
194-
});
206+
const refreshTokenObj = await globalPrismaClient.projectUserRefreshToken.create({
207+
data: {
208+
tenancyId: options.tenancy.id,
209+
projectUserId: options.projectUserId,
210+
refreshToken: refreshToken,
211+
expiresAt: options.expiresAt,
212+
isImpersonation: options.isImpersonation,
213+
},
214+
});
195215

196-
const accessToken = await generateAccessToken({
197-
tenancy: options.tenancy,
198-
userId: options.projectUserId,
199-
refreshTokenId: refreshTokenObj.id,
200-
});
216+
return refreshTokenObj;
217+
}
201218

219+
export async function createAuthTokens(options: CreateRefreshTokenOptions) {
220+
const refreshTokenObj = await createRefreshTokenObj(options);
202221

203-
return { refreshToken, accessToken };
222+
const accessToken = await generateAccessTokenFromRefreshTokenIfValid({
223+
tenancy: options.tenancy,
224+
refreshTokenObj: refreshTokenObj,
225+
});
204226

205-
} catch (error) {
206-
if (error instanceof Prisma.PrismaClientKnownRequestError && error.code === 'P2003') {
207-
throwErr(new Error(
208-
`Auth token creation failed for tenancyId ${options.tenancy.id} and projectUserId ${options.projectUserId}: ${error.message}`,
209-
{ cause: error }
210-
));
211-
}
212-
throw error;
213-
}
227+
return { refreshToken: refreshTokenObj.refreshToken, accessToken };
214228
}

apps/backend/src/oauth/model.tsx

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ import { createMfaRequiredError } from "@/app/api/latest/auth/mfa/sign-in/verifi
22
import { checkApiKeySet } from "@/lib/internal-api-keys";
33
import { validateRedirectUrl } from "@/lib/redirect-urls";
44
import { getSoleTenancyFromProjectBranch, getTenancy } from "@/lib/tenancies";
5-
import { decodeAccessToken, generateAccessToken } from "@/lib/tokens";
5+
import { createRefreshTokenObj, decodeAccessToken, generateAccessTokenFromRefreshTokenIfValid, isRefreshTokenValid } from "@/lib/tokens";
66
import { getPrismaClientForTenancy, globalPrismaClient } from "@/prisma-client";
77
import { AuthorizationCode, AuthorizationCodeModel, Client, Falsey, RefreshToken, Token, User } from "@node-oauth/oauth2-server";
88
import { PrismaClientKnownRequestError } from "@prisma/client/runtime/library";
99
import { KnownErrors } from "@stackframe/stack-shared";
10-
import { generateSecureRandomString } from "@stackframe/stack-shared/dist/utils/crypto";
1110
import { captureError, throwErr } from "@stackframe/stack-shared/dist/utils/errors";
1211
import { getProjectBranchFromClientId } from ".";
1312

@@ -98,46 +97,46 @@ export class OAuthModel implements AuthorizationCodeModel {
9897
assertScopeIsValid(scope);
9998
const tenancy = await getSoleTenancyFromProjectBranch(...getProjectBranchFromClientId(client.id));
10099

101-
if (!user.refreshTokenId) {
102-
// create new refresh token
103-
const refreshToken = await this.generateRefreshToken(client, user, scope);
104-
// save it in user, then we just access it in refresh
105-
// HACK: This is a hack to ensure the refresh token is already there so we can associate the access token with it
106-
const newRefreshToken = await globalPrismaClient.projectUserRefreshToken.create({
107-
data: {
108-
refreshToken,
109-
tenancyId: tenancy.id,
110-
projectUserId: user.id,
111-
expiresAt: new Date(),
112-
},
113-
});
114-
user.refreshTokenId = newRefreshToken.id;
115-
}
100+
const refreshTokenObj = await this._getOrCreateRefreshTokenObj(client, user, scope);
116101

117-
return await generateAccessToken({
102+
return await generateAccessTokenFromRefreshTokenIfValid({
118103
tenancy,
119-
userId: user.id,
120-
refreshTokenId: user.refreshTokenId ?? throwErr("Refresh token ID not found on OAuth user"),
121-
});
104+
refreshTokenObj,
105+
}) ?? throwErr("Get or create refresh token failed; returned refreshTokenObj that's invalid (or maybe it's an ultra-rare race condition and it became invalid in since the function call?)", { refreshTokenObj }); // TODO fix the ultra-rare race condition — although unless we're at gigascale this should basically never happen
122106
}
123107

124-
async generateRefreshToken(client: Client, user: User, scope: string[]): Promise<string> {
125-
assertScopeIsValid(scope);
108+
async _getOrCreateRefreshTokenObj(client: Client, user: User, scope: string[]) {
109+
const tenancy = await getSoleTenancyFromProjectBranch(...getProjectBranchFromClientId(client.id));
126110

111+
// if refresh token already exists and is valid, return it
127112
if (user.refreshTokenId) {
128-
const tenancy = await getSoleTenancyFromProjectBranch(...getProjectBranchFromClientId(client.id));
129-
const refreshToken = await globalPrismaClient.projectUserRefreshToken.findUniqueOrThrow({
113+
const refreshTokenObj = await globalPrismaClient.projectUserRefreshToken.findUnique({
130114
where: {
131115
tenancyId_id: {
132116
tenancyId: tenancy.id,
133117
id: user.refreshTokenId,
134118
},
135119
},
136120
});
137-
return refreshToken.refreshToken;
121+
if (refreshTokenObj && await isRefreshTokenValid({ tenancy, refreshTokenObj })) {
122+
return refreshTokenObj;
123+
}
138124
}
139125

140-
return generateSecureRandomString();
126+
// otherwise, create a new refresh token and set its ID on the user
127+
const refreshTokenObj = await createRefreshTokenObj({
128+
tenancy,
129+
projectUserId: user.id,
130+
});
131+
user.refreshTokenId = refreshTokenObj.id;
132+
return refreshTokenObj;
133+
}
134+
135+
async generateRefreshToken(client: Client, user: User, scope: string[]): Promise<string> {
136+
assertScopeIsValid(scope);
137+
138+
const tokenObj = await this._getOrCreateRefreshTokenObj(client, user, scope);
139+
return tokenObj.refreshToken;
141140
}
142141

143142
async saveToken(token: Token, client: Client, user: User): Promise<Token | Falsey> {

apps/e2e/tests/backend/endpoints/api/v1/auth/oauth/token.test.ts

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,66 @@ describe("with grant_type === 'authorization_code'", async () => {
9595
await Auth.expectToBeSignedIn();
9696
});
9797

98+
it("should sign in a user even when the same OAuth account has been used with a previous user that has been deleted since", async ({ expect }) => {
99+
const response = await Auth.OAuth.signIn();
100+
expect(response.tokenResponse).toMatchInlineSnapshot(`
101+
NiceResponse {
102+
"status": 200,
103+
"body": {
104+
"access_token": <stripped field 'access_token'>,
105+
"afterCallbackRedirectUrl": null,
106+
"after_callback_redirect_url": null,
107+
"expires_in": 3599,
108+
"is_new_user": true,
109+
"newUser": true,
110+
"refresh_token": <stripped field 'refresh_token'>,
111+
"scope": "legacy",
112+
"token_type": "Bearer",
113+
},
114+
"headers": Headers {
115+
"pragma": "no-cache",
116+
<some fields may have been hidden>,
117+
},
118+
}
119+
`);
120+
121+
// delete the user
122+
const deleteUserResponse = await niceBackendFetch("/api/v1/users/me", {
123+
method: "DELETE",
124+
accessType: "server",
125+
});
126+
expect(deleteUserResponse).toMatchInlineSnapshot(`
127+
NiceResponse {
128+
"status": 200,
129+
"body": { "success": true },
130+
"headers": Headers { <some fields may have been hidden> },
131+
}
132+
`);
133+
134+
// sign in again
135+
const response2 = await Auth.OAuth.signIn();
136+
expect(response2.tokenResponse).toMatchInlineSnapshot(`
137+
NiceResponse {
138+
"status": 200,
139+
"body": {
140+
"access_token": <stripped field 'access_token'>,
141+
"afterCallbackRedirectUrl": null,
142+
"after_callback_redirect_url": null,
143+
"expires_in": 3599,
144+
"is_new_user": true,
145+
"newUser": true,
146+
"refresh_token": <stripped field 'refresh_token'>,
147+
"scope": "legacy",
148+
"token_type": "Bearer",
149+
},
150+
"headers": Headers {
151+
"pragma": "no-cache",
152+
<some fields may have been hidden>,
153+
},
154+
}
155+
`);
156+
});
157+
98158
it("should fail when called with an invalid code_challenge", async ({ expect }) => {
99159
const getAuthorizationCodeResult = await Auth.OAuth.getAuthorizationCode();
100160

0 commit comments

Comments
 (0)