Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions apps/api/src/auth/auth-context.decorator.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { createParamDecorator, ExecutionContext } from '@nestjs/common';
import {
createParamDecorator,
ExecutionContext,
InternalServerErrorException,
} from '@nestjs/common';
import { AuthContext as AuthContextType, AuthenticatedRequest } from './types';

/**
Expand Down Expand Up @@ -46,23 +50,39 @@ export const AuthContext = createParamDecorator(
);

/**
* Parameter decorator to extract just the organization ID
* Parameter decorator to extract just the organization ID.
* Throws when no active organization is present on the request — only use this
* on endpoints that require an active organization. For endpoints decorated
* with @SkipOrgCheck() (e.g. onboarding), use @OrganizationIdOptional() instead.
*/
export const OrganizationId = createParamDecorator(
(data: unknown, ctx: ExecutionContext): string => {
const request = ctx.switchToHttp().getRequest<AuthenticatedRequest>();
const { organizationId } = request;

if (!organizationId) {
throw new Error(
'Organization ID not found. Ensure HybridAuthGuard is applied.',
throw new InternalServerErrorException(
'Organization ID missing on request. If this endpoint is @SkipOrgCheck()-decorated, use @OrganizationIdOptional() instead.',
);
}

return organizationId;
},
);

/**
* Parameter decorator to extract the organization ID when it may be absent.
* Returns `undefined` instead of throwing when no active organization is
* present. Use this on endpoints decorated with @SkipOrgCheck() where the
* user may not yet have an active organization (e.g. during onboarding).
*/
export const OrganizationIdOptional = createParamDecorator(
(data: unknown, ctx: ExecutionContext): string | undefined => {
const request = ctx.switchToHttp().getRequest<AuthenticatedRequest>();
return request.organizationId || undefined;
},
);

/**
* Parameter decorator to extract the user ID (only available for session auth)
*/
Expand Down
27 changes: 27 additions & 0 deletions apps/api/src/frameworks/frameworks.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ describe('FrameworksController', () => {

const mockService = {
findAll: jest.fn(),
findAvailable: jest.fn(),
delete: jest.fn(),
};

Expand Down Expand Up @@ -68,6 +69,32 @@ describe('FrameworksController', () => {
});
});

describe('findAvailable', () => {
// Regression test for the onboarding 500 bug: this endpoint must not throw
// when the authenticated user has no active organization yet (fresh signups
// hitting the first onboarding step). Previously used @OrganizationId(),
// which threw when organizationId was empty → HTTP 500.
it('should return frameworks when user has no active organization', async () => {
const mockFrameworks = [
{ id: 'frk_1', name: 'soc2', visible: true, isCustom: false },
];
mockService.findAvailable.mockResolvedValue(mockFrameworks);

const result = await controller.findAvailable(undefined);

expect(result).toEqual({ data: mockFrameworks, count: 1 });
expect(service.findAvailable).toHaveBeenCalledWith(undefined);
});

it('should pass organizationId to service when user has an active org', async () => {
mockService.findAvailable.mockResolvedValue([]);

await controller.findAvailable('org_1');

expect(service.findAvailable).toHaveBeenCalledWith('org_1');
});
});

describe('delete', () => {
it('should delegate to service and return result', async () => {
mockService.delete.mockResolvedValue({ success: true });
Expand Down
8 changes: 6 additions & 2 deletions apps/api/src/frameworks/frameworks.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ import { HybridAuthGuard } from '../auth/hybrid-auth.guard';
import { PermissionGuard } from '../auth/permission.guard';
import { RequirePermission } from '../auth/require-permission.decorator';
import { SkipOrgCheck } from '../auth/skip-org-check.decorator';
import { AuthContext, OrganizationId } from '../auth/auth-context.decorator';
import {
AuthContext,
OrganizationId,
OrganizationIdOptional,
} from '../auth/auth-context.decorator';
import type { AuthContext as AuthContextType } from '../auth/types';
import { FrameworksService } from './frameworks.service';
import { AddFrameworksDto } from './dto/add-frameworks.dto';
Expand Down Expand Up @@ -57,7 +61,7 @@ export class FrameworksController {
summary:
'List available frameworks (requires session, no active org needed — used during onboarding)',
})
async findAvailable(@OrganizationId() organizationId?: string) {
async findAvailable(@OrganizationIdOptional() organizationId?: string) {
const data = await this.frameworksService.findAvailable(organizationId);
return { data, count: data.length };
}
Expand Down
34 changes: 32 additions & 2 deletions apps/app/src/app/(app)/setup/components/FrameworkSelection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,22 @@ export function FrameworkSelection({ value, onChange, onLoadingChange }: Framewo
const onChangeRef = useRef(onChange);
const valueRef = useRef(value);

const { data: frameworks = [], isLoading } = useSWR<Framework[]>(
const {
data: frameworks = [],
isLoading,
error,
mutate,
} = useSWR<Framework[]>(
'/v1/frameworks/available',
async (endpoint: string) => {
const response = await api.get<{ data: Framework[] }>(endpoint);
return Array.isArray(response.data?.data) ? response.data.data : [];
if (response.error || !response.data) {
throw new Error(
response.error ||
`Failed to load frameworks (HTTP ${response.status})`,
);
}
return Array.isArray(response.data.data) ? response.data.data : [];
},
);

Expand Down Expand Up @@ -52,6 +63,25 @@ export function FrameworkSelection({ value, onChange, onLoadingChange }: Framewo
return null;
}

if (error) {
const message =
error instanceof Error ? error.message : 'Something went wrong.';
return (
<div className="flex flex-col items-start gap-2">
<p className="text-sm text-destructive">
We couldn't load the compliance frameworks. {message}
</p>
<button
type="button"
onClick={() => mutate()}
className="text-sm underline hover:no-underline"
>
Try again
</button>
</div>
);
}

return (
<div className="flex flex-wrap gap-3 overflow-y-auto pr-1">
{frameworks
Expand Down
Loading