Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion testplanit/app/[locale]/admin/users/AddUser.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ export function AddUser({ open, onClose }: AddUserProps) {

// If the admin skipped the password (allowed when a passwordless login
// method is available), store an unusable random secret — same pattern as
// SAML auto-provisioning in app/api/auth/saml/callback/route.ts. User
// SAML auto-provisioning in app/api/auth/callback/saml/route.ts. User
// signs in via magic link / SSO.
const passwordForApi =
data.password.length > 0 ? data.password : crypto.randomUUID();
Expand Down
118 changes: 118 additions & 0 deletions testplanit/app/api/auth/callback/saml/route.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// @vitest-environment node
import jwt from "jsonwebtoken";
import { beforeEach, describe, expect, it, vi } from "vitest";

vi.hoisted(() => {
process.env.NEXTAUTH_SECRET = "test-secret-key-at-least-32-chars-long";
process.env.NEXTAUTH_URL = "https://app.example.com";
});

const SECRET = "test-secret-key-at-least-32-chars-long";

vi.mock("~/lib/valkey", () => ({ default: null })); // RelayState via signed token

vi.mock("~/server/db", () => ({
db: {
samlConfiguration: { findUnique: vi.fn() },
user: { findUnique: vi.fn(), update: vi.fn(), create: vi.fn() },
account: { upsert: vi.fn() },
roles: { findFirst: vi.fn() },
},
}));

const { validateSAMLResponse } = vi.hoisted(() => ({
validateSAMLResponse: vi.fn(),
}));
vi.mock("~/server/saml-provider", () => ({
createSAMLClient: vi.fn(async () => ({})),
validateSAMLResponse,
}));
vi.mock("~/lib/services/notificationService", () => ({
NotificationService: { createUserRegistrationNotification: vi.fn() },
}));
vi.mock("~/lib/utils/email-domain-validation", () => ({
isEmailDomainAllowed: vi.fn(async () => true),
}));

import { db } from "~/server/db";
import { POST } from "./route";

function makeReq(
relayState: string | null,
samlResponse: string | null = "<saml/>"
) {
const fd = new FormData();
if (samlResponse) fd.set("SAMLResponse", samlResponse);
if (relayState) fd.set("RelayState", relayState);
return {
headers: new Headers(),
formData: async () => fd,
url: "https://cint-prod-pod:3000/api/auth/callback/saml",
} as any;
}

// A RelayState as the init route would mint it without Valkey (signed token).
function relayFor(providerId: string, callbackUrl = "/dash") {
return jwt.sign({ providerId, callbackUrl }, SECRET, { expiresIn: "15m" });
}

describe("POST /api/auth/callback/saml — ACS validator", () => {
beforeEach(() => {
vi.clearAllMocks();
});

it("returns 400 when RelayState is missing or invalid (Bug 4 regression)", async () => {
const res = await POST(makeReq(null));
expect(res.status).toBe(400);
const body = await res.json();
expect(body.error).toMatch(/invalid or expired/i);
});

it("recovers the provider from RelayState and hands off to /api/auth/saml/complete (Bug 3)", async () => {
(db.samlConfiguration.findUnique as any).mockResolvedValue({
id: "cfg",
entryPoint: "e",
cert: "c",
issuer: "i",
attributeMapping: {},
autoProvisionUsers: false,
provider: { name: "okta", enabled: true },
});
validateSAMLResponse.mockResolvedValue({
email: "bob@example.com",
nameID: "bob",
});
(db.user.findUnique as any).mockResolvedValue({
id: "user_7",
email: "bob@example.com",
name: "Bob",
authMethod: "SSO",
externalId: "bob",
});
(db.account.upsert as any).mockResolvedValue({});

const res = await POST(makeReq(relayFor("ssoprovider_9")));

// Bug 2: looked up by the providerId carried in RelayState.
expect(db.samlConfiguration.findUnique).toHaveBeenCalledWith({
where: { providerId: "ssoprovider_9" },
include: { provider: true },
});

// Bug 3: hands off to the real session-minting route on the public origin,
// NOT the dead /api/auth/callback/saml?token= path that hit NextAuth.
const location = res.headers.get("location")!;
expect(
location.startsWith(
"https://app.example.com/api/auth/saml/complete?token="
)
).toBe(true);
expect(location).not.toContain("/api/auth/callback/saml?token=");
});

it("returns 404 when the provider in RelayState is unknown", async () => {
(db.samlConfiguration.findUnique as any).mockResolvedValue(null);
const res = await POST(makeReq(relayFor("nope")));
expect(res.status).toBe(404);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@ import { randomUUID } from "crypto";
import { NextRequest, NextResponse } from "next/server";
import {
checkRateLimit,
consumeSamlRelayState,
createTempSessionToken,
getAppBaseUrl,
getSecurityHeaders,
sanitizeCallbackUrl,
validateSAMLTimestamp,
verifyState,
} from "~/lib/auth-security";
import { NotificationService } from "~/lib/services/notificationService";
import { isEmailDomainAllowed } from "~/lib/utils/email-domain-validation";
import { db } from "~/server/db";
import { createSAMLClient, validateSAMLResponse } from "~/server/saml-provider";

// SAML callback handler
/**
* SAML Assertion Consumer Service (ACS).
*
* The IdP POSTs the SAMLResponse here. This path matches the ACS URL embedded
* in the AuthnRequest (createSAMLClient) and shown in the admin UI. It is a
* static route, so it takes precedence over the NextAuth [...nextauth]
* catch-all for this exact path.
*/
export async function POST(request: NextRequest) {
try {
const clientIp =
Expand Down Expand Up @@ -46,31 +54,27 @@ export async function POST(request: NextRequest) {
);
}

// Get provider and state from cookies
const providerId = request.cookies.get("saml-provider")?.value;
const storedState = request.cookies.get("saml-state")?.value;
const callbackUrl = sanitizeCallbackUrl(
request.cookies.get("saml-callback-url")?.value
// Recover the provider and destination from RelayState (the IdP echoes it
// back here; same-site cookies are not sent on this cross-site POST). The
// token is single-use and short-lived, which is what guards this endpoint.
const relay = await consumeSamlRelayState(
typeof relayState === "string" ? relayState : null
);

if (!providerId) {
if (!relay) {
return NextResponse.json(
{ error: "Provider information not found" },
{ error: "Invalid or expired SAML request" },
{ status: 400 }
);
}

// Verify state if relay state is provided
if (relayState && !verifyState(storedState, relayState as string)) {
return NextResponse.json(
{ error: "Invalid state parameter" },
{ status: 400 }
);
}
const providerId = relay.providerId;
const callbackUrl = sanitizeCallbackUrl(relay.callbackUrl);

// Fetch SAML configuration
// Fetch SAML configuration. RelayState carries the SsoProvider id, which is
// the unique foreign key on SamlConfiguration (not its own id).
const samlConfig = await db.samlConfiguration.findUnique({
where: { id: providerId },
where: { providerId },
include: { provider: true },
});

Expand Down Expand Up @@ -272,19 +276,15 @@ export async function POST(request: NextRequest) {
email: user.email,
});

// Create a session for the user by redirecting to NextAuth callback
// Hand off to the completion route, which verifies the token and mints the
// NextAuth session cookie before redirecting to the final destination.
const response = NextResponse.redirect(
new URL(
`/api/auth/callback/saml?token=${tempToken}&callbackUrl=${encodeURIComponent(callbackUrl)}`,
request.url
`/api/auth/saml/complete?token=${tempToken}&callbackUrl=${encodeURIComponent(callbackUrl)}`,
getAppBaseUrl(request)
)
);

// Clean up cookies
response.cookies.delete("saml-state");
response.cookies.delete("saml-provider");
response.cookies.delete("saml-callback-url");

// Set security headers
const securityHeaders = getSecurityHeaders();
Object.entries(securityHeaders).forEach(([key, value]) => {
Expand Down
10 changes: 6 additions & 4 deletions testplanit/app/api/auth/logout/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { NextRequest, NextResponse } from "next/server";
import { getAppBaseUrl } from "~/lib/auth-security";
import { withAuditContext } from "~/lib/auditContextWrappers";
import { auditAuthEvent } from "~/lib/services/auditLog";
import { getServerAuthSession } from "~/server/auth";
Expand Down Expand Up @@ -102,9 +103,10 @@ export const POST = withAuditContext(async (request: NextRequest) => {
"urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
issuer: samlProvider.samlConfig.issuer,
},
"https://" +
request.headers.get("host") +
new URL(
"/api/auth/saml/logout-callback",
getAppBaseUrl(request)
).toString(),
{}
);

Expand Down Expand Up @@ -186,11 +188,11 @@ export const GET = withAuditContext(async (request: NextRequest) => {
}

// Redirect to signin page after successful logout
return NextResponse.redirect(new URL("/signin", request.url));
return NextResponse.redirect(new URL("/signin", getAppBaseUrl(request)));
} catch (error) {
console.error("SAML logout callback error:", error);
return NextResponse.redirect(
new URL("/signin?error=logout-failed", request.url)
new URL("/signin?error=logout-failed", getAppBaseUrl(request))
);
}
});
106 changes: 106 additions & 0 deletions testplanit/app/api/auth/saml/complete/route.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// @vitest-environment node
import jwt from "jsonwebtoken";
import { decode } from "next-auth/jwt";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";

vi.hoisted(() => {
process.env.NEXTAUTH_SECRET = "test-secret-key-at-least-32-chars-long";
process.env.NEXTAUTH_URL = "https://app.example.com";
});

const SECRET = "test-secret-key-at-least-32-chars-long";

// Capture every cookie the route sets so we can decode the session token.
const cookieJar = new Map<string, { value: string }>();
vi.mock("next/headers", () => ({
cookies: vi.fn(async () => ({
set: (name: string, value: string) => cookieJar.set(name, { value }),
get: (name: string) => cookieJar.get(name),
delete: (name: string) => cookieJar.delete(name),
})),
}));

vi.mock("~/server/db", () => ({
db: { user: { findUnique: vi.fn() } },
}));

import { db } from "~/server/db";
import { GET } from "./route";

const user = {
id: "user_42",
email: "alice@example.com",
name: "Alice",
access: "USER",
isApi: false,
passwordChangedAt: null,
mustChangePassword: false,
};

function makeReq(token: string, callbackUrl = "/dashboard") {
return {
nextUrl: { searchParams: new URLSearchParams({ token, callbackUrl }) },
headers: new Headers(),
url: "https://cint-prod-pod-xyz:3000/api/auth/saml/complete",
} as any;
}

function tempToken() {
return jwt.sign(
{ userId: user.id, provider: "saml-okta", email: user.email },
SECRET,
{ expiresIn: "5m" }
);
}

describe("GET /api/auth/saml/complete — post-Okta session handoff", () => {
beforeEach(() => {
cookieJar.clear();
vi.clearAllMocks();
(db.user.findUnique as any).mockResolvedValue(user);
});

afterEach(() => {
vi.unstubAllEnvs();
});

it("mints a session cookie that NextAuth can decode (sub = user id)", async () => {
const res = await GET(makeReq(tempToken()));

// Redirects to the destination on the public origin (not the pod host).
expect(res.status).toBe(307);
expect(res.headers.get("location")).toBe(
"https://app.example.com/dashboard"
);

// The session cookie it set must be a valid NextAuth v4 JWT — decode it the
// exact way NextAuth would on the next request.
const sessionCookie = cookieJar.get("next-auth.session-token");
expect(sessionCookie?.value).toBeTruthy();

const decoded = await decode({
token: sessionCookie!.value,
secret: SECRET,
});
expect(decoded?.sub).toBe(user.id);
expect(decoded?.email).toBe(user.email);
// The token carries access on the first request so middleware (which only
// decodes, never runs the jwt callback) sees it without a session refresh.
expect(decoded?.access).toBe("USER");
});

it("sets the __Secure- cookie name in production (matches NextAuth on https)", async () => {
vi.stubEnv("NODE_ENV", "production");

await GET(makeReq(tempToken()));

expect(
cookieJar.get("__Secure-next-auth.session-token")?.value
).toBeTruthy();
});

it("rejects an invalid/expired token with 401", async () => {
const res = await GET(makeReq("not-a-valid-jwt"));
expect(res.status).toBe(401);
});
});
Loading
Loading