Skip to content

Commit 7e8ae96

Browse files
authored
Merge pull request #2352 from trycompai/main
[comp] Production Deploy
2 parents 95325de + 10c467d commit 7e8ae96

12 files changed

Lines changed: 372 additions & 115 deletions

File tree

apps/api/src/auth/auth-server-origins.spec.ts

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Tests for the getTrustedOrigins logic.
2+
* Tests for the getTrustedOrigins / isTrustedOrigin logic.
33
*
44
* Because auth.server.ts has side effects at module load time (better-auth
55
* initialization, DB connections, validateSecurityConfig), we test the logic
@@ -25,6 +25,32 @@ function getTrustedOriginsLogic(authTrustedOrigins: string | undefined): string[
2525
];
2626
}
2727

28+
/**
29+
* Mirror of isStaticTrustedOrigin from auth.server.ts for isolated testing.
30+
* The full isTrustedOrigin is async (checks DB for custom domains) —
31+
* that path is tested via integration tests.
32+
*/
33+
function isStaticTrustedOriginLogic(
34+
origin: string,
35+
trustedOrigins: string[],
36+
): boolean {
37+
if (trustedOrigins.includes(origin)) {
38+
return true;
39+
}
40+
41+
try {
42+
const url = new URL(origin);
43+
return (
44+
url.hostname.endsWith('.trycomp.ai') ||
45+
url.hostname.endsWith('.staging.trycomp.ai') ||
46+
url.hostname.endsWith('.trust.inc') ||
47+
url.hostname === 'trust.inc'
48+
);
49+
} catch {
50+
return false;
51+
}
52+
}
53+
2854
describe('getTrustedOrigins', () => {
2955
it('should return env-configured origins when AUTH_TRUSTED_ORIGINS is set', () => {
3056
const origins = getTrustedOriginsLogic('https://a.com, https://b.com');
@@ -45,17 +71,47 @@ describe('getTrustedOrigins', () => {
4571
const origins = getTrustedOriginsLogic(' https://a.com , https://b.com ');
4672
expect(origins).toEqual(['https://a.com', 'https://b.com']);
4773
});
74+
});
75+
76+
describe('isStaticTrustedOrigin', () => {
77+
const defaults = getTrustedOriginsLogic(undefined);
78+
79+
it('should allow static trusted origins', () => {
80+
expect(isStaticTrustedOriginLogic('https://app.trycomp.ai', defaults)).toBe(true);
81+
});
82+
83+
it('should allow trust portal subdomains of trycomp.ai', () => {
84+
expect(isStaticTrustedOriginLogic('https://security.trycomp.ai', defaults)).toBe(true);
85+
expect(isStaticTrustedOriginLogic('https://acme.trycomp.ai', defaults)).toBe(true);
86+
});
87+
88+
it('should allow trust portal subdomains of staging.trycomp.ai', () => {
89+
expect(isStaticTrustedOriginLogic('https://security.staging.trycomp.ai', defaults)).toBe(true);
90+
});
91+
92+
it('should allow trust.inc and its subdomains', () => {
93+
expect(isStaticTrustedOriginLogic('https://trust.inc', defaults)).toBe(true);
94+
expect(isStaticTrustedOriginLogic('https://acme.trust.inc', defaults)).toBe(true);
95+
});
96+
97+
it('should reject unknown origins', () => {
98+
expect(isStaticTrustedOriginLogic('https://evil.com', defaults)).toBe(false);
99+
expect(isStaticTrustedOriginLogic('https://trycomp.ai.evil.com', defaults)).toBe(false);
100+
});
101+
102+
it('should handle invalid origins gracefully', () => {
103+
expect(isStaticTrustedOriginLogic('not-a-url', defaults)).toBe(false);
104+
});
48105

49-
it('main.ts should use getTrustedOrigins instead of origin: true', () => {
50-
// Validate the CORS config change was made correctly by checking file content
106+
it('main.ts should use isTrustedOrigin for CORS', () => {
51107
const fs = require('fs');
52108
const path = require('path');
53109
const mainTs = fs.readFileSync(
54110
path.join(__dirname, '..', 'main.ts'),
55111
'utf-8',
56112
) as string;
57113
expect(mainTs).not.toContain('origin: true');
58-
expect(mainTs).toContain('origin: getTrustedOrigins()');
59-
expect(mainTs).toContain("import { getTrustedOrigins } from './auth/auth.server'");
114+
expect(mainTs).toContain('isTrustedOrigin');
115+
expect(mainTs).toContain("import { isTrustedOrigin } from './auth/auth.server'");
60116
});
61117
});

apps/api/src/auth/auth.server.ts

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import {
1515
} from 'better-auth/plugins';
1616
import { ac, allRoles } from '@trycompai/auth';
1717
import { createAuthMiddleware } from 'better-auth/api';
18+
import { Redis } from '@upstash/redis';
1819

1920
const MAGIC_LINK_EXPIRES_IN_SECONDS = 60 * 60; // 1 hour
2021

@@ -56,6 +57,93 @@ export function getTrustedOrigins(): string[] {
5657
];
5758
}
5859

60+
/**
61+
* Check if an origin matches a known trusted pattern (static list + subdomains).
62+
* This is a fast synchronous check that doesn't hit the DB.
63+
*/
64+
export function isStaticTrustedOrigin(origin: string): boolean {
65+
const trustedOrigins = getTrustedOrigins();
66+
if (trustedOrigins.includes(origin)) {
67+
return true;
68+
}
69+
70+
try {
71+
const url = new URL(origin);
72+
return (
73+
url.hostname.endsWith('.trycomp.ai') ||
74+
url.hostname.endsWith('.staging.trycomp.ai') ||
75+
url.hostname.endsWith('.trust.inc') ||
76+
url.hostname === 'trust.inc'
77+
);
78+
} catch {
79+
return false;
80+
}
81+
}
82+
83+
// ── Custom domain lookup via Redis cache ─────────────────────────────────────
84+
85+
const CORS_DOMAINS_CACHE_KEY = 'cors:custom-domains';
86+
const CORS_DOMAINS_CACHE_TTL_SECONDS = 5 * 60; // 5 minutes
87+
88+
const corsRedisClient = new Redis({
89+
url: process.env.UPSTASH_REDIS_REST_URL!,
90+
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
91+
});
92+
93+
async function getCustomDomains(): Promise<Set<string>> {
94+
try {
95+
// Try Redis cache first
96+
const cached = await corsRedisClient.get<string[]>(CORS_DOMAINS_CACHE_KEY);
97+
if (cached) {
98+
return new Set(cached);
99+
}
100+
101+
// Cache miss — query DB and store in Redis
102+
const trusts = await db.trust.findMany({
103+
where: {
104+
domain: { not: null },
105+
domainVerified: true,
106+
status: 'published',
107+
},
108+
select: { domain: true },
109+
});
110+
111+
const domains = trusts
112+
.map((t) => t.domain)
113+
.filter((d): d is string => d !== null);
114+
115+
await corsRedisClient.set(CORS_DOMAINS_CACHE_KEY, domains, {
116+
ex: CORS_DOMAINS_CACHE_TTL_SECONDS,
117+
});
118+
119+
return new Set(domains);
120+
} catch (error) {
121+
console.error('[CORS] Failed to fetch custom domains:', error);
122+
return new Set();
123+
}
124+
}
125+
126+
/**
127+
* Check if an origin is trusted. Checks (in order):
128+
* 1. Static trusted origins list
129+
* 2. *.trycomp.ai / *.trust.inc subdomains
130+
* 3. Verified custom domains from the DB (cached in Redis, TTL 5 min)
131+
*/
132+
export async function isTrustedOrigin(origin: string): Promise<boolean> {
133+
if (isStaticTrustedOrigin(origin)) {
134+
return true;
135+
}
136+
137+
// Check verified custom domains from DB via Redis cache
138+
try {
139+
const url = new URL(origin);
140+
const customDomains = await getCustomDomains();
141+
return customDomains.has(url.hostname);
142+
} catch {
143+
return false;
144+
}
145+
}
146+
59147
// Build social providers config
60148
const socialProviders: Record<string, unknown> = {};
61149

apps/api/src/auth/origin-check.middleware.spec.ts

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
import { originCheckMiddleware } from './origin-check.middleware';
22

3-
// Mock getTrustedOrigins
3+
// Mock isTrustedOrigin (async version)
44
jest.mock('./auth.server', () => ({
5-
getTrustedOrigins: () => [
6-
'http://localhost:3000',
7-
'http://localhost:3002',
8-
'https://app.trycomp.ai',
9-
'https://portal.trycomp.ai',
10-
],
5+
isTrustedOrigin: async (origin: string) => {
6+
const staticOrigins = [
7+
'http://localhost:3000',
8+
'http://localhost:3002',
9+
'https://app.trycomp.ai',
10+
'https://portal.trycomp.ai',
11+
];
12+
if (staticOrigins.includes(origin)) return true;
13+
try {
14+
const url = new URL(origin);
15+
return (
16+
url.hostname.endsWith('.trycomp.ai') ||
17+
url.hostname.endsWith('.staging.trycomp.ai') ||
18+
url.hostname.endsWith('.trust.inc') ||
19+
url.hostname === 'trust.inc'
20+
);
21+
} catch {
22+
return false;
23+
}
24+
},
1125
}));
1226

1327
function createMockReq(
@@ -22,6 +36,9 @@ function createMockReq(
2236
};
2337
}
2438

39+
/** Flush the microtask queue so async middleware completes. */
40+
const flushPromises = () => new Promise((resolve) => setImmediate(resolve));
41+
2542
function createMockRes(): Record<string, unknown> & { statusCode?: number; body?: unknown } {
2643
const res: Record<string, unknown> & { statusCode?: number; body?: unknown } = {};
2744
res.status = jest.fn().mockImplementation((code: number) => {
@@ -66,44 +83,48 @@ describe('originCheckMiddleware', () => {
6683
expect(next).toHaveBeenCalled();
6784
});
6885

69-
it('should allow POST from trusted origin', () => {
86+
it('should allow POST from trusted origin', async () => {
7087
const req = createMockReq('POST', '/v1/organization/api-keys', 'http://localhost:3000');
7188
const res = createMockRes();
7289
const next = jest.fn();
7390

7491
originCheckMiddleware(req as any, res as any, next);
92+
await flushPromises();
7593

7694
expect(next).toHaveBeenCalled();
7795
});
7896

79-
it('should block POST from untrusted origin', () => {
97+
it('should block POST from untrusted origin', async () => {
8098
const req = createMockReq('POST', '/v1/organization/transfer-ownership', 'http://evil.com');
8199
const res = createMockRes();
82100
const next = jest.fn();
83101

84102
originCheckMiddleware(req as any, res as any, next);
103+
await flushPromises();
85104

86105
expect(next).not.toHaveBeenCalled();
87106
expect(res.status).toHaveBeenCalledWith(403);
88107
});
89108

90-
it('should block DELETE from untrusted origin', () => {
109+
it('should block DELETE from untrusted origin', async () => {
91110
const req = createMockReq('DELETE', '/v1/organization', 'http://evil.com');
92111
const res = createMockRes();
93112
const next = jest.fn();
94113

95114
originCheckMiddleware(req as any, res as any, next);
115+
await flushPromises();
96116

97117
expect(next).not.toHaveBeenCalled();
98118
expect(res.status).toHaveBeenCalledWith(403);
99119
});
100120

101-
it('should block PATCH from untrusted origin', () => {
121+
it('should block PATCH from untrusted origin', async () => {
102122
const req = createMockReq('PATCH', '/v1/members/123/role', 'http://evil.com');
103123
const res = createMockRes();
104124
const next = jest.fn();
105125

106126
originCheckMiddleware(req as any, res as any, next);
127+
await flushPromises();
107128

108129
expect(next).not.toHaveBeenCalled();
109130
expect(res.status).toHaveBeenCalledWith(403);
@@ -139,12 +160,13 @@ describe('originCheckMiddleware', () => {
139160
expect(next).toHaveBeenCalled();
140161
});
141162

142-
it('should allow production origins', () => {
163+
it('should allow production origins', async () => {
143164
const req = createMockReq('POST', '/v1/organization/api-keys', 'https://app.trycomp.ai');
144165
const res = createMockRes();
145166
const next = jest.fn();
146167

147168
originCheckMiddleware(req as any, res as any, next);
169+
await flushPromises();
148170

149171
expect(next).toHaveBeenCalled();
150172
});

apps/api/src/auth/origin-check.middleware.ts

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { Request, Response, NextFunction } from 'express';
2-
import { getTrustedOrigins } from './auth.server';
2+
import { isTrustedOrigin } from './auth.server';
33

44
const SAFE_METHODS = new Set(['GET', 'HEAD', 'OPTIONS']);
55

@@ -52,14 +52,21 @@ export function originCheckMiddleware(
5252
return next();
5353
}
5454

55-
// Validate Origin against trusted origins
56-
const trustedOrigins = getTrustedOrigins();
57-
if (trustedOrigins.includes(origin)) {
58-
return next();
59-
}
60-
61-
res.status(403).json({
62-
statusCode: 403,
63-
message: 'Forbidden',
64-
});
55+
// Validate Origin against trusted origins (includes dynamic subdomains + custom domains)
56+
isTrustedOrigin(origin)
57+
.then((trusted) => {
58+
if (trusted) {
59+
return next();
60+
}
61+
res.status(403).json({
62+
statusCode: 403,
63+
message: 'Forbidden',
64+
});
65+
})
66+
.catch(() => {
67+
res.status(403).json({
68+
statusCode: 403,
69+
message: 'Forbidden',
70+
});
71+
});
6572
}

0 commit comments

Comments
 (0)