Skip to content

Commit b4b3be5

Browse files
yuchen-pipiYu Chenclaude
authored
feat: add SQL dialect support with Spark SQL dialect (lance-format#150)
## Summary - Add `SqlDialect` enum (`Default`, `Spark`, `PostgreSql`, `MySql`, `Sqlite`) and `SparkDialect` implementation using DataFusion's unparser `Dialect` trait - Refactor `to_sql()` to accept an optional `dialect` parameter instead of a separate method per dialect - Add Python API support: `query.to_sql(datasets, dialect="spark")` ### Spark SQL dialect differences - Backtick identifier quoting - `STRING` type instead of `VARCHAR` - `EXTRACT(field FROM expr)` for date parts - `LENGTH()` instead of `CHARACTER_LENGTH()` - `TIMESTAMP` without timezone info - Subqueries in FROM require aliases ### Usage **Rust:** ```rust use lance_graph::{CypherQuery, SqlDialect}; let sql = query.to_sql(datasets, Some(SqlDialect::Spark)).await?; ``` **Python:** ```python sql = query.to_sql(datasets, dialect="spark") ``` ## Test plan - [x] 7 new Spark SQL integration tests (backtick quoting, filters, relationships, complex queries, dialect comparison, PostgreSQL dialect) - [x] 5 unit tests for SparkDialect trait implementation - [x] 12 existing `to_sql` tests updated and passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Yu Chen <yu.chen@databricks.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f237088 commit b4b3be5

7 files changed

Lines changed: 532 additions & 26 deletions

File tree

crates/lance-graph-python/src/graph.rs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ use lance_graph::{
2727
ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause},
2828
CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy,
2929
GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog,
30-
SqlQuery as RustSqlQuery, VectorSearch as RustVectorSearch,
30+
SqlDialect as RustSqlDialect, SqlQuery as RustSqlQuery,
31+
VectorSearch as RustVectorSearch,
3132
};
3233
use pyo3::{
3334
exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError},
@@ -59,6 +60,34 @@ impl From<ExecutionStrategy> for RustExecutionStrategy {
5960
}
6061
}
6162

