Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 118 additions & 2 deletions apps/cli/src/device-login.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
import { describe, expect, it } from "@effect/vitest";
import { afterEach, describe, expect, it } from "@effect/vitest";

import { browserOpenCommand } from "./device-login";
import {
browserOpenCommand,
discoverCliLogin,
refreshDeviceTokens,
requestDeviceCode,
type CliLoginDiscovery,
} from "./device-login";

const originalFetch = globalThis.fetch;

interface FetchCall {
readonly url: string;
readonly headers: Record<string, string>;
}

const responseJson = (body: Record<string, unknown>, status = 200): Response =>
new Response(JSON.stringify(body), {
status,
headers: { "content-type": "application/json" },
});

const installFetch = (handler: (url: string, init: RequestInit | undefined) => Response) => {
globalThis.fetch = ((input: RequestInfo | URL, init?: RequestInit) => {
const url =
typeof input === "string" ? input : input instanceof URL ? input.toString() : input.url;
return Promise.resolve(handler(url, init));
}) as typeof fetch;
};

const recordCall = (calls: Array<FetchCall>, url: string, init: RequestInit | undefined): void => {
calls.push({
url,
headers: Object.fromEntries(new Headers(init?.headers).entries()),
});
};

afterEach(() => {
globalThis.fetch = originalFetch;
});

describe("browserOpenCommand", () => {
it("opens Windows browser URLs without cmd.exe", () => {
Expand Down Expand Up @@ -32,3 +70,81 @@ describe("browserOpenCommand", () => {
expect(browserOpenCommand("not a url", "win32")).toBeUndefined();
});
});

describe("device login headers", () => {
it("sends configured headers when discovering CLI login", async () => {
const calls: Array<FetchCall> = [];
installFetch((url, init) => {
recordCall(calls, url, init);
return responseJson({
provider: "better-auth",
deviceAuthorizationEndpoint: "https://executor.example/api/auth/device/code",
tokenEndpoint: "https://executor.example/api/auth/device/token",
clientId: "executor-cli",
requestFormat: "json",
});
});

const discovery = await discoverCliLogin("https://executor.example", {
headers: { "CF-Access-Client-Id": "client-id" },
});

expect(discovery.clientId).toBe("executor-cli");
expect(calls).toHaveLength(1);
expect(calls[0]?.url).toBe("https://executor.example/api/auth/cli-login");
expect(calls[0]?.headers).toMatchObject({
accept: "application/json",
"cf-access-client-id": "client-id",
});
});

it("sends configured headers only to same-origin device endpoints", async () => {
const calls: Array<FetchCall> = [];
installFetch((url, init) => {
recordCall(calls, url, init);
if (url.endsWith("/api/auth/device/code")) {
return responseJson({
device_code: "device-code",
user_code: "USER-CODE",
verification_uri: "https://executor.example/device",
expires_in: 300,
interval: 5,
});
}
return responseJson({
access_token: "access-token",
refresh_token: "refresh-token-2",
expires_in: 600,
});
});
const discovery: CliLoginDiscovery = {
provider: "better-auth",
deviceAuthorizationEndpoint: "https://executor.example/api/auth/device/code",
tokenEndpoint: "https://accounts.example/oauth/token",
clientId: "executor-cli",
requestFormat: "form",
};
const headers = { "CF-Access-Client-Id": "client-id" };

await requestDeviceCode(discovery, { serverOrigin: "https://executor.example", headers });
await refreshDeviceTokens({
tokenEndpoint: discovery.tokenEndpoint,
clientId: discovery.clientId,
refreshToken: "refresh-token",
serverOrigin: "https://executor.example",
headers,
});

expect(calls).toHaveLength(2);
expect(calls[0]?.headers).toMatchObject({
accept: "application/json",
"content-type": "application/x-www-form-urlencoded",
"cf-access-client-id": "client-id",
});
expect(calls[1]?.headers).toMatchObject({
accept: "application/json",
"content-type": "application/x-www-form-urlencoded",
});
expect(calls[1]?.headers["cf-access-client-id"]).toBeUndefined();
});
});
65 changes: 57 additions & 8 deletions apps/cli/src/device-login.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ export interface DeviceTokens {
readonly organizationId?: string;
}

export interface DeviceLoginHttpOptions {
readonly headers?: Readonly<Record<string, string>>;
readonly serverOrigin?: string;
}

export interface PollForDeviceTokensOptions extends DeviceLoginHttpOptions {
readonly now?: () => number;
}

const DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code";
const DEFAULT_INTERVAL_SECONDS = 5;

Expand Down Expand Up @@ -126,17 +135,43 @@ const definedFields = (fields: Record<string, string | undefined>): Record<strin
string
>;

const isSameOrigin = (url: string, origin: string): boolean => {
try {
return new URL(url).origin === new URL(origin).origin;
} catch {
return false;
}
};

const headersForUrl = (
url: string,
baseHeaders: Record<string, string>,
options: DeviceLoginHttpOptions = {},
): Record<string, string> => {
const configured =
options.headers && (!options.serverOrigin || isSameOrigin(url, options.serverOrigin))
? options.headers
: {};
return { ...configured, ...baseHeaders };
};

const post = async (
url: string,
fields: Record<string, string | undefined>,
format: "form" | "json",
options: DeviceLoginHttpOptions = {},
) =>
fetch(url, {
method: "POST",
headers: {
"content-type": format === "json" ? "application/json" : "application/x-www-form-urlencoded",
accept: "application/json",
},
headers: headersForUrl(
url,
{
"content-type":
format === "json" ? "application/json" : "application/x-www-form-urlencoded",
accept: "application/json",
},
options,
),
body: format === "json" ? JSON.stringify(definedFields(fields)) : formBody(fields),
});

Expand All @@ -150,10 +185,16 @@ const readJson = async (response: Response): Promise<Record<string, unknown>> =>
}
};

