Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api-doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2342,6 +2342,15 @@ components:
responseTimeoutMilliseconds:
type: integer
description: Query response timeout in milliseconds
oauthPassthrough:
type: boolean
description: >-
When true, the connection accepts per-request OAuth tokens via
the X-Database-Token HTTP header. Queries run as the authenticated
user's Snowflake role instead of using static credentials. The
static credentials (password/privateKey) are used as a fallback
when no token is provided.
default: false

TrinoConnection:
type: object
Expand Down
43 changes: 43 additions & 0 deletions packages/server/src/request_context.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { describe, expect, it } from "bun:test";
import { getDatabaseToken, requestContext } from "./request_context";

describe("request_context", () => {
describe("getDatabaseToken", () => {
it("should return undefined when no context is active", () => {
expect(getDatabaseToken()).toBeUndefined();
});

it("should return the token from the active context", async () => {
let captured: string | undefined;
await requestContext.run({ databaseToken: "test-token" }, () => {
captured = getDatabaseToken();
});
expect(captured).toBe("test-token");
});

it("should return undefined when context has no token", async () => {
let captured: string | undefined = "should-be-overwritten";
await requestContext.run({}, () => {
captured = getDatabaseToken();
});
expect(captured).toBeUndefined();
});

it("should isolate tokens across concurrent contexts", async () => {
const results: (string | undefined)[] = [];
await Promise.all([
requestContext.run({ databaseToken: "token-a" }, async () => {
// Simulate async work
await new Promise((r) => setTimeout(r, 10));
results.push(getDatabaseToken());
}),
requestContext.run({ databaseToken: "token-b" }, async () => {
results.push(getDatabaseToken());
}),
]);
expect(results).toContain("token-a");
expect(results).toContain("token-b");
expect(results).toHaveLength(2);
});
});
});
21 changes: 21 additions & 0 deletions packages/server/src/request_context.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/**
* Per-request context using AsyncLocalStorage.
*
* Carries request-scoped data (e.g., OAuth tokens) through the call stack
* without threading parameters through every function. Used by the MCP
* endpoint to propagate database tokens for per-user connection creation.
*/

import { AsyncLocalStorage } from "node:async_hooks";

export interface RequestContext {
/** OAuth token for per-request database authentication. */
databaseToken?: string;
}

export const requestContext = new AsyncLocalStorage<RequestContext>();

/** Returns the database token from the current request context, if any. */
export function getDatabaseToken(): string | undefined {
return requestContext.getStore()?.databaseToken;
}
69 changes: 44 additions & 25 deletions packages/server/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
import { logger, loggerMiddleware } from "./logger";

