Skip to content

Commit 1bc5623

Browse files
author
Yu Chen
committed
refactor: address PR review comments for SQL dialect support
- Use CustomDialectBuilder for Spark dialect instead of manual Dialect impl - Move SqlDialect enum from spark_dialect.rs to query.rs as general-purpose type - Expose SqlDialect as a Python enum instead of error-prone string parameter Co-authored-by: Isaac
1 parent 8306436 commit 1bc5623

5 files changed

Lines changed: 107 additions & 156 deletions

File tree

crates/lance-graph-python/src/graph.rs

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ use lance_graph::{
2727
ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause},
2828
CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy,
2929
GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog,
30-
SqlQuery as RustSqlQuery, VectorSearch as RustVectorSearch,
30+
SqlDialect as RustSqlDialect, SqlQuery as RustSqlQuery,
31+
VectorSearch as RustVectorSearch,
3132
};
3233
use pyo3::{
3334
exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError},
@@ -59,6 +60,34 @@ impl From<ExecutionStrategy> for RustExecutionStrategy {
5960
}
6061
}
6162

63+
/// SQL dialect for generating SQL from Cypher queries
64+
#[pyclass(name = "SqlDialect", module = "lance.graph")]
65+
#[derive(Clone, Copy)]
66+
pub enum SqlDialect {
67+
/// Generic SQL (DataFusion default dialect)
68+
Default,
69+
/// Spark SQL dialect
70+
Spark,
71+
/// PostgreSQL dialect
72+
PostgreSql,
73+
/// MySQL dialect
74+
MySql,
75+
/// SQLite dialect
76+
Sqlite,
77+
}
78+
79+
impl From<SqlDialect> for RustSqlDialect {
80+
fn from(dialect: SqlDialect) -> Self {
81+
match dialect {
82+
SqlDialect::Default => RustSqlDialect::Default,
83+
SqlDialect::Spark => RustSqlDialect::Spark,
84+
SqlDialect::PostgreSql => RustSqlDialect::PostgreSql,
85+
SqlDialect::MySql => RustSqlDialect::MySql,
86+
SqlDialect::Sqlite => RustSqlDialect::Sqlite,
87+
}
88+
}
89+
}
90+
6291
/// Distance metric for vector similarity search
6392
#[pyclass(name = "DistanceMetric", module = "lance.graph")]
6493
#[derive(Clone, Copy)]
@@ -494,9 +523,8 @@ impl CypherQuery {
494523
/// ----------
495524
/// datasets : dict
496525
/// Dictionary mapping table names to Lance datasets
497-
/// dialect : str, optional
498-
/// SQL dialect to use. One of "default", "spark", "postgresql", "mysql", "sqlite".
499-
/// Defaults to "default" (generic DataFusion SQL).
526+
/// dialect : SqlDialect, optional
527+
/// SQL dialect to use. Defaults to SqlDialect.Default (generic DataFusion SQL).
500528
///
501529
/// Returns
502530
/// -------
@@ -507,28 +535,14 @@ impl CypherQuery {
507535
/// ------
508536
/// RuntimeError
509537
/// If SQL generation fails
510-
/// ValueError
511-
/// If an invalid dialect is specified
512538
#[pyo3(signature = (datasets, dialect=None))]
513539
fn to_sql(
514540
&self,
515541
py: Python,
516542
datasets: &Bound<'_, PyDict>,
517-
dialect: Option<&str>,
543+
dialect: Option<SqlDialect>,
518544
) -> PyResult<String> {
519-
let sql_dialect = match dialect {
520-
None | Some("default") => None,
521-
Some("spark") => Some(lance_graph::SqlDialect::Spark),
522-
Some("postgresql") | Some("postgres") => Some(lance_graph::SqlDialect::PostgreSql),
523-
Some("mysql") => Some(lance_graph::SqlDialect::MySql),
524-
Some("sqlite") => Some(lance_graph::SqlDialect::Sqlite),
525-
Some(other) => {
526-
return Err(PyValueError::new_err(format!(
527-
"Unknown SQL dialect: '{}'. Valid options: 'default', 'spark', 'postgresql', 'mysql', 'sqlite'",
528-
other
529-
)));
530-
}
531-
};
545+
let sql_dialect = dialect.map(|d| d.into());
532546

533547
// Convert datasets to Arrow RecordBatch map
534548
let arrow_datasets = python_datasets_to_batches(datasets)?;
@@ -1570,6 +1584,7 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) ->
15701584
let graph_module = PyModule::new(py, "graph")?;
15711585

15721586
graph_module.add_class::<ExecutionStrategy>()?;
1587+
graph_module.add_class::<SqlDialect>()?;
15731588
graph_module.add_class::<DistanceMetric>()?;
15741589
graph_module.add_class::<GraphConfig>()?;
15751590
graph_module.add_class::<GraphConfigBuilder>()?;

crates/lance-graph/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ pub use lance_graph_catalog::{
6868
#[cfg(feature = "unity-catalog")]
6969
pub use lance_graph_catalog::{UnityCatalogConfig, UnityCatalogProvider};
7070
pub use lance_vector_search::VectorSearch;
71-
pub use query::{CypherQuery, ExecutionStrategy};
72-
pub use spark_dialect::{SqlDialect, SparkDialect};
71+
pub use query::{CypherQuery, ExecutionStrategy, SqlDialect};
7372
pub use sql_query::SqlQuery;
7473
#[cfg(feature = "delta")]
7574
pub use table_readers::DeltaTableReader;

crates/lance-graph/src/query.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@ use lance_namespace::models::DescribeTableRequest;
1616
use std::collections::{HashMap, HashSet};
1717
use std::sync::Arc;
1818

19+
/// SQL dialect to use when generating SQL from Cypher queries.
20+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21+
pub enum SqlDialect {
22+
/// Generic SQL (DataFusion default dialect)
23+
#[default]
24+
Default,
25+
/// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.)
26+
Spark,
27+
/// PostgreSQL dialect
28+
PostgreSql,
29+
/// MySQL dialect
30+
MySql,
31+
/// SQLite dialect
32+
Sqlite,
33+
}
34+
1935
/// Normalize an Arrow schema to have lowercase field names.
2036
///
2137
/// This ensures that column names in the dataset match the normalized
@@ -302,7 +318,7 @@ impl CypherQuery {
302318
pub async fn to_sql(
303319
&self,
304320
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
305-
dialect: Option<crate::spark_dialect::SqlDialect>,
321+
dialect: Option<SqlDialect>,
306322
) -> Result<String> {
307323
use std::sync::Arc;
308324

crates/lance-graph/src/spark_dialect.rs

Lines changed: 52 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
//! SQL dialect support for the DataFusion unparser.
55
//!
6-
//! This module provides a [`SqlDialect`] enum for selecting which SQL dialect
7-
//! to use when unparsing DataFusion logical plans to SQL strings, and includes
8-
//! a [`SparkDialect`] implementation for Spark SQL.
6+
//! This module provides a Spark SQL dialect built using DataFusion's
7+
//! [`CustomDialectBuilder`], and a helper to build an [`Unparser`] for any
8+
//! supported [`SqlDialect`].
99
//!
1010
//! Key Spark SQL differences from standard SQL:
1111
//! - Backtick (`` ` ``) identifier quoting
@@ -16,51 +16,50 @@
1616
//! - `LENGTH()` instead of `CHARACTER_LENGTH()`
1717
//! - Subqueries in FROM require aliases
1818
19-
use std::sync::Arc;
20-
21-
use arrow::datatypes::TimeUnit;
22-
use datafusion_common::Result;
23-
use datafusion_expr::Expr;
2419
use datafusion_sql::unparser::dialect::{
25-
CharacterLengthStyle, DateFieldExtractStyle, DefaultDialect, Dialect, IntervalStyle,
26-
MySqlDialect, PostgreSqlDialect, SqliteDialect,
20+
CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle,
21+
DefaultDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
2722
};
2823
use datafusion_sql::unparser::Unparser;
2924
use datafusion_sql::sqlparser::ast::{self, Ident, ObjectName, TimezoneInfo};
3025

31-
/// SQL dialect to use when generating SQL from Cypher queries.
32-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
33-
pub enum SqlDialect {
34-
/// Generic SQL (DataFusion default dialect)
35-
#[default]
36-
Default,
37-
/// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.)
38-
Spark,
39-
/// PostgreSQL dialect
40-
PostgreSql,
41-
/// MySQL dialect
42-
MySql,
43-
/// SQLite dialect
44-
Sqlite,
45-
}
26+
use crate::query::SqlDialect;
4627

47-
impl SqlDialect {
48-
/// Create a DataFusion `Unparser` configured for this dialect.
49-
pub fn unparser(&self) -> DialectUnparser {
50-
match self {
51-
SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}),
52-
SqlDialect::Spark => DialectUnparser::Spark(SparkDialect),
53-
SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}),
54-
SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}),
55-
SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}),
56-
}
57-
}
28+
/// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`.
29+
pub fn build_spark_dialect() -> CustomDialect {
30+
CustomDialectBuilder::new()
31+
.with_identifier_quote_style('`')
32+
.with_supports_nulls_first_in_sort(true)
33+
.with_use_timestamp_for_date64(true)
34+
.with_utf8_cast_dtype(ast::DataType::Custom(
35+
ObjectName::from(vec![Ident::new("STRING")]),
36+
vec![],
37+
))
38+
.with_large_utf8_cast_dtype(ast::DataType::Custom(
39+
ObjectName::from(vec![Ident::new("STRING")]),
40+
vec![],
41+
))
42+
.with_date_field_extract_style(DateFieldExtractStyle::Extract)
43+
.with_character_length_style(CharacterLengthStyle::Length)
44+
.with_int64_cast_dtype(ast::DataType::BigInt(None))
45+
.with_int32_cast_dtype(ast::DataType::Int(None))
46+
.with_timestamp_cast_dtype(
47+
ast::DataType::Timestamp(None, TimezoneInfo::None),
48+
ast::DataType::Timestamp(None, TimezoneInfo::None),
49+
)
50+
.with_date32_cast_dtype(ast::DataType::Date)
51+
.with_supports_column_alias_in_table_alias(true)
52+
.with_requires_derived_table_alias(true)
53+
.with_full_qualified_col(false)
54+
.with_unnest_as_table_factor(false)
55+
.with_float64_ast_dtype(ast::DataType::Double(ast::ExactNumberInfo::None))
56+
.build()
5857
}
5958

