Skip to content

Commit ea2cdae

Browse files
authored
feat(presto/trino-driver): Support custom headers (#10902)
1 parent 6e01f98 commit ea2cdae

5 files changed

Lines changed: 105 additions & 12 deletions

File tree

packages/cubejs-prestodb-driver/src/PrestoDriver.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ export type PrestoDriverConfiguration = PrestoDriverExportBucket & {
5353
dataSource?: string;
5454
queryTimeout?: number;
5555
preAggregations?: boolean;
56+
useSelectTestConnection?: boolean;
57+
// @see https://trino.io/docs/current/develop/client-protocol.html
58+
headers?: Record<string, string>;
5659
};
5760

5861
const SUPPORTED_BUCKET_TYPES = ['gcs', 's3'];
@@ -94,7 +97,8 @@ export class PrestoDriver extends BaseDriver implements DriverInterface {
9497
throw new Error('Both user/password and auth token are set. Please remove password or token.');
9598
}
9699

97-
this.useSelectTestConnection = getEnv('dbUseSelectTestConnection', { dataSource, preAggregations });
100+
this.useSelectTestConnection = config.useSelectTestConnection ??
101+
getEnv('dbUseSelectTestConnection', { dataSource, preAggregations });
98102

99103
this.config = {
100104
host: getEnv('dbHost', { dataSource, preAggregations }),
@@ -174,6 +178,7 @@ export class PrestoDriver extends BaseDriver implements DriverInterface {
174178
this.client.execute({
175179
query,
176180
schema: this.config.schema || 'default',
181+
headers: this.config.headers,
177182
session: this.config.queryTimeout ? `query_max_run_time=${this.config.queryTimeout}s` : undefined,
178183
columns: (error: any, columns: TableStructure) => {
179184
resolve({
@@ -202,6 +207,7 @@ export class PrestoDriver extends BaseDriver implements DriverInterface {
202207
this.client.execute({
203208
query,
204209
schema: this.config.schema || 'default',
210+
headers: this.config.headers,
205211
data: (error: any, data: any[], columns: TableStructure) => {
206212
const normalData = this.normalizeResultOverColumns(data, columns);
207213
fullData = concat(normalData, fullData);
@@ -312,9 +318,7 @@ export class PrestoDriver extends BaseDriver implements DriverInterface {
312318
}
313319

314320
if (!SUPPORTED_BUCKET_TYPES.includes(this.config.bucketType as string)) {
315-
throw new Error(`Unsupported export bucket type: ${
316-
this.config.bucketType
317-
}`);
321+
throw new Error(`Unsupported export bucket type: ${this.config.bucketType}`);
318322
}
319323

320324
const types = options.query
@@ -336,7 +340,7 @@ export class PrestoDriver extends BaseDriver implements DriverInterface {
336340
return { schema, tableName };
337341
}
338342

339-
private generateTableColumnsForExport(types: {name: string, type: string}[]) {
343+
private generateTableColumnsForExport(types: { name: string, type: string }[]) {
340344
return types.map((c) => `CAST(${c.name} AS varchar) ${c.name}`).join(', ');
341345
}
342346

@@ -360,7 +364,7 @@ export class PrestoDriver extends BaseDriver implements DriverInterface {
360364
});
361365
}
362366

363-
private async unloadGeneric(params: {tableFullName: string, typeSql: string, typeParams: any[], fromSql: string, fromParams: any[]}) {
367+
private async unloadGeneric(params: { tableFullName: string, typeSql: string, typeParams: any[], fromSql: string, fromParams: any[] }) {
364368
if (!this.config.exportBucket) {
365369
throw new Error('Export bucket is not configured.');
366370
}

packages/cubejs-trino-driver/package.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
"build": "rm -rf dist && npm run tsc",
2222
"tsc": "tsc",
2323
"watch": "tsc -w",
24-
"integration": "jest dist/test",
25-
"integration:trino": "jest dist/test",
24+
"unit": "jest dist/test/unit",
25+
"integration": "jest dist/test/integration",
26+
"integration:trino": "jest dist/test/integration",
2627
"lint": "eslint src/* --ext .ts",
2728
"lint:fix": "eslint --fix src/* --ext .ts"
2829
},

packages/cubejs-trino-driver/src/TrinoDriver.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ export class TrinoDriver extends PrestoDriver {
1717
return this.testConnectionViaSelect();
1818
}
1919

20-
const { host, port, ssl, basic_auth: basicAuth, custom_auth: customAuth } = this.config;
20+
const { host, port, ssl, basic_auth: basicAuth, custom_auth: customAuth, headers: extraHeaders } = this.config;
2121
const protocol = ssl ? 'https' : 'http';
2222
const url = `${protocol}://${host}:${port}/v1/info`;
23-
const headers: Record<string, string> = {};
23+
const headers: Record<string, string> = { ...extraHeaders };
2424

2525
if (customAuth) {
2626
headers.Authorization = customAuth;

packages/cubejs-trino-driver/test/trino-driver.test.ts renamed to packages/cubejs-trino-driver/test/integration/trino-driver.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { TrinoDriver } from '../src/TrinoDriver';
1+
import { TrinoDriver } from '../../src/TrinoDriver';
22

33
const path = require('path');
44
const { DockerComposeEnvironment, Wait } = require('testcontainers');
@@ -37,7 +37,7 @@ describe('TrinoDriver', () => {
3737
}
3838

3939
const dc = new DockerComposeEnvironment(
40-
path.resolve(path.dirname(__filename), '../../'),
40+
path.resolve(path.dirname(__filename), '../../../'),
4141
'docker-compose.yml'
4242
);
4343

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import { TrinoDriver } from '../../src/TrinoDriver';
2+
3+
const mockFetch: jest.Mock = jest.fn();
4+
const mockExecute: jest.Mock = jest.fn();
5+
6+
jest.mock('node-fetch', () => ({
7+
__esModule: true,
8+
default: (...args: any[]) => mockFetch(...args),
9+
}));
10+
11+
jest.mock('@cubejs-backend/schema-compiler', () => ({
12+
PrestodbQuery: class { },
13+
}));
14+
15+
jest.mock('presto-client', () => ({
16+
Client: jest.fn().mockImplementation(() => ({
17+
execute: (...args: any[]) => mockExecute(...args),
18+
nodes: jest.fn(),
19+
})),
20+
}));
21+
22+
describe('TrinoDriver headers', () => {
23+
beforeEach(() => {
24+
mockFetch.mockReset();
25+
mockFetch.mockResolvedValue({
26+
ok: true,
27+
status: 200,
28+
statusText: 'OK',
29+
text: async () => '',
30+
});
31+
mockExecute.mockReset();
32+
// Default: synthesize a successful query result with no rows.
33+
mockExecute.mockImplementation((opts: any) => {
34+
opts.success?.();
35+
});
36+
});
37+
38+
it('forwards configured custom headers on testConnection()', async () => {
39+
const driver = new TrinoDriver({
40+
host: 'trino.local',
41+
port: '8080',
42+
// See https://trino.io/docs/current/develop/client-protocol.html for
43+
// the upstream list of `X-Trino-*` headers accepted by the coordinator.
44+
headers: {
45+
'X-Trino-Source': 'cube',
46+
'X-Trino-Routing-Group': 'etl',
47+
'X-Trino-Client-Tags': 'user=alice@example.com',
48+
'X-Mozart-User-Token': 'abc.def.ghi',
49+
},
50+
});
51+
52+
await driver.testConnection();
53+
54+
expect(mockFetch).toHaveBeenCalledTimes(1);
55+
const [url, options] = mockFetch.mock.calls[0];
56+
expect(url).toBe('http://trino.local:8080/v1/info');
57+
expect(options.method).toBe('GET');
58+
expect(options.headers).toMatchObject({
59+
'X-Trino-Source': 'cube',
60+
'X-Trino-Routing-Group': 'etl',
61+
'X-Trino-Client-Tags': 'user=alice@example.com',
62+
'X-Mozart-User-Token': 'abc.def.ghi',
63+
});
64+
});
65+
66+
it('forwards configured custom headers when useSelectTestConnection is enabled', async () => {
67+
const driver = new TrinoDriver({
68+
host: 'trino.local',
69+
port: '8080',
70+
useSelectTestConnection: true,
71+
headers: {
72+
'X-Trino-Source': 'cube',
73+
'X-Trino-Routing-Group': 'etl',
74+
},
75+
});
76+
77+
await driver.testConnection();
78+
79+
expect(mockFetch).not.toHaveBeenCalled();
80+
expect(mockExecute).toHaveBeenCalledTimes(1);
81+
const [executeOpts] = mockExecute.mock.calls[0];
82+
expect(executeOpts.query).toBe('SELECT 1');
83+
expect(executeOpts.headers).toEqual({
84+
'X-Trino-Source': 'cube',
85+
'X-Trino-Routing-Group': 'etl',
86+
});
87+
});
88+
});

0 commit comments

Comments
 (0)