63+
/// SQL dialect for generating SQL from Cypher queries
64+
#[pyclass(name = "SqlDialect", module = "lance.graph")]
65+
#[derive(Clone, Copy)]
66+
pub enum SqlDialect {
67+
/// Generic SQL (DataFusion default dialect)
68+
Default,
69+
/// Spark SQL dialect
70+
Spark,
71+
/// PostgreSQL dialect
72+
PostgreSql,
73+
/// MySQL dialect
74+
MySql,
75+
/// SQLite dialect
76+
Sqlite,
77+
}
78+
79+
impl From<SqlDialect> for RustSqlDialect {
80+
fn from(dialect: SqlDialect) -> Self {
81+
match dialect {
82+
SqlDialect::Default => RustSqlDialect::Default,
83+
SqlDialect::Spark => RustSqlDialect::Spark,
84+
SqlDialect::PostgreSql => RustSqlDialect::PostgreSql,
85+
SqlDialect::MySql => RustSqlDialect::MySql,
86+
SqlDialect::Sqlite => RustSqlDialect::Sqlite,
87+
}
88+
}
89+
}
90+
6291
/// Distance metric for vector similarity search
6392
#[pyclass(name = "DistanceMetric", module = "lance.graph")]
6493
#[derive(Clone, Copy)]
@@ -494,6 +523,8 @@ impl CypherQuery {
494523
/// ----------
495524
/// datasets : dict
496525
/// Dictionary mapping table names to Lance datasets
526+
/// dialect : SqlDialect, optional
527+
/// SQL dialect to use. Defaults to SqlDialect.Default (generic DataFusion SQL).
497528
///
498529
/// Returns
499530
/// -------
@@ -504,7 +535,15 @@ impl CypherQuery {
504535
/// ------
505536
/// RuntimeError
506537
/// If SQL generation fails
507-
fn to_sql(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
538+
#[pyo3(signature = (datasets, dialect=None))]
539+
fn to_sql(
540+
&self,
541+
py: Python,
542+
datasets: &Bound<'_, PyDict>,
543+
dialect: Option<SqlDialect>,
544+
) -> PyResult<String> {
545+
let sql_dialect = dialect.map(|d| d.into());
546+
508547
// Convert datasets to Arrow RecordBatch map
509548
let arrow_datasets = python_datasets_to_batches(datasets)?;
510549

@@ -513,7 +552,7 @@ impl CypherQuery {
513552

514553
// Execute via runtime
515554
let sql = RT
516-
.block_on(Some(py), inner_query.to_sql(arrow_datasets))?
555+
.block_on(Some(py), inner_query.to_sql(arrow_datasets, sql_dialect))?
517556
.map_err(graph_error_to_pyerr)?;
518557

519558
Ok(sql)
@@ -1545,6 +1584,7 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) ->
15451584
let graph_module = PyModule::new(py, "graph")?;
15461585

15471586
graph_module.add_class::<ExecutionStrategy>()?;
1587+
graph_module.add_class::<SqlDialect>()?;
15481588
graph_module.add_class::<DistanceMetric>()?;
15491589
graph_module.add_class::<GraphConfig>()?;
15501590
graph_module.add_class::<GraphConfigBuilder>()?;

crates/lance-graph/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub mod parameter_substitution;
4747
pub mod parser;
4848
pub mod query;
4949
pub mod semantic;
50+
pub mod spark_dialect;
5051
pub mod sql_catalog;
5152
pub mod sql_query;
5253
pub mod table_readers;
@@ -67,7 +68,7 @@ pub use lance_graph_catalog::{
6768
#[cfg(feature = "unity-catalog")]
6869
pub use lance_graph_catalog::{UnityCatalogConfig, UnityCatalogProvider};
6970
pub use lance_vector_search::VectorSearch;
70-
pub use query::{CypherQuery, ExecutionStrategy};
71+
pub use query::{CypherQuery, ExecutionStrategy, SqlDialect};
7172
pub use sql_query::SqlQuery;
7273
#[cfg(feature = "delta")]
7374
pub use table_readers::DeltaTableReader;

crates/lance-graph/src/query.rs

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,68 @@ use crate::config::GraphConfig;
99
use crate::error::{GraphError, Result};
1010
use crate::logical_plan::LogicalPlanner;
1111
use crate::parser::parse_cypher_query;
12+
use crate::spark_dialect::build_spark_dialect;
1213
use arrow_array::RecordBatch;
1314
use arrow_schema::{Field, Schema, SchemaRef};
15+
use datafusion_sql::unparser::dialect::{
16+
CustomDialect, DefaultDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
17+
};
18+
use datafusion_sql::unparser::Unparser;
1419
use lance_graph_catalog::DirNamespace;
1520
use lance_namespace::models::DescribeTableRequest;
1621
use std::collections::{HashMap, HashSet};
1722
use std::sync::Arc;
1823

24+
/// SQL dialect to use when generating SQL from Cypher queries.
25+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26+
pub enum SqlDialect {
27+
/// Generic SQL (DataFusion default dialect)
28+
#[default]
29+
Default,
30+
/// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.)
31+
Spark,
32+
/// PostgreSQL dialect
33+
PostgreSql,
34+
/// MySQL dialect
35+
MySql,
36+
/// SQLite dialect
37+
Sqlite,
38+
}
39+
40+
/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference.
41+
pub enum DialectUnparser {
42+
Default(DefaultDialect),
43+
Spark(Box<CustomDialect>),
44+
PostgreSql(PostgreSqlDialect),
45+
MySql(MySqlDialect),
46+
Sqlite(SqliteDialect),
47+
}
48+
49+
impl DialectUnparser {
50+
pub fn as_unparser(&self) -> Unparser<'_> {
51+
match self {
52+
DialectUnparser::Default(d) => Unparser::new(d),
53+
DialectUnparser::Spark(d) => Unparser::new(d.as_ref()),
54+
DialectUnparser::PostgreSql(d) => Unparser::new(d),
55+
DialectUnparser::MySql(d) => Unparser::new(d),
56+
DialectUnparser::Sqlite(d) => Unparser::new(d),
57+
}
58+
}
59+
}
60+
61+
impl SqlDialect {
62+
/// Create a `DialectUnparser` configured for this dialect.
63+
pub fn unparser(&self) -> DialectUnparser {
64+
match self {
65+
SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}),
66+
SqlDialect::Spark => DialectUnparser::Spark(Box::new(build_spark_dialect())),
67+
SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}),
68+
SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}),
69+
SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}),
70+
}
71+
}
72+
}
73+
1974
/// Normalize an Arrow schema to have lowercase field names.
2075
///
2176
/// This ensures that column names in the dataset match the normalized
@@ -280,10 +335,10 @@ impl CypherQuery {
280335
self.explain_internal(Arc::new(catalog), ctx).await
281336
}
282337

