Skip to content

Commit b21d81e

Browse files
committed
feat(cloud-security): implement organization-based authorization for scan endpoints
1 parent ee69414 commit b21d81e

5 files changed

Lines changed: 97 additions & 23 deletions

File tree

apps/api/src/cloud-security/cloud-security.controller.ts

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ import {
33
Post,
44
Get,
55
Param,
6-
Body,
76
Headers,
87
Logger,
98
HttpException,
109
HttpStatus,
10+
UseGuards,
1111
} from '@nestjs/common';
12+
import { HybridAuthGuard } from '../auth/hybrid-auth.guard';
13+
import { OrganizationId } from '../auth/auth-context.decorator';
1214
import { CloudSecurityService } from './cloud-security.service';
1315

1416
@Controller({ path: 'cloud-security', version: '1' })
@@ -57,17 +59,11 @@ export class CloudSecurityController {
5759
}
5860

5961
@Post('trigger/:connectionId')
62+
@UseGuards(HybridAuthGuard)
6063
async triggerScan(
6164
@Param('connectionId') connectionId: string,
62-
@Body('organizationId') organizationId: string,
65+
@OrganizationId() organizationId: string,
6366
) {
64-
if (!organizationId) {
65-
throw new HttpException(
66-
'Organization ID required',
67-
HttpStatus.BAD_REQUEST,
68-
);
69-
}
70-
7167
this.logger.log(
7268
`Cloud security scan trigger requested for connection ${connectionId}`,
7369
);
@@ -86,9 +82,16 @@ export class CloudSecurityController {
8682
}
8783

8884
@Get('runs/:runId')
89-
async getRunStatus(@Param('runId') runId: string) {
85+
@UseGuards(HybridAuthGuard)
86+
async getRunStatus(
87+
@Param('runId') runId: string,
88+
@OrganizationId() organizationId: string,
89+
) {
9090
try {
91-
return await this.cloudSecurityService.getRunStatus(runId);
91+
return await this.cloudSecurityService.getRunStatus(
92+
runId,
93+
organizationId,
94+
);
9295
} catch (error) {
9396
const message =
9497
error instanceof Error ? error.message : 'Failed to get run status';

apps/api/src/cloud-security/cloud-security.module.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ import { GCPSecurityService } from './providers/gcp-security.service';
55
import { AWSSecurityService } from './providers/aws-security.service';
66
import { AzureSecurityService } from './providers/azure-security.service';
77
import { IntegrationPlatformModule } from '../integration-platform/integration-platform.module';
8+
import { AuthModule } from '../auth/auth.module';
89

910
@Module({
10-
imports: [IntegrationPlatformModule],
11+
imports: [IntegrationPlatformModule, AuthModule],
1112
controllers: [CloudSecurityController],
1213
providers: [
1314
CloudSecurityService,

apps/api/src/cloud-security/cloud-security.service.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ export interface ScanResult {
3333
export class CloudSecurityService {
3434
private readonly logger = new Logger(CloudSecurityService.name);
3535

36+
// Track which organization owns each trigger.dev run for authorization
37+
private readonly runOwnership = new Map<
38+
string,
39+
{ organizationId: string; createdAt: number }
40+
>();
41+
42+
// Clean up stale entries older than 10 minutes
43+
private readonly RUN_OWNERSHIP_TTL_MS = 10 * 60 * 1000;
44+
3645
constructor(
3746
private readonly credentialVaultService: CredentialVaultService,
3847
private readonly oauthCredentialsService: OAuthCredentialsService,
@@ -250,21 +259,49 @@ export class CloudSecurityService {
250259
runId: handle.id,
251260
});
252261

262+
// Track ownership for authorization on status checks
263+
this.runOwnership.set(handle.id, {
264+
organizationId,
265+
createdAt: Date.now(),
266+
});
267+
this.cleanupStaleRuns();
268+
253269
return { runId: handle.id };
254270
}
255271

256272
async getRunStatus(
257273
runId: string,
274+
organizationId: string,
258275
): Promise<{ completed: boolean; success: boolean; output: unknown }> {
276+
// Verify the caller's organization owns this run
277+
const ownership = this.runOwnership.get(runId);
278+
if (!ownership || ownership.organizationId !== organizationId) {
279+
throw new Error('Run not found');
280+
}
281+
259282
const run = await runs.retrieve(runId);
260283

284+
// Clean up completed runs from the ownership map
285+
if (run.isCompleted) {
286+
this.runOwnership.delete(runId);
287+
}
288+
261289
return {
262290
completed: run.isCompleted,
263291
success: run.isCompleted ? run.isSuccess : false,
264292
output: run.isCompleted ? run.output : null,
265293
};
266294
}
267295

296+
private cleanupStaleRuns(): void {
297+
const now = Date.now();
298+
for (const [runId, entry] of this.runOwnership) {
299+
if (now - entry.createdAt > this.RUN_OWNERSHIP_TTL_MS) {
300+
this.runOwnership.delete(runId);
301+
}
302+
}
303+
}
304+
268305
private async storeFindings(
269306
connectionId: string,
270307
provider: string,

apps/app/src/app/(app)/[orgId]/cloud-tests/actions/run-platform-scan.ts

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,28 @@ import { headers } from 'next/headers';
88
const MAX_POLL_ATTEMPTS = 150; // Max 5 minutes (150 * 2 seconds)
99
const POLL_INTERVAL_MS = 2000;
1010

11+
/**
12+
* Get auth headers for calling guarded API endpoints server-side.
13+
* Uses Better Auth's jwt plugin to generate a Bearer token from the session cookie.
14+
*/
15+
async function getAuthHeaders(organizationId: string): Promise<Record<string, string>> {
16+
const reqHeaders = await headers();
17+
const authHeaders: Record<string, string> = {
18+
'X-Organization-Id': organizationId,
19+
};
20+
21+
// Get a JWT from Better Auth using the session cookie
22+
const tokenResponse = await auth.api.getToken({
23+
headers: reqHeaders,
24+
});
25+
26+
if (tokenResponse?.token) {
27+
authHeaders['Authorization'] = `Bearer ${tokenResponse.token}`;
28+
}
29+
30+
return authHeaders;
31+
}
32+
1133
/**
1234
* Run cloud security scan for a new platform connection.
1335
* Triggers a background task via the NestJS API and polls for completion,
@@ -36,10 +58,13 @@ export const runPlatformScan = async (connectionId: string) => {
3658
}
3759

3860
try {
61+
const authHeaders = await getAuthHeaders(orgId);
62+
3963
// Trigger the scan via API (task is defined in the API's trigger.dev project)
4064
const triggerResponse = await serverApi.post<{ runId: string }>(
4165
`/v1/cloud-security/trigger/${connectionId}`,
42-
{ organizationId: orgId },
66+
undefined,
67+
authHeaders,
4368
);
4469

4570
if (triggerResponse.error || !triggerResponse.data?.runId) {
@@ -64,7 +89,7 @@ export const runPlatformScan = async (connectionId: string) => {
6489
provider?: string;
6590
scannedAt?: string;
6691
} | null;
67-
}>(`/v1/cloud-security/runs/${runId}`);
92+
}>(`/v1/cloud-security/runs/${runId}`, authHeaders);
6893

6994
if (statusResponse.error) {
7095
return {

apps/app/src/lib/server-api-client.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ interface ApiResponse<T = unknown> {
1111
interface CallOptions {
1212
method?: 'GET' | 'POST' | 'PUT' | 'PATCH' | 'DELETE';
1313
body?: unknown;
14+
headers?: Record<string, string>;
1415
}
1516

1617
/**
@@ -21,7 +22,7 @@ async function call<T = unknown>(
2122
endpoint: string,
2223
options: CallOptions = {},
2324
): Promise<ApiResponse<T>> {
24-
const { method = 'GET', body } = options;
25+
const { method = 'GET', body, headers: customHeaders } = options;
2526

2627
const requestHeaders: Record<string, string> = {
2728
'Content-Type': 'application/json',
@@ -33,6 +34,11 @@ async function call<T = unknown>(
3334
requestHeaders['Cookie'] = cookieHeader;
3435
}
3536

37+
// Apply custom headers (e.g. Authorization, X-Organization-Id)
38+
if (customHeaders) {
39+
Object.assign(requestHeaders, customHeaders);
40+
}
41+
3642
try {
3743
const response = await fetch(`${API_BASE_URL}${endpoint}`, {
3844
method,
@@ -67,16 +73,18 @@ async function call<T = unknown>(
6773
}
6874

6975
export const serverApi = {
70-
get: <T = unknown>(endpoint: string) => call<T>(endpoint, { method: 'GET' }),
76+
get: <T = unknown>(endpoint: string, headers?: Record<string, string>) =>
77+
call<T>(endpoint, { method: 'GET', headers }),
7178

72-
post: <T = unknown>(endpoint: string, body?: unknown) =>
73-
call<T>(endpoint, { method: 'POST', body }),
79+
post: <T = unknown>(endpoint: string, body?: unknown, headers?: Record<string, string>) =>
80+
call<T>(endpoint, { method: 'POST', body, headers }),
7481

75-
put: <T = unknown>(endpoint: string, body?: unknown) =>
76-
call<T>(endpoint, { method: 'PUT', body }),
82+
put: <T = unknown>(endpoint: string, body?: unknown, headers?: Record<string, string>) =>
83+
call<T>(endpoint, { method: 'PUT', body, headers }),
7784

78-
patch: <T = unknown>(endpoint: string, body?: unknown) =>
79-
call<T>(endpoint, { method: 'PATCH', body }),
85+
patch: <T = unknown>(endpoint: string, body?: unknown, headers?: Record<string, string>) =>
86+
call<T>(endpoint, { method: 'PATCH', body, headers }),
8087

81-
delete: <T = unknown>(endpoint: string) => call<T>(endpoint, { method: 'DELETE' }),
88+
delete: <T = unknown>(endpoint: string, headers?: Record<string, string>) =>
89+
call<T>(endpoint, { method: 'DELETE', headers }),
8290
};

0 commit comments

Comments
 (0)