@@ -9,13 +9,68 @@ use crate::config::GraphConfig;
99use crate :: error:: { GraphError , Result } ;
1010use crate :: logical_plan:: LogicalPlanner ;
1111use crate :: parser:: parse_cypher_query;
12+ use crate :: spark_dialect:: build_spark_dialect;
1213use arrow_array:: RecordBatch ;
1314use arrow_schema:: { Field , Schema , SchemaRef } ;
15+ use datafusion_sql:: unparser:: dialect:: {
16+ CustomDialect , DefaultDialect , MySqlDialect , PostgreSqlDialect , SqliteDialect ,
17+ } ;
18+ use datafusion_sql:: unparser:: Unparser ;
1419use lance_graph_catalog:: DirNamespace ;
1520use lance_namespace:: models:: DescribeTableRequest ;
1621use std:: collections:: { HashMap , HashSet } ;
1722use std:: sync:: Arc ;
1823
24+ /// SQL dialect to use when generating SQL from Cypher queries.
25+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , Default ) ]
26+ pub enum SqlDialect {
27+ /// Generic SQL (DataFusion default dialect)
28+ #[ default]
29+ Default ,
30+ /// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.)
31+ Spark ,
32+ /// PostgreSQL dialect
33+ PostgreSql ,
34+ /// MySQL dialect
35+ MySql ,
36+ /// SQLite dialect
37+ Sqlite ,
38+ }
39+
40+ /// Wrapper to hold the concrete dialect type and provide an `Unparser` reference.
41+ pub enum DialectUnparser {
42+ Default ( DefaultDialect ) ,
43+ Spark ( Box < CustomDialect > ) ,
44+ PostgreSql ( PostgreSqlDialect ) ,
45+ MySql ( MySqlDialect ) ,
46+ Sqlite ( SqliteDialect ) ,
47+ }
48+
49+ impl DialectUnparser {
50+ pub fn as_unparser ( & self ) -> Unparser < ' _ > {
51+ match self {
52+ DialectUnparser :: Default ( d) => Unparser :: new ( d) ,
53+ DialectUnparser :: Spark ( d) => Unparser :: new ( d. as_ref ( ) ) ,
54+ DialectUnparser :: PostgreSql ( d) => Unparser :: new ( d) ,
55+ DialectUnparser :: MySql ( d) => Unparser :: new ( d) ,
56+ DialectUnparser :: Sqlite ( d) => Unparser :: new ( d) ,
57+ }
58+ }
59+ }
60+
61+ impl SqlDialect {
62+ /// Create a `DialectUnparser` configured for this dialect.
63+ pub fn unparser ( & self ) -> DialectUnparser {
64+ match self {
65+ SqlDialect :: Default => DialectUnparser :: Default ( DefaultDialect { } ) ,
66+ SqlDialect :: Spark => DialectUnparser :: Spark ( Box :: new ( build_spark_dialect ( ) ) ) ,
67+ SqlDialect :: PostgreSql => DialectUnparser :: PostgreSql ( PostgreSqlDialect { } ) ,
68+ SqlDialect :: MySql => DialectUnparser :: MySql ( MySqlDialect { } ) ,
69+ SqlDialect :: Sqlite => DialectUnparser :: Sqlite ( SqliteDialect { } ) ,
70+ }
71+ }
72+ }
73+
1974/// Normalize an Arrow schema to have lowercase field names.
2075///
2176/// This ensures that column names in the dataset match the normalized
@@ -280,10 +335,10 @@ impl CypherQuery {
280335 self . explain_internal ( Arc :: new ( catalog) , ctx) . await
281336 }
282337
283- /// Convert the Cypher query to a DataFusion SQL string
338+ /// Convert the Cypher query to a SQL string in the specified dialect.
284339 ///
285340 /// This method generates a SQL string that corresponds to the DataFusion logical plan
286- /// derived from the Cypher query. It uses the `datafusion-sql` unparser .
341+ /// derived from the Cypher query, using the specified SQL dialect for unparsing .
287342 ///
288343 /// **WARNING**: This method is experimental and the generated SQL dialect may change.
289344 ///
@@ -293,16 +348,20 @@ impl CypherQuery {
293348 ///
294349 /// # Arguments
295350 /// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships)
351+ /// * `dialect` - The SQL dialect to use for generating the output SQL.
352+ /// Defaults to `SqlDialect::Default` (generic DataFusion SQL).
353+ /// Use `SqlDialect::Spark` for Spark SQL, `SqlDialect::PostgreSql`, etc.
296354 ///
297355 /// # Returns
298- /// A SQL string representing the query
356+ /// A SQL string representing the query in the specified dialect
299357 pub async fn to_sql (
300358 & self ,
301359 datasets : HashMap < String , arrow:: record_batch:: RecordBatch > ,
360+ dialect : Option < SqlDialect > ,
302361 ) -> Result < String > {
303- use datafusion_sql:: unparser:: plan_to_sql;
304362 use std:: sync:: Arc ;
305363
364+ let dialect = dialect. unwrap_or_default ( ) ;
306365 let _config = self . require_config ( ) ?;
307366
308367 // Build catalog and context from datasets using the helper
@@ -323,11 +382,15 @@ impl CypherQuery {
323382 location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
324383 } ) ?;
325384
326- // Unparse to SQL
327- let sql_ast = plan_to_sql ( & optimized_plan) . map_err ( |e| GraphError :: PlanError {
328- message : format ! ( "Failed to unparse plan to SQL: {}" , e) ,
329- location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
330- } ) ?;
385+ // Unparse to SQL using the specified dialect
386+ let dialect_unparser = dialect. unparser ( ) ;
387+ let unparser = dialect_unparser. as_unparser ( ) ;
388+ let sql_ast = unparser
389+ . plan_to_sql ( & optimized_plan)
390+ . map_err ( |e| GraphError :: PlanError {
391+ message : format ! ( "Failed to unparse plan to SQL: {}" , e) ,
392+ location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
393+ } ) ?;
331394
332395 Ok ( sql_ast. to_string ( ) )
333396 }
@@ -1852,7 +1915,7 @@ mod tests {
18521915 . unwrap ( )
18531916 . with_config ( cfg) ;
18541917
1855- let sql = query. to_sql ( datasets) . await . unwrap ( ) ;
1918+ let sql = query. to_sql ( datasets, None ) . await . unwrap ( ) ;
18561919 println ! ( "Generated SQL: {}" , sql) ;
18571920
18581921 assert ! ( sql. contains( "SELECT" ) ) ;
0 commit comments