diff --git a/api-doc.yaml b/api-doc.yaml index 35890498..fa3939ec 100644 --- a/api-doc.yaml +++ b/api-doc.yaml @@ -2342,6 +2342,15 @@ components: responseTimeoutMilliseconds: type: integer description: Query response timeout in milliseconds + oauthPassthrough: + type: boolean + description: >- + When true, the connection accepts per-request OAuth tokens via + the X-Database-Token HTTP header. Queries run as the authenticated + user's Snowflake role instead of using static credentials. The + static credentials (password/privateKey) are used as a fallback + when no token is provided. + default: false TrinoConnection: type: object diff --git a/packages/server/src/request_context.spec.ts b/packages/server/src/request_context.spec.ts new file mode 100644 index 00000000..f745f597 --- /dev/null +++ b/packages/server/src/request_context.spec.ts @@ -0,0 +1,43 @@ +import { describe, expect, it } from "bun:test"; +import { getDatabaseToken, requestContext } from "./request_context"; + +describe("request_context", () => { + describe("getDatabaseToken", () => { + it("should return undefined when no context is active", () => { + expect(getDatabaseToken()).toBeUndefined(); + }); + + it("should return the token from the active context", async () => { + let captured: string | undefined; + await requestContext.run({ databaseToken: "test-token" }, () => { + captured = getDatabaseToken(); + }); + expect(captured).toBe("test-token"); + }); + + it("should return undefined when context has no token", async () => { + let captured: string | undefined = "should-be-overwritten"; + await requestContext.run({}, () => { + captured = getDatabaseToken(); + }); + expect(captured).toBeUndefined(); + }); + + it("should isolate tokens across concurrent contexts", async () => { + const results: (string | undefined)[] = []; + await Promise.all([ + requestContext.run({ databaseToken: "token-a" }, async () => { + // Simulate async work + await new Promise((r) => setTimeout(r, 10)); + results.push(getDatabaseToken()); + }), + requestContext.run({ databaseToken: "token-b" }, async () => { + results.push(getDatabaseToken()); + }), + ]); + expect(results).toContain("token-a"); + expect(results).toContain("token-b"); + expect(results).toHaveLength(2); + }); + }); +}); diff --git a/packages/server/src/request_context.ts b/packages/server/src/request_context.ts new file mode 100644 index 00000000..8f5350ba --- /dev/null +++ b/packages/server/src/request_context.ts @@ -0,0 +1,21 @@ +/** + * Per-request context using AsyncLocalStorage. + * + * Carries request-scoped data (e.g., OAuth tokens) through the call stack + * without threading parameters through every function. Used by the MCP + * endpoint to propagate database tokens for per-user connection creation. + */ + +import { AsyncLocalStorage } from "node:async_hooks"; + +export interface RequestContext { + /** OAuth token for per-request database authentication. */ + databaseToken?: string; +} + +export const requestContext = new AsyncLocalStorage(); + +/** Returns the database token from the current request context, if any. */ +export function getDatabaseToken(): string | undefined { + return requestContext.getStore()?.databaseToken; +} diff --git a/packages/server/src/server.ts b/packages/server/src/server.ts index 72a862f7..01a9ce1c 100644 --- a/packages/server/src/server.ts +++ b/packages/server/src/server.ts @@ -29,6 +29,7 @@ import { import { logger, loggerMiddleware } from "./logger"; import { initializeMcpServer } from "./mcp/server"; +import { requestContext } from "./request_context"; import { ProjectStore } from "./service/project_store"; // Parse command line arguments function parseArgs() { @@ -142,37 +143,55 @@ mcpApp.all(MCP_ENDPOINT, async (req, res) => { try { if (req.method === "POST") { - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); - - transport.onclose = () => { - logger.info( - `[MCP Transport Info] Stateless transport closed for a request.`, - ); - }; - transport.onerror = (err: Error) => { - logger.error(`[MCP Transport Error] Stateless transport error:`, { - error: err, + // Extract database token from request header for per-user OAuth pass-through. + // When present, downstream connections use this token instead of static credentials. + const databaseToken = + (req.headers["x-database-token"] as string) || undefined; + + const handlePost = async () => { + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined, }); - }; - const requestMcpServer = initializeMcpServer(projectStore); - await requestMcpServer.connect(transport); - - res.on("close", () => { - logger.info( - "[MCP Transport Info] Response closed, cleaning up stateless transport.", - ); - transport.close().catch((err) => { + transport.onclose = () => { + logger.info( + `[MCP Transport Info] Stateless transport closed for a request.`, + ); + }; + transport.onerror = (err: Error) => { logger.error( - "[MCP Transport Error] Error closing stateless transport on response close:", - { error: err }, + `[MCP Transport Error] Stateless transport error:`, + { + error: err, + }, ); + }; + + const requestMcpServer = initializeMcpServer(projectStore); + await requestMcpServer.connect(transport); + + res.on("close", () => { + logger.info( + "[MCP Transport Info] Response closed, cleaning up stateless transport.", + ); + transport.close().catch((err) => { + logger.error( + "[MCP Transport Error] Error closing stateless transport on response close:", + { error: err }, + ); + }); }); - }); - await transport.handleRequest(req, res, req.body); + await transport.handleRequest(req, res, req.body); + }; + + // Run the handler within request context so downstream code can + // access the database token via getDatabaseToken(). + if (databaseToken) { + await requestContext.run({ databaseToken }, handlePost); + } else { + await handlePost(); + } } else if (req.method === "GET" || req.method === "DELETE") { logger.warn( `[MCP Transport Warn] Method Not Allowed in Stateless Mode: ${req.method}`, diff --git a/packages/server/src/service/connection.spec.ts b/packages/server/src/service/connection.spec.ts index 38ad6116..d80545af 100644 --- a/packages/server/src/service/connection.spec.ts +++ b/packages/server/src/service/connection.spec.ts @@ -3,7 +3,19 @@ import fs from "fs/promises"; import path from "path"; import sinon from "sinon"; import { DuckDBConnection } from "@malloydata/db-duckdb"; -import { createProjectConnections, testConnectionConfig } from "./connection"; +import { SnowflakeConnection } from "@malloydata/db-snowflake"; +import { Connection } from "@malloydata/malloy"; +import { + closeOAuthConnections, + createOAuthSnowflakeConnection, + createProjectConnections, + getOAuthConfig, + getRequestConnections, + setOAuthConfig, + SnowflakeOAuthConfig, + testConnectionConfig, +} from "./connection"; +import { requestContext } from "../request_context"; import { components } from "../api"; type ApiConnection = components["schemas"]["Connection"]; @@ -1329,3 +1341,161 @@ describe("connection integration tests", () => { ); }); }); + +describe("OAuth pass-through helpers", () => { + const baseConfig: SnowflakeOAuthConfig = { + name: "sf-conn", + account: "test-account.us-east-1", + warehouse: "COMPUTE_WH", + database: "ANALYTICS", + schema: "PUBLIC", + role: "ANALYST", + }; + + afterEach(() => { + sinon.restore(); + }); + + describe("setOAuthConfig / getOAuthConfig", () => { + it("should store and retrieve OAuth config by name", () => { + setOAuthConfig("my-conn", baseConfig); + const stored = getOAuthConfig("my-conn"); + expect(stored).toEqual(baseConfig); + }); + + it("should return undefined for unknown connection name", () => { + expect(getOAuthConfig("non-existent")).toBeUndefined(); + }); + + it("should overwrite previous config for the same name", () => { + setOAuthConfig("my-conn", baseConfig); + const updated = { ...baseConfig, warehouse: "LARGE_WH" }; + setOAuthConfig("my-conn", updated); + expect(getOAuthConfig("my-conn")?.warehouse).toBe("LARGE_WH"); + }); + }); + + describe("createOAuthSnowflakeConnection", () => { + it("should create a SnowflakeConnection instance", () => { + const conn = createOAuthSnowflakeConnection( + baseConfig, + "oauth-token-123", + ); + expect(conn).toBeInstanceOf(SnowflakeConnection); + }); + + it("should work with minimal config (no optional fields)", () => { + const minimalConfig: SnowflakeOAuthConfig = { + name: "minimal", + account: "acct", + warehouse: "WH", + }; + const conn = createOAuthSnowflakeConnection(minimalConfig, "token"); + expect(conn).toBeInstanceOf(SnowflakeConnection); + }); + }); + + describe("getRequestConnections", () => { + it("should return original map when no database token is present", () => { + // No requestContext.run() — getDatabaseToken() returns undefined + const original = new Map([ + ["conn1", { name: "conn1" } as unknown as Connection], + ]); + const result = getRequestConnections(original); + + expect(result.connections).toBe(original); + expect(result.oauthConnections).toHaveLength(0); + }); + + it("should override connections that have OAuth config when token present", () => { + const connName = "sf-conn-override-test"; + setOAuthConfig(connName, baseConfig); + + const mockSfConn = { + name: connName, + } as unknown as Connection; + const mockOtherConn = { + name: "duckdb", + } as unknown as Connection; + const original = new Map([ + [connName, mockSfConn], + ["duckdb", mockOtherConn], + ]); + + // Run inside requestContext so getDatabaseToken() returns a token + requestContext.run({ databaseToken: "user-token" }, () => { + const result = getRequestConnections(original); + + // The overridden map should have the same keys + expect(result.connections.size).toBe(2); + // sf-conn should be replaced (different instance) + expect(result.connections.get(connName)).not.toBe(mockSfConn); + expect(result.connections.get(connName)).toBeInstanceOf( + SnowflakeConnection, + ); + // duckdb should be unchanged + expect(result.connections.get("duckdb")).toBe(mockOtherConn); + // One OAuth connection created + expect(result.oauthConnections).toHaveLength(1); + }); + }); + + it("should not override connections without OAuth config", () => { + // Use a unique connection name that has no OAuth config registered + const original = new Map([ + [ + "sf-conn-no-config", + { name: "sf-conn-no-config" } as unknown as Connection, + ], + ]); + + requestContext.run({ databaseToken: "user-token" }, () => { + const result = getRequestConnections(original); + + expect(result.connections.size).toBe(1); + expect(result.oauthConnections).toHaveLength(0); + }); + }); + }); + + describe("closeOAuthConnections", () => { + it("should call close() on each connection", async () => { + const conn1 = { close: sinon.stub().resolves() }; + const conn2 = { close: sinon.stub().resolves() }; + + await closeOAuthConnections([ + conn1 as unknown as Connection, + conn2 as unknown as Connection, + ]); + + expect(conn1.close.calledOnce).toBe(true); + expect(conn2.close.calledOnce).toBe(true); + }); + + it("should handle connections without a close method", async () => { + const conn = {} as unknown as Connection; + // Should not throw + await closeOAuthConnections([conn]); + }); + + it("should catch errors and continue closing", async () => { + const failing = { + close: sinon.stub().rejects(new Error("close failed")), + }; + const succeeding = { close: sinon.stub().resolves() }; + + await closeOAuthConnections([ + failing as unknown as Connection, + succeeding as unknown as Connection, + ]); + + expect(failing.close.calledOnce).toBe(true); + expect(succeeding.close.calledOnce).toBe(true); + }); + + it("should do nothing for an empty array", async () => { + // Should not throw + await closeOAuthConnections([]); + }); + }); +}); diff --git a/packages/server/src/service/connection.ts b/packages/server/src/service/connection.ts index da32ee39..a321885b 100644 --- a/packages/server/src/service/connection.ts +++ b/packages/server/src/service/connection.ts @@ -13,8 +13,116 @@ import { v4 as uuidv4 } from "uuid"; import { components } from "../api"; import { TEMP_DIR_PATH } from "../constants"; import { logAxiosError, logger } from "../logger"; +import { getDatabaseToken } from "../request_context"; import { CloudStorageCredentials } from "./gcs_s3_utils"; +/** Configuration needed to create per-request OAuth Snowflake connections. */ +export interface SnowflakeOAuthConfig { + name: string; + account: string; + warehouse: string; + database?: string; + schema?: string; + role?: string; + responseTimeoutMilliseconds?: number; +} + +/** + * Creates a per-request SnowflakeConnection using an OAuth token. + * The caller is responsible for closing the connection after use. + */ +export function createOAuthSnowflakeConnection( + config: SnowflakeOAuthConfig, + token: string, +): SnowflakeConnection { + return new SnowflakeConnection(config.name, { + connOptions: { + account: config.account, + authenticator: "OAUTH", + token, + warehouse: config.warehouse, + database: config.database, + schema: config.schema, + role: config.role, + timeout: config.responseTimeoutMilliseconds, + }, + }); +} + +/** + * Metadata attached to Snowflake connections in the connection map when + * oauthPassthrough is enabled. Allows downstream code to create per-request + * OAuth connections using the same base config. + */ +const oauthConfigMap = new Map(); + +/** Store OAuth config for a named connection. */ +export function setOAuthConfig( + connectionName: string, + config: SnowflakeOAuthConfig, +): void { + oauthConfigMap.set(connectionName, config); +} + +/** Get OAuth config for a named connection, if oauthPassthrough is enabled. */ +export function getOAuthConfig( + connectionName: string, +): SnowflakeOAuthConfig | undefined { + return oauthConfigMap.get(connectionName); +} + +/** + * If the current request has a database token and any connection in the map + * has oauthPassthrough enabled, returns a new connection map with per-request + * OAuth Snowflake connections. Otherwise returns the original map unchanged. + * + * The caller must call closeOAuthConnections() on the returned connections + * after the query completes. + */ +export function getRequestConnections(connections: Map): { + connections: Map; + oauthConnections: Connection[]; +} { + const token = getDatabaseToken(); + if (!token) { + return { connections, oauthConnections: [] }; + } + + const overridden = new Map(connections); + const oauthConnections: Connection[] = []; + + for (const [name] of connections) { + const config = getOAuthConfig(name); + if (config) { + const oauthConn = createOAuthSnowflakeConnection(config, token); + overridden.set(name, oauthConn); + oauthConnections.push(oauthConn); + logger.info( + `[OAuth] Created per-request Snowflake connection for '${name}'`, + ); + } + } + + return { connections: overridden, oauthConnections }; +} + +/** Close per-request OAuth connections to prevent session leaks. */ +export async function closeOAuthConnections( + oauthConnections: Connection[], +): Promise { + for (const conn of oauthConnections) { + try { + if ("close" in conn && typeof conn.close === "function") { + await conn.close(); + } + } catch (err) { + logger.warn(`[OAuth] Error closing per-request connection`, { + error: err, + }); + } + } +} + type AttachedDatabase = components["schemas"]["AttachedDatabase"]; type ApiConnection = components["schemas"]["Connection"]; type ApiConnectionAttributes = components["schemas"]["ConnectionAttributes"]; @@ -1054,6 +1162,23 @@ export async function createProjectConnections( throw new Error("Snowflake warehouse is required."); } + // Store OAuth config for per-request connection creation + if (connection.snowflakeConnection.oauthPassthrough) { + setOAuthConfig(connection.name, { + name: connection.name, + account: connection.snowflakeConnection.account, + warehouse: connection.snowflakeConnection.warehouse, + database: connection.snowflakeConnection.database, + schema: connection.snowflakeConnection.schema, + role: connection.snowflakeConnection.role, + responseTimeoutMilliseconds: + connection.snowflakeConnection.responseTimeoutMilliseconds, + }); + logger.info( + `[OAuth] Snowflake connection '${connection.name}' has oauthPassthrough enabled`, + ); + } + let privateKeyPath = undefined; if (connection.snowflakeConnection.privateKey) { diff --git a/packages/server/src/service/model.spec.ts b/packages/server/src/service/model.spec.ts index 6b17eb04..1ef07d47 100644 --- a/packages/server/src/service/model.spec.ts +++ b/packages/server/src/service/model.spec.ts @@ -1,9 +1,10 @@ import { MalloyError, Runtime } from "@malloydata/malloy"; -import { describe, expect, it } from "bun:test"; +import { afterEach, describe, expect, it } from "bun:test"; import fs from "fs/promises"; import sinon from "sinon"; import { BadRequestError, ModelNotFoundError } from "../errors"; +// requestContext import not needed — tests exercise the no-token (default) path import { Model, ModelType } from "./model"; describe("service/model", () => { @@ -12,6 +13,10 @@ describe("service/model", () => { const mockPackagePath = "mockPackagePath"; const mockModelPath = "mockModel.malloy"; + afterEach(() => { + sinon.restore(); + }); + it("should create a Model instance", async () => { sinon.stub(Model, "getModelRuntime").resolves({ runtime: sinon.createStubInstance(Runtime), @@ -34,8 +39,6 @@ describe("service/model", () => { ); expect(model).toBeInstanceOf(Model); expect(model.getPath()).toBe(mockModelPath); - - sinon.restore(); }); it("should handle ModelNotFoundError correctly", async () => { @@ -47,8 +50,6 @@ describe("service/model", () => { new Map(), ); }).toThrowError(`${mockModelPath} does not exist.`); - - sinon.restore(); }); describe("instance methods", () => { @@ -56,6 +57,7 @@ describe("service/model", () => { it("should return the correct modelPath", async () => { const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, "model", @@ -69,8 +71,6 @@ describe("service/model", () => { ); expect(model.getPath()).toBe(mockModelPath); - - sinon.restore(); }); }); @@ -79,6 +79,7 @@ describe("service/model", () => { const modelType = "model"; const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, modelType, @@ -92,8 +93,6 @@ describe("service/model", () => { ); expect(model.getType()).toBe(modelType); - - sinon.restore(); }); }); @@ -101,6 +100,7 @@ describe("service/model", () => { it("should throw ModelCompilationError if a compilation error exists", async () => { const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, "model", @@ -116,13 +116,12 @@ describe("service/model", () => { await expect(async () => { await model.getModel(); }).toThrowError(MalloyError); - - sinon.restore(); }); it("should throw ModelNotFoundError for invalid modelType", async () => { const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, "notebook" as ModelType, @@ -138,8 +137,6 @@ describe("service/model", () => { await expect(async () => { await model.getModel(); }).toThrowError(ModelNotFoundError); - - sinon.restore(); }); }); @@ -147,6 +144,7 @@ describe("service/model", () => { it("should throw ModelCompilationError if a compilation error exists", async () => { const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, "notebook", @@ -162,13 +160,12 @@ describe("service/model", () => { await expect(async () => { await model.getNotebook(); }).toThrowError(Error); - - sinon.restore(); }); it("should throw ModelNotFoundError for invalid modelType", async () => { const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, "model" as ModelType, @@ -184,8 +181,6 @@ describe("service/model", () => { await expect(async () => { await model.getNotebook(); }).toThrowError(ModelNotFoundError); - - sinon.restore(); }); }); @@ -194,6 +189,7 @@ describe("service/model", () => { const error = new Error("Compilation error"); const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, "model", @@ -209,13 +205,12 @@ describe("service/model", () => { await expect(async () => { await model.getQueryResults(); }).toThrowError(BadRequestError); - - sinon.restore(); }); it("should throw BadRequestError if no queryable entities exist", async () => { const model = new Model( packageName, + mockPackagePath, mockModelPath, {}, "model", @@ -231,8 +226,29 @@ describe("service/model", () => { await expect(async () => { await model.getQueryResults(); }).toThrowError(BadRequestError); + }); - sinon.restore(); + it("should not create OAuth runtime when no token present", async () => { + // No requestContext.run() — getDatabaseToken() returns undefined + // so the OAuth path is skipped entirely. + const model = new Model( + packageName, + mockPackagePath, + mockModelPath, + {}, + "model", + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + undefined, + ); + + await expect(async () => { + await model.getQueryResults(); + }).toThrowError(BadRequestError); }); }); }); @@ -249,8 +265,6 @@ describe("service/model", () => { new Map(), ); }).toThrowError(ModelNotFoundError); - - sinon.restore(); }); }); }); diff --git a/packages/server/src/service/model.ts b/packages/server/src/service/model.ts index 992265a1..b9bf6543 100644 --- a/packages/server/src/service/model.ts +++ b/packages/server/src/service/model.ts @@ -28,6 +28,8 @@ import { metrics } from "@opentelemetry/api"; import * as fs from "fs/promises"; import * as path from "path"; import { fileURLToPath } from "url"; +import { closeOAuthConnections, getRequestConnections } from "./connection"; +import { getDatabaseToken } from "../request_context"; import { components } from "../api"; import { MODEL_FILE_SUFFIX, @@ -70,6 +72,7 @@ interface RunnableNotebookCell { export class Model { private packageName: string; + private packagePath: string; private modelPath: string; private dataStyles: DataStyles; private modelType: ModelType; @@ -81,6 +84,8 @@ export class Model { private sourceInfos: Malloy.SourceInfo[] | undefined; private runnableNotebookCells: RunnableNotebookCell[] | undefined; private compilationError: MalloyError | Error | undefined; + /** Retained for per-request OAuth connection override. */ + private connections: Map | undefined; private meter = metrics.getMeter("publisher"); private queryExecutionHistogram = this.meter.createHistogram( "malloy_model_query_duration", @@ -92,6 +97,7 @@ export class Model { constructor( packageName: string, + packagePath: string, modelPath: string, dataStyles: DataStyles, modelType: ModelType, @@ -103,8 +109,10 @@ export class Model { sourceInfos: Malloy.SourceInfo[] | undefined, runnableNotebookCells: RunnableNotebookCell[] | undefined, compilationError: MalloyError | Error | undefined, + connections?: Map, ) { this.packageName = packageName; + this.packagePath = packagePath; this.modelPath = modelPath; this.dataStyles = dataStyles; this.modelType = modelType; @@ -115,6 +123,7 @@ export class Model { this.sourceInfos = sourceInfos; this.runnableNotebookCells = runnableNotebookCells; this.compilationError = compilationError; + this.connections = connections; this.modelInfo = this.modelDef ? modelDefToModelInfo(this.modelDef) : undefined; @@ -197,6 +206,7 @@ export class Model { return new Model( packageName, + packagePath, modelPath, dataStyles, modelType, @@ -207,6 +217,7 @@ export class Model { sourceInfos.length > 0 ? sourceInfos : undefined, runnableNotebookCells, undefined, + connections, ); } catch (error) { let computedError = error; @@ -223,6 +234,7 @@ export class Model { } return new Model( packageName, + packagePath, modelPath, dataStyles, modelType, @@ -312,112 +324,150 @@ export class Model { `Model compilation failed: ${this.compilationError.message}`, ); } - let runnable: QueryMaterializer; if (!this.modelMaterializer || !this.modelDef || !this.modelInfo) throw new BadRequestError("Model has no queryable entities."); - // Wrap loadQuery calls in try-catch to handle query parsing errors - try { - if (!sourceName && !queryName && query) { - runnable = this.modelMaterializer.loadQuery("\n" + query); - } else if (queryName && !query) { - runnable = this.modelMaterializer.loadQuery( - `\nrun: ${sourceName ? sourceName + "->" : ""}${queryName}`, + // When an OAuth token is present in the request context, create a new + // Runtime with per-request Snowflake connections so the query executes + // under the authenticated user's role. + let activeMaterializer = this.modelMaterializer; + let oauthConnections: Connection[] = []; + + const databaseToken = getDatabaseToken(); + if (databaseToken && this.connections) { + const result = getRequestConnections(this.connections); + if (result.oauthConnections.length > 0) { + oauthConnections = result.oauthConnections; + const oauthRuntime = new Runtime({ + urlReader: URL_READER, + connections: new FixedConnectionMap( + result.connections, + "duckdb", + ), + }); + // Re-materialize the model with OAuth connections using the full + // filesystem path so the Runtime can locate the .malloy file. + const fullModelPath = path.join(this.packagePath, this.modelPath); + activeMaterializer = oauthRuntime.loadModel( + new URL(`file://${fullModelPath}`), ); - } else { - const endTime = performance.now(); - const executionTime = endTime - startTime; - this.queryExecutionHistogram.record(executionTime, { + } + } + + try { + let runnable: QueryMaterializer; + + // Wrap loadQuery calls in try-catch to handle query parsing errors + try { + if (!sourceName && !queryName && query) { + runnable = activeMaterializer.loadQuery("\n" + query); + } else if (queryName && !query) { + runnable = activeMaterializer.loadQuery( + `\nrun: ${sourceName ? sourceName + "->" : ""}${queryName}`, + ); + } else { + const endTime = performance.now(); + const executionTime = endTime - startTime; + this.queryExecutionHistogram.record(executionTime, { + "malloy.model.path": this.modelPath, + "malloy.model.query.name": queryName, + "malloy.model.query.source": sourceName, + "malloy.model.query.query": query, + "malloy.model.query.status": "error", + }); + throw new BadRequestError( + "Invalid query request. (Query AND !sourceName) OR (queryName AND sourceName) must be defined.", + ); + } + } catch (error) { + // Re-throw BadRequestError as-is + if (error instanceof BadRequestError) { + throw error; + } + // Re-throw MalloyError as-is (maps to 400) + if (error instanceof MalloyError) { + throw error; + } + // For other query parsing errors, wrap as BadRequestError + const errorMessage = + error instanceof Error ? error.message : String(error); + logger.error("Query parsing error", { + error, + errorMessage, + projectName: this.packageName, + modelPath: this.modelPath, + query, + queryName, + sourceName, + }); + throw new BadRequestError(`Invalid query: ${errorMessage}`); + } + + const rowLimit = + (await runnable.getPreparedResult()).resultExplore.limit || + ROW_LIMIT; + const endTime = performance.now(); + const executionTime = endTime - startTime; + + let queryResults; + try { + queryResults = await runnable.run({ rowLimit }); + } catch (error) { + // Record error metrics + const errorEndTime = performance.now(); + const errorExecutionTime = errorEndTime - startTime; + this.queryExecutionHistogram.record(errorExecutionTime, { "malloy.model.path": this.modelPath, "malloy.model.query.name": queryName, "malloy.model.query.source": sourceName, "malloy.model.query.query": query, "malloy.model.query.status": "error", }); + + // Re-throw Malloy errors as-is (they will be handled by error handler) + if (error instanceof MalloyError) { + throw error; + } + + // For other runtime errors (like divide by zero), throw as BadRequestError + const errorMessage = + error instanceof Error ? error.message : String(error); + logger.error("Query execution error", { + error, + errorMessage, + projectName: this.packageName, + modelPath: this.modelPath, + query, + queryName, + sourceName, + }); throw new BadRequestError( - "Invalid query request. (Query AND !sourceName) OR (queryName AND sourceName) must be defined.", + `Query execution failed: ${errorMessage}`, ); } - } catch (error) { - // Re-throw BadRequestError as-is - if (error instanceof BadRequestError) { - throw error; - } - // Re-throw MalloyError as-is (maps to 400) - if (error instanceof MalloyError) { - throw error; - } - // For other query parsing errors, wrap as BadRequestError - const errorMessage = - error instanceof Error ? error.message : String(error); - logger.error("Query parsing error", { - error, - errorMessage, - projectName: this.packageName, - modelPath: this.modelPath, - query, - queryName, - sourceName, - }); - throw new BadRequestError(`Invalid query: ${errorMessage}`); - } - const rowLimit = - (await runnable.getPreparedResult()).resultExplore.limit || ROW_LIMIT; - const endTime = performance.now(); - const executionTime = endTime - startTime; - - let queryResults; - try { - queryResults = await runnable.run({ rowLimit }); - } catch (error) { - // Record error metrics - const errorEndTime = performance.now(); - const errorExecutionTime = errorEndTime - startTime; - this.queryExecutionHistogram.record(errorExecutionTime, { + this.queryExecutionHistogram.record(executionTime, { "malloy.model.path": this.modelPath, "malloy.model.query.name": queryName, "malloy.model.query.source": sourceName, "malloy.model.query.query": query, - "malloy.model.query.status": "error", + "malloy.model.query.rows_limit": rowLimit, + "malloy.model.query.rows_total": queryResults.totalRows, + "malloy.model.query.connection": queryResults.connectionName, + "malloy.model.query.status": "success", }); - - // Re-throw Malloy errors as-is (they will be handled by error handler) - if (error instanceof MalloyError) { - throw error; + return { + result: API.util.wrapResult(queryResults), + compactResult: queryResults.data.value, + modelInfo: this.modelInfo, + dataStyles: this.dataStyles, + }; + } finally { + // Close per-request OAuth connections to prevent Snowflake session leaks. + if (oauthConnections.length > 0) { + await closeOAuthConnections(oauthConnections); } - - // For other runtime errors (like divide by zero), throw as BadRequestError - const errorMessage = - error instanceof Error ? error.message : String(error); - logger.error("Query execution error", { - error, - errorMessage, - projectName: this.packageName, - modelPath: this.modelPath, - query, - queryName, - sourceName, - }); - throw new BadRequestError(`Query execution failed: ${errorMessage}`); } - - this.queryExecutionHistogram.record(executionTime, { - "malloy.model.path": this.modelPath, - "malloy.model.query.name": queryName, - "malloy.model.query.source": sourceName, - "malloy.model.query.query": query, - "malloy.model.query.rows_limit": rowLimit, - "malloy.model.query.rows_total": queryResults.totalRows, - "malloy.model.query.connection": queryResults.connectionName, - "malloy.model.query.status": "success", - }); - return { - result: API.util.wrapResult(queryResults), - compactResult: queryResults.data.value, - modelInfo: this.modelInfo, - dataStyles: this.dataStyles, - }; } private getStandardModel(): ApiCompiledModel {