|
1 | 1 | import { asyncBufferFromFile, asyncBufferFromUrl, parquetMetadataAsync } from 'hyparquet' |
2 | 2 | import { compressors } from 'hyparquet-compressors' |
3 | | -import { collect, executeSql } from 'squirreling' |
| 3 | +import { collect, executeSql, parseSql, planSql } from 'squirreling' |
4 | 4 | import { parquetDataSource } from 'hyperparam' |
5 | 5 | import { markdownTable } from './markdownTable.js' |
6 | 6 |
|
7 | 7 | const maxRows = 100 |
8 | 8 |
|
| 9 | +/** |
| 10 | + * Recursively collect table names from all Scan/Count nodes in a query plan. |
| 11 | + * |
| 12 | + * @param {import('squirreling').QueryPlan} plan |
| 13 | + * @returns {Set<string>} |
| 14 | + */ |
| 15 | +function scanTables(plan) { |
| 16 | + /** @type {Set<string>} */ |
| 17 | + const tables = new Set() |
| 18 | + /** @param {import('squirreling').QueryPlan} node */ |
| 19 | + function walk(node) { |
| 20 | + if (!node) return |
| 21 | + if (node.type === 'Scan' || node.type === 'Count') { |
| 22 | + tables.add(node.table) |
| 23 | + } else if ('child' in node) { |
| 24 | + walk(node.child) |
| 25 | + } |
| 26 | + if ('left' in node) walk(node.left) |
| 27 | + if ('right' in node) walk(node.right) |
| 28 | + } |
| 29 | + walk(plan) |
| 30 | + return tables |
| 31 | +} |
| 32 | + |
| 33 | +/** |
| 34 | + * Build an AsyncDataSource for a file path or URL. |
| 35 | + * |
| 36 | + * @param {string} file |
| 37 | + * @returns {Promise<import('squirreling').AsyncDataSource>} |
| 38 | + */ |
| 39 | +async function fileToDataSource(file) { |
| 40 | + const asyncBuffer = file.startsWith('http://') || file.startsWith('https://') |
| 41 | + ? await asyncBufferFromUrl({ url: file }) |
| 42 | + : await asyncBufferFromFile(file) |
| 43 | + const metadata = await parquetMetadataAsync(asyncBuffer) |
| 44 | + return parquetDataSource(asyncBuffer, metadata, compressors) |
| 45 | +} |
| 46 | + |
| 47 | +/** |
| 48 | + * Execute a SQL query by extracting table names from the plan and loading them |
| 49 | + * as parquet data sources. Returns a formatted result string. |
| 50 | + * |
| 51 | + * @param {string} query |
| 52 | + * @param {boolean} [truncate] |
| 53 | + * @returns {Promise<string>} |
| 54 | + */ |
| 55 | +export async function runSqlQuery(query, truncate = true) { |
| 56 | + const startTime = performance.now() |
| 57 | + const ast = parseSql({ query }) |
| 58 | + const plan = planSql({ query: ast }) |
| 59 | + const tableNames = scanTables(plan) |
| 60 | + |
| 61 | + /** @type {Record<string, import('squirreling').AsyncDataSource>} */ |
| 62 | + const tables = {} |
| 63 | + await Promise.all([...tableNames].map(async name => { |
| 64 | + tables[name] = await fileToDataSource(name) |
| 65 | + })) |
| 66 | + |
| 67 | + const results = await collect(executeSql({ tables, query })) |
| 68 | + const queryTime = (performance.now() - startTime) / 1000 |
| 69 | + |
| 70 | + if (results.length === 0) { |
| 71 | + return `Query executed successfully but returned no results in ${queryTime.toFixed(1)} seconds.` |
| 72 | + } |
| 73 | + |
| 74 | + const rowCount = results.length |
| 75 | + const maxChars = truncate ? 1000 : 10000 |
| 76 | + let content = `Query returned ${rowCount} row${rowCount === 1 ? '' : 's'} in ${queryTime.toFixed(1)} seconds.\n\n` |
| 77 | + content += markdownTable(results.slice(0, maxRows), maxChars) |
| 78 | + if (rowCount > maxRows) { |
| 79 | + content += `\n\n... and ${rowCount - maxRows} more row${rowCount - maxRows === 1 ? '' : 's'} (showing first ${maxRows} rows)` |
| 80 | + } |
| 81 | + return content |
| 82 | +} |
| 83 | + |
9 | 84 | /** |
10 | 85 | * @import { ToolHandler } from '../types.d.ts' |
11 | 86 | * @type {ToolHandler} |
|
0 commit comments