283-
/// Convert the Cypher query to a DataFusion SQL string
338+
/// Convert the Cypher query to a SQL string in the specified dialect.
284339
///
285340
/// This method generates a SQL string that corresponds to the DataFusion logical plan
286-
/// derived from the Cypher query. It uses the `datafusion-sql` unparser.
341+
/// derived from the Cypher query, using the specified SQL dialect for unparsing.
287342
///
288343
/// **WARNING**: This method is experimental and the generated SQL dialect may change.
289344
///
@@ -293,16 +348,20 @@ impl CypherQuery {
293348
///
294349
/// # Arguments
295350
/// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships)
351+
/// * `dialect` - The SQL dialect to use for generating the output SQL.
352+
/// Defaults to `SqlDialect::Default` (generic DataFusion SQL).
353+
/// Use `SqlDialect::Spark` for Spark SQL, `SqlDialect::PostgreSql`, etc.
296354
///
297355
/// # Returns
298-
/// A SQL string representing the query
356+
/// A SQL string representing the query in the specified dialect
299357
pub async fn to_sql(
300358
&self,
301359
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
360+
dialect: Option<SqlDialect>,
302361
) -> Result<String> {
303-
use datafusion_sql::unparser::plan_to_sql;
304362
use std::sync::Arc;
305363

364+
let dialect = dialect.unwrap_or_default();
306365
let _config = self.require_config()?;
307366

308367
// Build catalog and context from datasets using the helper
@@ -323,11 +382,15 @@ impl CypherQuery {
323382
location: snafu::Location::new(file!(), line!(), column!()),
324383
})?;
325384

326-
// Unparse to SQL
327-
let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
328-
message: format!("Failed to unparse plan to SQL: {}", e),
329-
location: snafu::Location::new(file!(), line!(), column!()),
330-
})?;
385+
// Unparse to SQL using the specified dialect
386+
let dialect_unparser = dialect.unparser();
387+
let unparser = dialect_unparser.as_unparser();
388+
let sql_ast = unparser
389+
.plan_to_sql(&optimized_plan)
390+
.map_err(|e| GraphError::PlanError {
391+
message: format!("Failed to unparse plan to SQL: {}", e),
392+
location: snafu::Location::new(file!(), line!(), column!()),
393+
})?;
331394