6059
/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference.
6160
pub enum DialectUnparser {
6261
Default(DefaultDialect),
63-
Spark(SparkDialect),
62+
Spark(CustomDialect),
6463
PostgreSql(PostgreSqlDialect),
6564
MySql(MySqlDialect),
6665
Sqlite(SqliteDialect),
@@ -78,114 +77,34 @@ impl DialectUnparser {
7877
}
7978
}
8079

81-
/// A Spark SQL dialect for unparsing DataFusion logical plans to Spark-compatible SQL.
82-
pub struct SparkDialect;
83-
84-
impl Dialect for SparkDialect {
85-
fn identifier_quote_style(&self, _identifier: &str) -> Option<char> {
86-
Some('`')
87-
}
88-
89-
fn supports_nulls_first_in_sort(&self) -> bool {
90-
true
91-
}
92-
93-
fn use_timestamp_for_date64(&self) -> bool {
94-
true
95-
}
96-
97-
fn interval_style(&self) -> IntervalStyle {
98-
IntervalStyle::SQLStandard
99-
}
100-
101-
fn float64_ast_dtype(&self) -> ast::DataType {
102-
ast::DataType::Double(ast::ExactNumberInfo::None)
103-
}
104-
105-
fn utf8_cast_dtype(&self) -> ast::DataType {
106-
ast::DataType::Custom(
107-
ObjectName::from(vec![Ident::new("STRING")]),
108-
vec![],
109-
)
110-
}
111-
112-
fn large_utf8_cast_dtype(&self) -> ast::DataType {
113-
ast::DataType::Custom(
114-
ObjectName::from(vec![Ident::new("STRING")]),
115-
vec![],
116-
)
117-
}
118-
119-
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
120-
DateFieldExtractStyle::Extract
121-
}
122-
123-
fn character_length_style(&self) -> CharacterLengthStyle {
124-
CharacterLengthStyle::Length
125-
}
126-
127-
fn int64_cast_dtype(&self) -> ast::DataType {
128-
ast::DataType::BigInt(None)
129-
}
130-
131-
fn int32_cast_dtype(&self) -> ast::DataType {
132-
ast::DataType::Int(None)
133-
}
134-
135-
fn timestamp_cast_dtype(
136-
&self,
137-
_time_unit: &TimeUnit,
138-
_tz: &Option<Arc<str>>,
139-
) -> ast::DataType {
140-
ast::DataType::Timestamp(None, TimezoneInfo::None)
141-
}
142-
143-
fn date32_cast_dtype(&self) -> ast::DataType {
144-
ast::DataType::Date
145-
}
146-
147-
fn supports_column_alias_in_table_alias(&self) -> bool {
148-
true
149-
}
150-
151-
fn requires_derived_table_alias(&self) -> bool {
152-
true
153-
}
154-
155-
fn full_qualified_col(&self) -> bool {
156-
false
157-
}
158-
159-
fn unnest_as_table_factor(&self) -> bool {
160-
false
161-
}
162-
163-
fn scalar_function_to_sql_overrides(
164-
&self,
165-
_unparser: &Unparser,
166-
_func_name: &str,
167-
_args: &[Expr],
168-
) -> Result<Option<ast::Expr>> {
169-
// character_length -> length is handled by CharacterLengthStyle::Length
170-
// Additional Spark-specific function mappings can be added here as needed
171-
Ok(None)
80+
impl SqlDialect {
81+
/// Create a `DialectUnparser` configured for this dialect.
82+
pub fn unparser(&self) -> DialectUnparser {
83+
match self {
84+
SqlDialect::Default => DialectUnparser::Default(DefaultDialect {}),
85+
SqlDialect::Spark => DialectUnparser::Spark(build_spark_dialect()),
86+
SqlDialect::PostgreSql => DialectUnparser::PostgreSql(PostgreSqlDialect {}),
87+
SqlDialect::MySql => DialectUnparser::MySql(MySqlDialect {}),
88+
SqlDialect::Sqlite => DialectUnparser::Sqlite(SqliteDialect {}),
89+
}
17290
}
17391
}
17492

17593
#[cfg(test)]
17694
mod tests {
17795
use super::*;
96+
use datafusion_sql::unparser::dialect::Dialect;
17897

17998
#[test]
18099
fn test_spark_dialect_identifier_quoting() {
181-
let dialect = SparkDialect;
100+
let dialect = build_spark_dialect();
182101
assert_eq!(dialect.identifier_quote_style("table_name"), Some('`'));
183102
assert_eq!(dialect.identifier_quote_style("column"), Some('`'));
184103
}
185104

186105
#[test]
187106
fn test_spark_dialect_type_mappings() {
188-
let dialect = SparkDialect;
107+
let dialect = build_spark_dialect();
189108
assert!(matches!(dialect.utf8_cast_dtype(), ast::DataType::Custom(..)));
190109
assert!(matches!(dialect.int64_cast_dtype(), ast::DataType::BigInt(None)));
191110
assert!(matches!(dialect.int32_cast_dtype(), ast::DataType::Int(None)));
@@ -194,13 +113,13 @@ mod tests {
194113

195114
#[test]
196115
fn test_spark_dialect_requires_derived_table_alias() {
197-
let dialect = SparkDialect;
116+
let dialect = build_spark_dialect();
198117
assert!(dialect.requires_derived_table_alias());
199118
}
200119

201120
#[test]
202121
fn test_spark_dialect_extract_style() {
203-
let dialect = SparkDialect;
122+
let dialect = build_spark_dialect();
204123
assert!(matches!(
205124
dialect.date_field_extract_style(),
206125
DateFieldExtractStyle::Extract
@@ -209,7 +128,7 @@ mod tests {
209128

210129
#[test]
211130
fn test_spark_dialect_character_length_style() {
212-
let dialect = SparkDialect;
131+
let dialect = build_spark_dialect();
213132
assert!(matches!(
214133
dialect.character_length_style(),
215134
CharacterLengthStyle::Length

0 commit comments

Comments
 (0)