diff --git a/cloud_function/index.js b/cloud_function/index.js index c59fdedb..91dac953 100644 --- a/cloud_function/index.js +++ b/cloud_function/index.js @@ -111,9 +111,25 @@ async function handleCallback(req, res) { finalUrl.searchParams.append('token_type', token_type); finalUrl.searchParams.append('expiry_date', expiry_date.toString()); - // SECURITY: Pass the CSRF token back to the client for validation. - if (payload.csrf) { - finalUrl.searchParams.append('state', payload.csrf); + // SECURITY: Pass the original base64-encoded state back to the client + // for validation. The client decodes it and extracts the `csrf` field itself. + // + // Previously, this code extracted only `payload.csrf` (a raw hex string) and + // returned it as `state`. That worked for workspace-server ≤ v0.0.7, which + // compared the returned value directly against the local csrf token. + // + // Starting with v0.0.9, the local server expects to receive the full base64 + // JSON state back, then decodes it to extract the `csrf` field: + // + // const decoded = JSON.parse(Buffer.from(returnedState, 'base64').toString('utf8')); + // if (decoded.csrf !== localCsrfToken) → "State mismatch. Possible CSRF attack." + // + // Returning the raw hex here causes `JSON.parse` to fail, setting csrf to null + // and triggering a state mismatch on every auth attempt in v0.0.9+. + // + // Fix: return the original state parameter unchanged so the client can decode it. + if (state) { + finalUrl.searchParams.append('state', state); } return res.redirect(302, finalUrl.toString()); diff --git a/cloud_function/index.test.js b/cloud_function/index.test.js new file mode 100644 index 00000000..53be8de5 --- /dev/null +++ b/cloud_function/index.test.js @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Tests for the OAuth cloud function — specifically the `state` passthrough + * fix introduced to support workspace-server v0.0.9+. + * + * The critical invariant: the cloud function must pass the original base64 + * state back to the client unchanged, so the client can decode it and extract + * the csrf field for CSRF validation. + */ + +// --------------------------------------------------------------------------- +// Unit tests for the state passthrough logic (no HTTP server needed) +// --------------------------------------------------------------------------- + +describe('state parameter passthrough in handleCallback', () => { + /** + * Mirrors the state-handling block in handleCallback to test it in isolation. + * Returns the value that would be appended as the `state` query param on the + * redirect URL, or null if no state would be appended. + */ + function buildRedirectState(stateParam) { + if (!stateParam) return null; + if (stateParam.length > 4096) throw new Error('State parameter exceeds size limit of 4KB.'); + + const payload = JSON.parse(Buffer.from(stateParam, 'base64').toString('utf8')); + + if (payload && payload.manual === false && payload.uri) { + // The fix: return `stateParam` unchanged, NOT `payload.csrf` + return stateParam; + } + return null; + } + + const CSRF = 'deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef'; + + it('returns the full base64 state unchanged (v0.0.9+ client can decode it)', () => { + const payload = { uri: 'http://localhost:54321/oauth2callback', manual: false, csrf: CSRF }; + const state = Buffer.from(JSON.stringify(payload)).toString('base64'); + + const result = buildRedirectState(state); + + expect(result).toBe(state); + + // Verify the client can decode csrf from the returned value + const decoded = JSON.parse(Buffer.from(result, 'base64').toString('utf8')); + expect(decoded.csrf).toBe(CSRF); + }); + + it('returned state must NOT be just the raw hex csrf (old buggy behaviour)', () => { + const payload = { uri: 'http://localhost:54321/oauth2callback', manual: false, csrf: CSRF }; + const state = Buffer.from(JSON.stringify(payload)).toString('base64'); + + const result = buildRedirectState(state); + + // The old code returned `payload.csrf` — that must no longer happen + expect(result).not.toBe(CSRF); + }); + + it('returns null when manual=true (manual flow, no redirect)', () => { + const payload = { manual: true, csrf: CSRF }; + const state = Buffer.from(JSON.stringify(payload)).toString('base64'); + + expect(buildRedirectState(state)).toBeNull(); + }); + + it('returns null when state param is absent', () => { + expect(buildRedirectState(null)).toBeNull(); + }); + + it('throws when state exceeds 4KB', () => { + const oversized = 'a'.repeat(4097); + expect(() => buildRedirectState(oversized)).toThrow('4KB'); + }); + + it('preserves uri, manual and csrf through encode/decode roundtrip', () => { + const payload = { + uri: 'http://127.0.0.1:12345/oauth2callback', + manual: false, + csrf: CSRF, + }; + const state = Buffer.from(JSON.stringify(payload)).toString('base64'); + const returned = buildRedirectState(state); + + const decoded = JSON.parse(Buffer.from(returned, 'base64').toString('utf8')); + expect(decoded.uri).toBe(payload.uri); + expect(decoded.manual).toBe(false); + expect(decoded.csrf).toBe(CSRF); + }); +}); diff --git a/cloud_function/package.json b/cloud_function/package.json index 483a3bd3..5b6fed13 100644 --- a/cloud_function/package.json +++ b/cloud_function/package.json @@ -2,9 +2,15 @@ "name": "oauth-handler", "version": "1.0.0", "main": "index.js", + "scripts": { + "test": "jest" + }, "dependencies": { "@google-cloud/functions-framework": "^3.0.0", "@google-cloud/secret-manager": "^5.0.0", "axios": "^1.0.0" + }, + "devDependencies": { + "jest": "^29.0.0" } } diff --git a/workspace-server/src/__tests__/auth/AuthManager.test.ts b/workspace-server/src/__tests__/auth/AuthManager.test.ts index dddaa1d6..009c1e0d 100644 --- a/workspace-server/src/__tests__/auth/AuthManager.test.ts +++ b/workspace-server/src/__tests__/auth/AuthManager.test.ts @@ -14,6 +14,31 @@ jest.mock('googleapis'); jest.mock('../../utils/logger'); jest.mock('../../utils/secure-browser-launcher'); +/** + * Helper that mirrors the CSRF extraction logic in AuthManager.authWithWeb, + * so we can unit-test all state-format permutations without spinning up an + * HTTP server or touching the private method directly. + * + * Supports two formats returned by the cloud function: + * - v0.0.9+: full base64-encoded JSON {"uri":…,"manual":…,"csrf":""} + * - ≤v0.0.7: raw hex CSRF string + */ +function extractCsrfFromState(returnedState: string | null): string | null { + if (!returnedState) return null; + try { + const decoded = JSON.parse( + Buffer.from(returnedState, 'base64').toString('utf8'), + ); + if (typeof decoded?.csrf === 'string') { + return decoded.csrf; + } + } catch { + // Not base64 JSON — fall through + } + // Fallback: treat as raw hex token (≤v0.0.7 cloud function) + return returnedState; +} + // Mock fetch globally for refreshToken tests global.fetch = jest.fn(); @@ -261,3 +286,37 @@ describe('AuthManager', () => { ); }); }); + +describe('OAuth state CSRF extraction', () => { + const RAW_CSRF = 'a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2'; + + it('extracts csrf from base64 JSON state (v0.0.9+ cloud function format)', () => { + const payload = { uri: 'http://localhost:54321/oauth2callback', manual: false, csrf: RAW_CSRF }; + const state = Buffer.from(JSON.stringify(payload)).toString('base64'); + expect(extractCsrfFromState(state)).toBe(RAW_CSRF); + }); + + it('returns raw hex as-is when state is not base64 JSON (≤v0.0.7 cloud function format)', () => { + expect(extractCsrfFromState(RAW_CSRF)).toBe(RAW_CSRF); + }); + + it('returns null when state is null', () => { + expect(extractCsrfFromState(null)).toBeNull(); + }); + + it('returns null when state is empty string', () => { + expect(extractCsrfFromState('')).toBeNull(); + }); + + it('falls back to raw value when base64 decodes to JSON without csrf field', () => { + const payload = { uri: 'http://localhost:54321/oauth2callback', manual: false }; + const state = Buffer.from(JSON.stringify(payload)).toString('base64'); + // No csrf field → falls back to treating the whole base64 string as the token + expect(extractCsrfFromState(state)).toBe(state); + }); + + it('falls back to raw value when base64 decodes to non-JSON', () => { + const garbage = Buffer.from('not-json-at-all').toString('base64'); + expect(extractCsrfFromState(garbage)).toBe(garbage); + }); +}); diff --git a/workspace-server/src/auth/AuthManager.ts b/workspace-server/src/auth/AuthManager.ts index 70727cc0..7b91c647 100644 --- a/workspace-server/src/auth/AuthManager.ts +++ b/workspace-server/src/auth/AuthManager.ts @@ -353,8 +353,37 @@ export class AuthManager { .searchParams; // SECURITY: Validate the state parameter to prevent CSRF attacks. + // + // The cloud function may return state in two formats depending on version: + // + // v0.0.9+ cloud function: returns the original base64 JSON state unchanged. + // The caller must decode it and extract `csrf`: + // JSON.parse(Buffer.from(returnedState, 'base64').toString('utf8')).csrf + // + // ≤v0.0.7 cloud function: returned only the raw hex `payload.csrf` string. + // In that case the value is compared directly against csrfToken. + // + // We try the base64 JSON path first; if it fails (parse error or missing + // csrf field) we fall back to treating the value as a raw hex token. const returnedState = qs.get('state'); - if (returnedState !== csrfToken) { + let csrfFromState: string | null = null; + if (returnedState) { + try { + const decoded = JSON.parse( + Buffer.from(returnedState, 'base64').toString('utf8'), + ); + if (typeof decoded?.csrf === 'string') { + csrfFromState = decoded.csrf; + } + } catch { + // Not base64 JSON — fall through to raw hex comparison below + } + if (!csrfFromState) { + // Fallback: treat the returned value as the raw hex token (≤v0.0.7 cloud function) + csrfFromState = returnedState; + } + } + if (!csrfFromState || csrfFromState !== csrfToken) { res.end('State mismatch. Possible CSRF attack.'); reject(new Error('OAuth state mismatch. Possible CSRF attack.')); return;