Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/quiet-auth-stability.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@tailor-platform/app-shell": patch
---

Fix OAuth callback handling so auth redirects do not re-run unnecessarily when AppShell re-renders.

Auth initialization now also starts from `AuthProvider`, which avoids unresolved auth state when consumers are mounted outside the router-driven AppShell flow.
234 changes: 229 additions & 5 deletions packages/core/src/contexts/auth-context.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@ vi.mock("@tailor-platform/auth-public-client", () => ({
createAuthClient: vi.fn(),
}));

import { AuthProvider, useAuth, useAuthSuspense, type EnhancedAuthClient } from "./auth-context";
import {
AuthProvider,
useAuth,
useEnsureAuthInitialized,
useAuthSuspense,
type EnhancedAuthClient,
} from "./auth-context";
import { useRootRouteContext } from "./root-route-context";

afterEach(() => {
cleanup();
vi.clearAllMocks();
vi.unstubAllGlobals();
window.history.replaceState({}, "", "/");
});

const LoadingGuard = () => <div>Loading...</div>;
Expand All @@ -38,12 +45,27 @@ describe("AuthProvider", () => {
isReady: false,
};

const baseHandleCallback = overrides?.handleCallback ?? vi.fn();
let handleCallbackInFlight: Promise<void> | null = null;
const handleCallback = vi.fn(() => {
if (handleCallbackInFlight) {
return handleCallbackInFlight;
}

const callbackPromise = Promise.resolve(baseHandleCallback()).finally(() => {
handleCallbackInFlight = null;
});
handleCallbackInFlight = callbackPromise;
return callbackPromise;
});

const { handleCallback: _ignoredHandleCallback, ...otherOverrides } = overrides ?? {};

return {
getState: vi.fn(() => state),
login: vi.fn(),
logout: vi.fn(),
getAuthUrl: vi.fn(),
handleCallback: vi.fn(),
checkAuthStatus: vi.fn().mockResolvedValue({
isAuthenticated: false,
error: null,
Expand All @@ -56,10 +78,110 @@ describe("AuthProvider", () => {
getAuthHeaders: vi.fn(),
fetch: vi.fn(),
getAppUri: vi.fn(() => "https://api.test.com"),
...overrides,
...otherOverrides,
handleCallback,
} as EnhancedAuthClient;
};

describe("useEnsureAuthInitialized", () => {
it("should initialize auth status on mount", async () => {
const state = {
isAuthenticated: false,
error: null,
isReady: false,
};
const mockCheckAuthStatus = vi.fn().mockResolvedValue({
isAuthenticated: true,
error: null,
isReady: true,
});

const mockClient = createMockAuthClient(state, {
checkAuthStatus: mockCheckAuthStatus,
});

const { result } = renderHook(() => useEnsureAuthInitialized(mockClient));

await act(async () => {
await result.current();
});

expect(mockCheckAuthStatus).toHaveBeenCalledTimes(1);
});

it("should coalesce overlapping auth initialization checks", async () => {
const state = {
isAuthenticated: false,
error: null,
isReady: false,
};

let resolveCheckAuthStatus: (() => void) | undefined;
const mockCheckAuthStatus = vi.fn(
() =>
new Promise<{
isAuthenticated: boolean;
error: null;
isReady: true;
}>((resolve) => {
resolveCheckAuthStatus = () =>
resolve({
isAuthenticated: true,
error: null,
isReady: true,
});
}),
);

const mockClient = createMockAuthClient(state, {
checkAuthStatus: mockCheckAuthStatus,
});

const { result } = renderHook(() => useEnsureAuthInitialized(mockClient));

const mountRetry = result.current();

await waitFor(() => {
expect(mockCheckAuthStatus).toHaveBeenCalledTimes(1);
});

const firstRetry = result.current();
const secondRetry = result.current();

expect(mockCheckAuthStatus).toHaveBeenCalledTimes(1);

resolveCheckAuthStatus?.();
await Promise.all([mountRetry, firstRetry, secondRetry]);
});

it("should skip auth initialization while handling an OAuth callback", async () => {
window.history.replaceState({}, "", "/?code=auth-code-123&state=abc");

const state = {
isAuthenticated: false,
error: null,
isReady: false,
};
const mockCheckAuthStatus = vi.fn().mockResolvedValue({
isAuthenticated: false,
error: null,
isReady: true,
});

const mockClient = createMockAuthClient(state, {
checkAuthStatus: mockCheckAuthStatus,
});

const { result } = renderHook(() => useEnsureAuthInitialized(mockClient));

await act(async () => {
await result.current();
});

expect(mockCheckAuthStatus).not.toHaveBeenCalled();
});
});

describe("initial state", () => {
it("should render children when not using guard component", () => {
const mockClient = createMockAuthClient();
Expand Down Expand Up @@ -137,7 +259,7 @@ describe("AuthProvider", () => {
});

describe("authentication flow", () => {
it("should check auth status via useRootRouteContext", async () => {
it("should not check auth status via useRootRouteContext on non-callback URLs", async () => {
const state = {
isAuthenticated: false,
error: null,
Expand All @@ -158,8 +280,12 @@ describe("AuthProvider", () => {
});

expect(result.current).not.toBeNull();
await waitFor(() => {
expect(mockCheckAuthStatus).toHaveBeenCalledTimes(1);
});

const response = await result.current!.loader(new URL("http://localhost/"));
expect(mockCheckAuthStatus).toHaveBeenCalled();
expect(mockCheckAuthStatus).toHaveBeenCalledTimes(1);
expect(response).toBeNull();
});

Expand Down Expand Up @@ -187,6 +313,51 @@ describe("AuthProvider", () => {
expect(response).toBeNull();
});

it("should deduplicate concurrent OAuth callback handling", async () => {
const state = {
isAuthenticated: true,
error: null,
isReady: true,
};

let resolveFirstCallback: (() => void) | undefined;
const mockHandleCallback = vi.fn(() => {
if (mockHandleCallback.mock.calls.length === 0) {
return Promise.resolve();
}

if (mockHandleCallback.mock.calls.length === 1) {
return new Promise<void>((resolve) => {
resolveFirstCallback = resolve;
});
}

return Promise.resolve();
});

const mockClient = createMockAuthClient(state, {
handleCallback: mockHandleCallback,
});

const { result } = renderHook(() => useRootRouteContext(), {
wrapper: ({ children }) => <AuthProvider client={mockClient}>{children}</AuthProvider>,
});

expect(result.current).not.toBeNull();

const requestUrl = new URL("http://localhost/?code=auth-code-123&state=abc");
const firstLoad = result.current!.loader(requestUrl);
const secondLoad = result.current!.loader(requestUrl);

expect(mockHandleCallback).toHaveBeenCalledTimes(1);

resolveFirstCallback?.();
await Promise.all([firstLoad, secondLoad]);

await result.current!.loader(requestUrl);
expect(mockHandleCallback).toHaveBeenCalledTimes(2);
});

it("should be authenticated when logged in", async () => {
const state = {
isAuthenticated: true,
Expand Down Expand Up @@ -597,6 +768,59 @@ describe("AuthProvider", () => {
{ timeout: 1000 },
);
});

it("should not login while the current URL is an OAuth callback", async () => {
window.history.replaceState({}, "", "/?code=auth-code-123&state=abc");

let authEventListener: ((event: { type: string; data?: unknown }) => void) | undefined;

const mockAddEventListener = vi.fn(
(listener: (event: { type: string; data?: unknown }) => void) => {
authEventListener = listener;
return () => {};
},
);

let currentState = {
isAuthenticated: false,
error: null as string | null,
isReady: true,
};

const mockLogin = vi.fn().mockResolvedValue(undefined);
const mockClient = createMockAuthClient(undefined, {
login: mockLogin,
addEventListener: mockAddEventListener,
getState: vi.fn(() => currentState),
});

render(
<AuthProvider client={mockClient} autoLogin={true}>
<div>Content</div>
</AuthProvider>,
);

await act(async () => {
await Promise.resolve();
});

expect(mockLogin).not.toHaveBeenCalled();

act(() => {
currentState = {
isAuthenticated: false,
error: null,
isReady: true,
};
authEventListener?.({ type: "auth_state_changed", data: {} });
});

await act(async () => {
await Promise.resolve();
});

expect(mockLogin).not.toHaveBeenCalled();
});
});

describe("event listeners", () => {
Expand Down
Loading
Loading