1+ import { and, eq } from "drizzle-orm";
2+ import type { Context } from "hono";
13import { Hono } from "hono";
24
35import { jwtMiddleware } from "../auth";
46import { ApiContext } from "../context";
7+ import { createDatabase } from "../db";
8+ import { memberships } from "../db/schema";
59import { getProvider, OAuthError } from "../oauth";
610
711// Create a new Hono instance for OAuth endpoints
812const oauthRoutes = new Hono<ApiContext>();
913
14+ /**
15+ * Middleware to override organization context from query parameter.
16+ * Used during OAuth initiation when the frontend passes the current org.
17+ * Validates membership before accepting the override.
18+ */
19+ const resolveOrgFromQuery = async (
20+ c: Context<ApiContext>,
21+ next: () => Promise<void>
22+ ) => {
23+ const orgIdFromQuery = c.req.query("organizationId");
24+ if (orgIdFromQuery) {
25+ const payload = c.get("jwtPayload");
26+ if (!payload) {
27+ return c.json({ error: "Not authenticated" }, 401);
28+ }
29+ const db = createDatabase(c.env.DB);
30+
31+ const [membership] = await db
32+ .select({ organizationId: memberships.organizationId })
33+ .from(memberships)
34+ .where(
35+ and(
36+ eq(memberships.userId, payload.sub),
37+ eq(memberships.organizationId, orgIdFromQuery)
38+ )
39+ );
40+
41+ if (!membership) {
42+ return c.json({ error: "Organization not found or access denied" }, 403);
43+ }
44+
45+ c.set("organizationId", membership.organizationId);
46+ }
47+ await next();
48+ };
49+
1050/**
1151 * GET /oauth/:provider/connect
1252 *
@@ -16,59 +56,66 @@ const oauthRoutes = new Hono<ApiContext>();
1656 *
1757 * Supported providers: google-mail, google-calendar, discord, linkedin, reddit, github
1858 */
19- oauthRoutes.get("/:provider/connect", jwtMiddleware, async (c) => {
20- const providerName = c.req.param("provider");
59+ oauthRoutes.get(
60+ "/:provider/connect",
61+ jwtMiddleware,
62+ resolveOrgFromQuery,
63+ async (c) => {
64+ const providerName = c.req.param("provider");
2165
22- try {
23- const provider = getProvider(providerName);
24- const code = c.req.query("code");
66+ try {
67+ const provider = getProvider(providerName);
68+ const code = c.req.query("code");
2569
26- // Callback flow - provider redirected back with authorization code
27- if (code) {
28- const stateParam = c.req.query("state");
29- const error = c.req.query("error");
70+ // Callback flow - provider redirected back with authorization code
71+ if (code) {
72+ const stateParam = c.req.query("state");
73+ const error = c.req.query("error");
3074
31- // Check for authorization errors from provider
32- if (error) {
33- return c.redirect(`${c.env.WEB_HOST}/integrations?error=${error}`);
34- }
75+ // Check for authorization errors from provider
76+ if (error) {
77+ return c.redirect(`${c.env.WEB_HOST}/integrations?error=${error}`);
78+ }
3579
36- // State parameter is required for CSRF protection
37- if (!stateParam) {
38- return c.redirect(`${c.env.WEB_HOST}/integrations?error=oauth_failed`);
39- }
80+ // State parameter is required for CSRF protection
81+ if (!stateParam) {
82+ return c.redirect(
83+ `${c.env.WEB_HOST}/integrations?error=oauth_failed`
84+ );
85+ }
4086
41- // Let the provider handle the complete callback flow:
42- // 1. Validate state (CSRF + organization membership)
43- // 2. Exchange code for access token
44- // 3. Fetch user information
45- // 4. Create integration in database
46- const { orgId } = await provider.handleCallback(c, code, stateParam);
87+ // Let the provider handle the complete callback flow:
88+ // 1. Validate state (CSRF + organization membership)
89+ // 2. Exchange code for access token
90+ // 3. Fetch user information
91+ // 4. Create integration in database
92+ const { orgId } = await provider.handleCallback(c, code, stateParam);
4793
48- return c.redirect(
49- `${c.env.WEB_HOST}/org/${orgId}/integrations?success=${providerName}_connected`
50- );
51- }
94+ return c.redirect(
95+ `${c.env.WEB_HOST}/org/${orgId}/integrations?success=${providerName}_connected`
96+ );
97+ }
5298
53- // Initiation flow - no code parameter, start OAuth flow
54- // Provider will:
55- // 1. Validate organization context
56- // 2. Create secure state with nonce
57- // 3. Build authorization URL with correct parameters
58- const authUrl = await provider.initiateAuth(c);
59- return c.redirect(authUrl);
60- } catch (error) {
61- // Handle OAuth-specific errors with user-friendly redirects
62- if (error instanceof OAuthError) {
63- return c.redirect(
64- `${c.env.WEB_HOST}/integrations?error=${error.redirectError}`
65- );
66- }
99+ // Initiation flow - no code parameter, start OAuth flow
100+ // Provider will:
101+ // 1. Validate organization context
102+ // 2. Create secure state with nonce
103+ // 3. Build authorization URL with correct parameters
104+ const authUrl = await provider.initiateAuth(c);
105+ return c.redirect(authUrl);
106+ } catch (error) {
107+ // Handle OAuth-specific errors with user-friendly redirects
108+ if (error instanceof OAuthError) {
109+ return c.redirect(
110+ `${c.env.WEB_HOST}/integrations?error=${error.redirectError}`
111+ );
112+ }
67113
68- // Log unexpected errors and show generic error
69- console.error(`OAuth error for ${providerName}:`, error);
70- return c.redirect(`${c.env.WEB_HOST}/integrations?error=oauth_failed`);
114+ // Log unexpected errors and show generic error
115+ console.error(`OAuth error for ${providerName}:`, error);
116+ return c.redirect(`${c.env.WEB_HOST}/integrations?error=oauth_failed`);
117+ }
71118 }
72- } );
119+ );
73120
74121export default oauthRoutes;
0 commit comments