import { initializeMcpServer } from "./mcp/server";
import { requestContext } from "./request_context";
import { ProjectStore } from "./service/project_store";
// Parse command line arguments
function parseArgs() {
Expand Down Expand Up @@ -142,37 +143,55 @@ mcpApp.all(MCP_ENDPOINT, async (req, res) => {

try {
if (req.method === "POST") {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: undefined,
});

transport.onclose = () => {
logger.info(
`[MCP Transport Info] Stateless transport closed for a request.`,
);
};
transport.onerror = (err: Error) => {
logger.error(`[MCP Transport Error] Stateless transport error:`, {
error: err,
// Extract database token from request header for per-user OAuth pass-through.
// When present, downstream connections use this token instead of static credentials.
const databaseToken =
(req.headers["x-database-token"] as string) || undefined;

const handlePost = async () => {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: undefined,
});
};

const requestMcpServer = initializeMcpServer(projectStore);
await requestMcpServer.connect(transport);

res.on("close", () => {
logger.info(
"[MCP Transport Info] Response closed, cleaning up stateless transport.",
);
transport.close().catch((err) => {
transport.onclose = () => {
logger.info(
`[MCP Transport Info] Stateless transport closed for a request.`,
);
};
transport.onerror = (err: Error) => {
logger.error(
"[MCP Transport Error] Error closing stateless transport on response close:",
{ error: err },
`[MCP Transport Error] Stateless transport error:`,
{
error: err,
},
);
};

const requestMcpServer = initializeMcpServer(projectStore);
await requestMcpServer.connect(transport);

res.on("close", () => {
logger.info(
"[MCP Transport Info] Response closed, cleaning up stateless transport.",
);
transport.close().catch((err) => {
logger.error(
"[MCP Transport Error] Error closing stateless transport on response close:",
{ error: err },
);
});
});
});

await transport.handleRequest(req, res, req.body);
await transport.handleRequest(req, res, req.body);
};

// Run the handler within request context so downstream code can
// access the database token via getDatabaseToken().
if (databaseToken) {
await requestContext.run({ databaseToken }, handlePost);
} else {
await handlePost();
}
} else if (req.method === "GET" || req.method === "DELETE") {
logger.warn(
`[MCP Transport Warn] Method Not Allowed in Stateless Mode: ${req.method}`,
Expand Down
172 changes: 171 additions & 1 deletion packages/server/src/service/connection.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@ import fs from "fs/promises";
import path from "path";
import sinon from "sinon";
import { DuckDBConnection } from "@malloydata/db-duckdb";
import { createProjectConnections, testConnectionConfig } from "./connection";
import { SnowflakeConnection } from "@malloydata/db-snowflake";
import { Connection } from "@malloydata/malloy";
import {
closeOAuthConnections,
createOAuthSnowflakeConnection,
createProjectConnections,
getOAuthConfig,
getRequestConnections,
setOAuthConfig,
SnowflakeOAuthConfig,
testConnectionConfig,
} from "./connection";
import { requestContext } from "../request_context";
import { components } from "../api";

type ApiConnection = components["schemas"]["Connection"];
Expand Down Expand Up @@ -1329,3 +1341,161 @@ describe("connection integration tests", () => {
);
});
});

describe("OAuth pass-through helpers", () => {
const baseConfig: SnowflakeOAuthConfig = {
name: "sf-conn",
account: "test-account.us-east-1",
warehouse: "COMPUTE_WH",
database: "ANALYTICS",
schema: "PUBLIC",
role: "ANALYST",
};

afterEach(() => {
sinon.restore();
});

describe("setOAuthConfig / getOAuthConfig", () => {
it("should store and retrieve OAuth config by name", () => {
setOAuthConfig("my-conn", baseConfig);
const stored = getOAuthConfig("my-conn");
expect(stored).toEqual(baseConfig);
});

it("should return undefined for unknown connection name", () => {
expect(getOAuthConfig("non-existent")).toBeUndefined();
});

it("should overwrite previous config for the same name", () => {
setOAuthConfig("my-conn", baseConfig);
const updated = { ...baseConfig, warehouse: "LARGE_WH" };
setOAuthConfig("my-conn", updated);
expect(getOAuthConfig("my-conn")?.warehouse).toBe("LARGE_WH");
});
});

describe("createOAuthSnowflakeConnection", () => {
it("should create a SnowflakeConnection instance", () => {
const conn = createOAuthSnowflakeConnection(
baseConfig,
"oauth-token-123",
);
expect(conn).toBeInstanceOf(SnowflakeConnection);
});

it("should work with minimal config (no optional fields)", () => {
const minimalConfig: SnowflakeOAuthConfig = {
name: "minimal",
account: "acct",
warehouse: "WH",
};
const conn = createOAuthSnowflakeConnection(minimalConfig, "token");
expect(conn).toBeInstanceOf(SnowflakeConnection);
});
});

describe("getRequestConnections", () => {
it("should return original map when no database token is present", () => {
// No requestContext.run() — getDatabaseToken() returns undefined
const original = new Map<string, Connection>([
["conn1", { name: "conn1" } as unknown as Connection],
]);
const result = getRequestConnections(original);

expect(result.connections).toBe(original);
expect(result.oauthConnections).toHaveLength(0);
});

it("should override connections that have OAuth config when token present", () => {
const connName = "sf-conn-override-test";
setOAuthConfig(connName, baseConfig);

const mockSfConn = {
name: connName,
} as unknown as Connection;
const mockOtherConn = {
name: "duckdb",
} as unknown as Connection;
const original = new Map<string, Connection>([
[connName, mockSfConn],
["duckdb", mockOtherConn],
]);

// Run inside requestContext so getDatabaseToken() returns a token
requestContext.run({ databaseToken: "user-token" }, () => {
const result = getRequestConnections(original);

// The overridden map should have the same keys
expect(result.connections.size).toBe(2);
// sf-conn should be replaced (different instance)
expect(result.connections.get(connName)).not.toBe(mockSfConn);
expect(result.connections.get(connName)).toBeInstanceOf(
SnowflakeConnection,
);
// duckdb should be unchanged
expect(result.connections.get("duckdb")).toBe(mockOtherConn);
// One OAuth connection created
expect(result.oauthConnections).toHaveLength(1);
});
});

it("should not override connections without OAuth config", () => {
// Use a unique connection name that has no OAuth config registered
const original = new Map<string, Connection>([
[
"sf-conn-no-config",
{ name: "sf-conn-no-config" } as unknown as Connection,
],
]);

requestContext.run({ databaseToken: "user-token" }, () => {
const result = getRequestConnections(original);

expect(result.connections.size).toBe(1);
expect(result.oauthConnections).toHaveLength(0);
});
});
});

describe("closeOAuthConnections", () => {
it("should call close() on each connection", async () => {
const conn1 = { close: sinon.stub().resolves() };
const conn2 = { close: sinon.stub().resolves() };

await closeOAuthConnections([
conn1 as unknown as Connection,
conn2 as unknown as Connection,
]);

expect(conn1.close.calledOnce).toBe(true);
expect(conn2.close.calledOnce).toBe(true);
});

it("should handle connections without a close method", async () => {
const conn = {} as unknown as Connection;
// Should not throw
await closeOAuthConnections([conn]);
});

it("should catch errors and continue closing", async () => {
const failing = {
close: sinon.stub().rejects(new Error("close failed")),
};
const succeeding = { close: sinon.stub().resolves() };

await closeOAuthConnections([
failing as unknown as Connection,
succeeding as unknown as Connection,
]);

expect(failing.close.calledOnce).toBe(true);
expect(succeeding.close.calledOnce).toBe(true);
});

it("should do nothing for an empty array", async () => {
// Should not throw
await closeOAuthConnections([]);
});
});
});
Loading
Loading