export const discoverCliLogin = async (origin: string): Promise<CliLoginDiscovery> => {
export const discoverCliLogin = async (
origin: string,
options: DeviceLoginHttpOptions = {},
): Promise<CliLoginDiscovery> => {
let response: Response;
const url = cliLoginUrl(origin);
try {
response = await fetch(cliLoginUrl(origin), { headers: { accept: "application/json" } });
response = await fetch(url, {
headers: headersForUrl(url, { accept: "application/json" }, options),
});
} catch (cause) {
throw new DeviceLoginError(
`Could not reach ${origin} to start login: ${cause instanceof Error ? cause.message : String(cause)}`,
Expand Down Expand Up @@ -182,11 +223,15 @@ export const discoverCliLogin = async (origin: string): Promise<CliLoginDiscover
};
};

export const requestDeviceCode = async (discovery: CliLoginDiscovery): Promise<DeviceCodeGrant> => {
export const requestDeviceCode = async (
discovery: CliLoginDiscovery,
options: DeviceLoginHttpOptions = {},
): Promise<DeviceCodeGrant> => {
const response = await post(
discovery.deviceAuthorizationEndpoint,
{ client_id: discovery.clientId, scope: discovery.scope },
discovery.requestFormat,
options,
);
const body = await readJson(response);
if (!response.ok) {
Expand Down Expand Up @@ -220,7 +265,7 @@ const sleep = (ms: number): Promise<void> => new Promise((resolve) => setTimeout
export const pollForDeviceTokens = async (
discovery: CliLoginDiscovery,
grant: DeviceCodeGrant,
options: { readonly now?: () => number } = {},
options: PollForDeviceTokensOptions = {},
): Promise<DeviceTokens> => {
const now = options.now ?? (() => Date.now());
const deadline = now() + grant.expiresInSeconds * 1000;
Expand All @@ -240,6 +285,7 @@ export const pollForDeviceTokens = async (
client_id: discovery.clientId,
},
discovery.requestFormat,
options,
);
const body = await readJson(response);

Expand Down Expand Up @@ -286,6 +332,8 @@ export const refreshDeviceTokens = async (input: {
readonly tokenEndpoint: string;
readonly clientId: string;
readonly refreshToken: string;
readonly headers?: Readonly<Record<string, string>>;
readonly serverOrigin?: string;
}): Promise<DeviceTokens> => {
const response = await post(
input.tokenEndpoint,
Expand All @@ -295,6 +343,7 @@ export const refreshDeviceTokens = async (input: {
client_id: input.clientId,
},
"form",
input,
);
const body = await readJson(response);
if (!response.ok) {
Expand Down
Loading