Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,44 @@ jest.mock('@aws-amplify/core', () => ({
}));
jest.mock('../../../../src/providers/cognito/utils/oauth/oAuthStore');

// Helper function for creating test tokens with configurable expiration
function createTokens({
accessTokenExpired = false,
idTokenExpired = false,
overrides = {},
}: {
accessTokenExpired?: boolean;
idTokenExpired?: boolean;
overrides?: Partial<CognitoAuthTokens>;
} = {}): CognitoAuthTokens {
const now = Math.floor(Date.now() / 1000);
const pastExp = now - 3600; // 1 hour ago
const futureExp = now + 3600; // 1 hour from now

return {
accessToken: {
payload: {
exp: accessTokenExpired ? pastExp : futureExp,
iat: accessTokenExpired ? pastExp - 3600 : now,
},
toString: () =>
accessTokenExpired ? 'mock-expired-access-token' : 'mock-access-token',
},
idToken: {
payload: {
exp: idTokenExpired ? pastExp : futureExp,
iat: idTokenExpired ? pastExp - 3600 : now,
},
toString: () =>
idTokenExpired ? 'mock-expired-id-token' : 'mock-id-token',
},
refreshToken: 'mock-refresh-token',
clockDrift: 0,
username: 'testuser',
...overrides,
};
}

describe('tokenOrchestrator', () => {
const mockTokenRefresher = jest.fn();
const mockTokenStore = {
Expand Down Expand Up @@ -225,4 +263,192 @@ describe('tokenOrchestrator', () => {
expect(clearTokensSpy).not.toHaveBeenCalled();
});
});

describe('getTokens method', () => {
const mockAuthConfig = {
Cognito: {
userPoolId: 'us-east-1_testpool',
userPoolClientId: 'testclientid',
},
};

beforeEach(() => {
tokenOrchestrator.setAuthConfig(mockAuthConfig);
jest.clearAllMocks();
(oAuthStore.loadOAuthInFlight as jest.Mock).mockResolvedValue(false);
});

it('should return null when no tokens are stored', async () => {
mockTokenStore.loadTokens.mockResolvedValue(null);

const result = await tokenOrchestrator.getTokens();

expect(result).toBeNull();
expect(mockTokenRefresher).not.toHaveBeenCalled();
});

it('should return tokens without refresh when tokens are valid', async () => {
const validTokens = createTokens();
mockTokenStore.loadTokens.mockResolvedValue(validTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');

const result = await tokenOrchestrator.getTokens();

expect(mockTokenRefresher).not.toHaveBeenCalled();
expect(result?.accessToken).toBeDefined();
expect(result?.idToken).toBeDefined();
});

it.each([
[
'access token is expired',
{ accessTokenExpired: true, idTokenExpired: false },
],
[
'ID token is expired',
{ accessTokenExpired: false, idTokenExpired: true },
],
[
'both tokens are expired',
{ accessTokenExpired: true, idTokenExpired: true },
],
])('should trigger refresh when %s', async (_scenario, tokenConfig) => {
const expiredTokens = createTokens(tokenConfig);
const newTokens = createTokens();
mockTokenStore.loadTokens.mockResolvedValue(expiredTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');
mockTokenRefresher.mockResolvedValue(newTokens);

const result = await tokenOrchestrator.getTokens();

expect(mockTokenRefresher).toHaveBeenCalledWith(
expect.objectContaining({
tokens: expiredTokens,
username: 'testuser',
}),
);
expect(result?.accessToken).toEqual(newTokens.accessToken);
});

it('should trigger refresh when forceRefresh is true even with valid tokens', async () => {
const validTokens = createTokens();
const newTokens = createTokens();
mockTokenStore.loadTokens.mockResolvedValue(validTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');
mockTokenRefresher.mockResolvedValue(newTokens);

const result = await tokenOrchestrator.getTokens({ forceRefresh: true });

expect(mockTokenRefresher).toHaveBeenCalledWith(
expect.objectContaining({
tokens: validTokens,
username: 'testuser',
}),
);
expect(result?.accessToken).toEqual(newTokens.accessToken);
});

it('should preserve signInDetails after token refresh', async () => {
const expiredTokens = createTokens({
accessTokenExpired: true,
overrides: {
signInDetails: {
authFlowType: 'USER_SRP_AUTH',
loginId: 'testuser',
},
},
});
const newTokens = createTokens();

mockTokenStore.loadTokens.mockResolvedValue(expiredTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');
mockTokenRefresher.mockResolvedValue(newTokens);

const result = await tokenOrchestrator.getTokens();

expect(result?.signInDetails?.authFlowType).toBe('USER_SRP_AUTH');
expect(result?.signInDetails?.loginId).toBe('testuser');
});

it('should return null when refresh fails with NotAuthorizedException', async () => {
const expiredTokens = createTokens({ accessTokenExpired: true });
mockTokenStore.loadTokens.mockResolvedValue(expiredTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');
mockTokenRefresher.mockRejectedValue(
new AmplifyError({
name: 'NotAuthorizedException',
message: 'Refresh token has expired',
}),
);

const result = await tokenOrchestrator.getTokens();

expect(result).toBeNull();
expect(mockTokenStore.clearTokens).toHaveBeenCalled();
});

it('should throw error when refresh fails with network error', async () => {
const expiredTokens = createTokens({ accessTokenExpired: true });
mockTokenStore.loadTokens.mockResolvedValue(expiredTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');
mockTokenRefresher.mockRejectedValue(
new AmplifyError({
name: AmplifyErrorCode.NetworkError,
message: 'Network Error',
}),
);

await expect(tokenOrchestrator.getTokens()).rejects.toThrow(
'Network Error',
);
expect(mockTokenStore.clearTokens).not.toHaveBeenCalled();
});

it('should not refresh tokens when idToken is missing but accessToken is valid', async () => {
const tokensWithoutIdToken = createTokens();
delete (tokensWithoutIdToken as any).idToken;
mockTokenStore.loadTokens.mockResolvedValue(tokensWithoutIdToken);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');

const result = await tokenOrchestrator.getTokens();

expect(mockTokenRefresher).not.toHaveBeenCalled();
expect(result?.accessToken).toBeDefined();
expect(result?.idToken).toBeUndefined();
});

it('should pass clientMetadata to token refresher', async () => {
const expiredTokens = createTokens({ accessTokenExpired: true });
const newTokens = createTokens();
const clientMetadata = { customKey: 'customValue' };
mockTokenStore.loadTokens.mockResolvedValue(expiredTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');
mockTokenRefresher.mockResolvedValue(newTokens);

await tokenOrchestrator.getTokens({ clientMetadata });

expect(mockTokenRefresher).toHaveBeenCalledWith(
expect.objectContaining({
clientMetadata,
}),
);
});

it('should store new tokens after successful refresh', async () => {
const expiredTokens = createTokens({ accessTokenExpired: true });
const newTokens = createTokens();
mockTokenStore.loadTokens.mockResolvedValue(expiredTokens);
mockTokenStore.getLastAuthUser.mockResolvedValue('testuser');
mockTokenRefresher.mockResolvedValue(newTokens);

await tokenOrchestrator.getTokens();

expect(mockTokenStore.storeTokens).toHaveBeenCalledWith(
expect.objectContaining({
accessToken: newTokens.accessToken,
idToken: newTokens.idToken,
}),
);
});
});
});