332395
Ok(sql_ast.to_string())
333396
}
@@ -1852,7 +1915,7 @@ mod tests {
18521915
.unwrap()
18531916
.with_config(cfg);
18541917

1855-
let sql = query.to_sql(datasets).await.unwrap();
1918+
let sql = query.to_sql(datasets, None).await.unwrap();
18561919
println!("Generated SQL: {}", sql);
18571920

18581921
assert!(sql.contains("SELECT"));
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright The Lance Authors
3+
4+
//! Spark SQL dialect for the DataFusion unparser.
5+
//!
6+
//! This module provides a Spark SQL dialect built using DataFusion's
7+
//! [`CustomDialectBuilder`].
8+
//!
9+
//! Key Spark SQL differences from standard SQL:
10+
//! - Backtick (`` ` ``) identifier quoting
11+
//! - `EXTRACT(field FROM expr)` for date field extraction
12+
//! - `STRING` type for casting (not `VARCHAR`)
13+
//! - `BIGINT`/`INT` for integer types
14+
//! - `TIMESTAMP` for all timestamp types (no timezone info in cast)
15+
//! - `LENGTH()` instead of `CHARACTER_LENGTH()`
16+
//! - Subqueries in FROM require aliases
17+
18+
use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo};
19+
use datafusion_sql::unparser::dialect::{
20+
CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle,
21+
};
22+
23+
/// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`.
24+
pub fn build_spark_dialect() -> CustomDialect {
25+
CustomDialectBuilder::new()
26+
.with_identifier_quote_style('`')
27+
.with_supports_nulls_first_in_sort(true)
28+
.with_use_timestamp_for_date64(true)
29+
.with_utf8_cast_dtype(ast::DataType::Custom(
30+
ObjectName::from(vec![Ident::new("STRING")]),
31+
vec![],
32+
))
33+
.with_large_utf8_cast_dtype(ast::DataType::Custom(
34+
ObjectName::from(vec![Ident::new("STRING")]),
35+
vec![],
36+
))
37+
.with_date_field_extract_style(DateFieldExtractStyle::Extract)
38+
.with_character_length_style(CharacterLengthStyle::Length)
39+
.with_int64_cast_dtype(ast::DataType::BigInt(None))
40+
.with_int32_cast_dtype(ast::DataType::Int(None))
41+
.with_timestamp_cast_dtype(
42+
ast::DataType::Timestamp(None, TimezoneInfo::None),
43+
ast::DataType::Timestamp(None, TimezoneInfo::None),
44+
)
45+
.with_date32_cast_dtype(ast::DataType::Date)
46+
.with_supports_column_alias_in_table_alias(true)
47+
.with_requires_derived_table_alias(true)
48+
.with_full_qualified_col(false)
49+
.with_unnest_as_table_factor(false)
50+
.with_float64_ast_dtype(ast::DataType::Double(ast::ExactNumberInfo::None))
51+
.build()
52+
}
53+
54+
#[cfg(test)]
55+
mod tests {
56+
use super::*;
57+
use datafusion_sql::unparser::dialect::Dialect;
58+
59+
#[test]
60+
fn test_spark_dialect_identifier_quoting() {
61+
let dialect = build_spark_dialect();
62+
assert_eq!(dialect.identifier_quote_style("table_name"), Some('`'));
63+
assert_eq!(dialect.identifier_quote_style("column"), Some('`'));
64+
}
65+
66+
#[test]
67+
fn test_spark_dialect_type_mappings() {
68+
let dialect = build_spark_dialect();
69+
assert!(matches!(
70+
dialect.utf8_cast_dtype(),
71+
ast::DataType::Custom(..)
72+
));
73+
assert!(matches!(
74+
dialect.int64_cast_dtype(),
75+
ast::DataType::BigInt(None)
76+
));
77+
assert!(matches!(
78+
dialect.int32_cast_dtype(),
79+
ast::DataType::Int(None)
80+
));
81+
assert!(matches!(dialect.date32_cast_dtype(), ast::DataType::Date));
82+
}
83+
84+
#[test]
85+
fn test_spark_dialect_requires_derived_table_alias() {
86+
let dialect = build_spark_dialect();
87+
assert!(dialect.requires_derived_table_alias());
88+
}
89+
90+
#[test]
91+
fn test_spark_dialect_extract_style() {
92+
let dialect = build_spark_dialect();
93+
assert!(matches!(
94+
dialect.date_field_extract_style(),
95+
DateFieldExtractStyle::Extract
96+
));
97+
}
98+
99+
#[test]
100+
fn test_spark_dialect_character_length_style() {
101+
let dialect = build_spark_dialect();
102+
assert!(matches!(
103+
dialect.character_length_style(),
104+
CharacterLengthStyle::Length
105+
));
106+
}
107+
}

0 commit comments

Comments
 (0)