33
44//! SQL dialect support for the DataFusion unparser.
55//!
6- //! This module provides a [`SqlDialect`] enum for selecting which SQL dialect
7- //! to use when unparsing DataFusion logical plans to SQL strings, and includes
8- //! a [`SparkDialect`] implementation for Spark SQL .
6+ //! This module provides a Spark SQL dialect built using DataFusion's
7+ //! [`CustomDialectBuilder`], and a helper to build an [`Unparser`] for any
8+ //! supported [`SqlDialect`] .
99//!
1010//! Key Spark SQL differences from standard SQL:
1111//! - Backtick (`` ` ``) identifier quoting
1616//! - `LENGTH()` instead of `CHARACTER_LENGTH()`
1717//! - Subqueries in FROM require aliases
1818
19- use std:: sync:: Arc ;
20-
21- use arrow:: datatypes:: TimeUnit ;
22- use datafusion_common:: Result ;
23- use datafusion_expr:: Expr ;
2419use datafusion_sql:: unparser:: dialect:: {
25- CharacterLengthStyle , DateFieldExtractStyle , DefaultDialect , Dialect , IntervalStyle ,
26- MySqlDialect , PostgreSqlDialect , SqliteDialect ,
20+ CharacterLengthStyle , CustomDialect , CustomDialectBuilder , DateFieldExtractStyle ,
21+ DefaultDialect , MySqlDialect , PostgreSqlDialect , SqliteDialect ,
2722} ;
2823use datafusion_sql:: unparser:: Unparser ;
2924use datafusion_sql:: sqlparser:: ast:: { self , Ident , ObjectName , TimezoneInfo } ;
3025
31- /// SQL dialect to use when generating SQL from Cypher queries.
32- #[ derive( Debug , Clone , Copy , PartialEq , Eq , Default ) ]
33- pub enum SqlDialect {
34- /// Generic SQL (DataFusion default dialect)
35- #[ default]
36- Default ,
37- /// Spark SQL dialect (backtick quoting, STRING type, EXTRACT, etc.)
38- Spark ,
39- /// PostgreSQL dialect
40- PostgreSql ,
41- /// MySQL dialect
42- MySql ,
43- /// SQLite dialect
44- Sqlite ,
45- }
26+ use crate :: query:: SqlDialect ;
4627
47- impl SqlDialect {
48- /// Create a DataFusion `Unparser` configured for this dialect.
49- pub fn unparser ( & self ) -> DialectUnparser {
50- match self {
51- SqlDialect :: Default => DialectUnparser :: Default ( DefaultDialect { } ) ,
52- SqlDialect :: Spark => DialectUnparser :: Spark ( SparkDialect ) ,
53- SqlDialect :: PostgreSql => DialectUnparser :: PostgreSql ( PostgreSqlDialect { } ) ,
54- SqlDialect :: MySql => DialectUnparser :: MySql ( MySqlDialect { } ) ,
55- SqlDialect :: Sqlite => DialectUnparser :: Sqlite ( SqliteDialect { } ) ,
56- }
57- }
28+ /// Build a Spark SQL dialect using DataFusion's `CustomDialectBuilder`.
29+ pub fn build_spark_dialect ( ) -> CustomDialect {
30+ CustomDialectBuilder :: new ( )
31+ . with_identifier_quote_style ( '`' )
32+ . with_supports_nulls_first_in_sort ( true )
33+ . with_use_timestamp_for_date64 ( true )
34+ . with_utf8_cast_dtype ( ast:: DataType :: Custom (
35+ ObjectName :: from ( vec ! [ Ident :: new( "STRING" ) ] ) ,
36+ vec ! [ ] ,
37+ ) )
38+ . with_large_utf8_cast_dtype ( ast:: DataType :: Custom (
39+ ObjectName :: from ( vec ! [ Ident :: new( "STRING" ) ] ) ,
40+ vec ! [ ] ,
41+ ) )
42+ . with_date_field_extract_style ( DateFieldExtractStyle :: Extract )
43+ . with_character_length_style ( CharacterLengthStyle :: Length )
44+ . with_int64_cast_dtype ( ast:: DataType :: BigInt ( None ) )
45+ . with_int32_cast_dtype ( ast:: DataType :: Int ( None ) )
46+ . with_timestamp_cast_dtype (
47+ ast:: DataType :: Timestamp ( None , TimezoneInfo :: None ) ,
48+ ast:: DataType :: Timestamp ( None , TimezoneInfo :: None ) ,
49+ )
50+ . with_date32_cast_dtype ( ast:: DataType :: Date )
51+ . with_supports_column_alias_in_table_alias ( true )
52+ . with_requires_derived_table_alias ( true )
53+ . with_full_qualified_col ( false )
54+ . with_unnest_as_table_factor ( false )
55+ . with_float64_ast_dtype ( ast:: DataType :: Double ( ast:: ExactNumberInfo :: None ) )
56+ . build ( )
5857}
5958
6059/// Wrapper to hold the concrete dialect type and provide an `Unparser` reference.
6160pub enum DialectUnparser {
6261 Default ( DefaultDialect ) ,
63- Spark ( SparkDialect ) ,
62+ Spark ( CustomDialect ) ,
6463 PostgreSql ( PostgreSqlDialect ) ,
6564 MySql ( MySqlDialect ) ,
6665 Sqlite ( SqliteDialect ) ,
@@ -78,114 +77,34 @@ impl DialectUnparser {
7877 }
7978}
8079
81- /// A Spark SQL dialect for unparsing DataFusion logical plans to Spark-compatible SQL.
82- pub struct SparkDialect ;
83-
84- impl Dialect for SparkDialect {
85- fn identifier_quote_style ( & self , _identifier : & str ) -> Option < char > {
86- Some ( '`' )
87- }
88-
89- fn supports_nulls_first_in_sort ( & self ) -> bool {
90- true
91- }
92-
93- fn use_timestamp_for_date64 ( & self ) -> bool {
94- true
95- }
96-
97- fn interval_style ( & self ) -> IntervalStyle {
98- IntervalStyle :: SQLStandard
99- }
100-
101- fn float64_ast_dtype ( & self ) -> ast:: DataType {
102- ast:: DataType :: Double ( ast:: ExactNumberInfo :: None )
103- }
104-
105- fn utf8_cast_dtype ( & self ) -> ast:: DataType {
106- ast:: DataType :: Custom (
107- ObjectName :: from ( vec ! [ Ident :: new( "STRING" ) ] ) ,
108- vec ! [ ] ,
109- )
110- }
111-
112- fn large_utf8_cast_dtype ( & self ) -> ast:: DataType {
113- ast:: DataType :: Custom (
114- ObjectName :: from ( vec ! [ Ident :: new( "STRING" ) ] ) ,
115- vec ! [ ] ,
116- )
117- }
118-
119- fn date_field_extract_style ( & self ) -> DateFieldExtractStyle {
120- DateFieldExtractStyle :: Extract
121- }
122-
123- fn character_length_style ( & self ) -> CharacterLengthStyle {
124- CharacterLengthStyle :: Length
125- }
126-
127- fn int64_cast_dtype ( & self ) -> ast:: DataType {
128- ast:: DataType :: BigInt ( None )
129- }
130-
131- fn int32_cast_dtype ( & self ) -> ast:: DataType {
132- ast:: DataType :: Int ( None )
133- }
134-
135- fn timestamp_cast_dtype (
136- & self ,
137- _time_unit : & TimeUnit ,
138- _tz : & Option < Arc < str > > ,
139- ) -> ast:: DataType {
140- ast:: DataType :: Timestamp ( None , TimezoneInfo :: None )
141- }
142-
143- fn date32_cast_dtype ( & self ) -> ast:: DataType {
144- ast:: DataType :: Date
145- }
146-
147- fn supports_column_alias_in_table_alias ( & self ) -> bool {
148- true
149- }
150-
151- fn requires_derived_table_alias ( & self ) -> bool {
152- true
153- }
154-
155- fn full_qualified_col ( & self ) -> bool {
156- false
157- }
158-
159- fn unnest_as_table_factor ( & self ) -> bool {
160- false
161- }
162-
163- fn scalar_function_to_sql_overrides (
164- & self ,
165- _unparser : & Unparser ,
166- _func_name : & str ,
167- _args : & [ Expr ] ,
168- ) -> Result < Option < ast:: Expr > > {
169- // character_length -> length is handled by CharacterLengthStyle::Length
170- // Additional Spark-specific function mappings can be added here as needed
171- Ok ( None )
80+ impl SqlDialect {
81+ /// Create a `DialectUnparser` configured for this dialect.
82+ pub fn unparser ( & self ) -> DialectUnparser {
83+ match self {
84+ SqlDialect :: Default => DialectUnparser :: Default ( DefaultDialect { } ) ,
85+ SqlDialect :: Spark => DialectUnparser :: Spark ( build_spark_dialect ( ) ) ,
86+ SqlDialect :: PostgreSql => DialectUnparser :: PostgreSql ( PostgreSqlDialect { } ) ,
87+ SqlDialect :: MySql => DialectUnparser :: MySql ( MySqlDialect { } ) ,
88+ SqlDialect :: Sqlite => DialectUnparser :: Sqlite ( SqliteDialect { } ) ,
89+ }
17290 }
17391}
17492
17593#[ cfg( test) ]
17694mod tests {
17795 use super :: * ;
96+ use datafusion_sql:: unparser:: dialect:: Dialect ;
17897
17998 #[ test]
18099 fn test_spark_dialect_identifier_quoting ( ) {
181- let dialect = SparkDialect ;
100+ let dialect = build_spark_dialect ( ) ;
182101 assert_eq ! ( dialect. identifier_quote_style( "table_name" ) , Some ( '`' ) ) ;
183102 assert_eq ! ( dialect. identifier_quote_style( "column" ) , Some ( '`' ) ) ;
184103 }
185104
186105 #[ test]
187106 fn test_spark_dialect_type_mappings ( ) {
188- let dialect = SparkDialect ;
107+ let dialect = build_spark_dialect ( ) ;
189108 assert ! ( matches!( dialect. utf8_cast_dtype( ) , ast:: DataType :: Custom ( ..) ) ) ;
190109 assert ! ( matches!( dialect. int64_cast_dtype( ) , ast:: DataType :: BigInt ( None ) ) ) ;
191110 assert ! ( matches!( dialect. int32_cast_dtype( ) , ast:: DataType :: Int ( None ) ) ) ;
@@ -194,13 +113,13 @@ mod tests {
194113
195114 #[ test]
196115 fn test_spark_dialect_requires_derived_table_alias ( ) {
197- let dialect = SparkDialect ;
116+ let dialect = build_spark_dialect ( ) ;
198117 assert ! ( dialect. requires_derived_table_alias( ) ) ;
199118 }
200119
201120 #[ test]
202121 fn test_spark_dialect_extract_style ( ) {
203- let dialect = SparkDialect ;
122+ let dialect = build_spark_dialect ( ) ;
204123 assert ! ( matches!(
205124 dialect. date_field_extract_style( ) ,
206125 DateFieldExtractStyle :: Extract
@@ -209,7 +128,7 @@ mod tests {
209128
210129 #[ test]
211130 fn test_spark_dialect_character_length_style ( ) {
212- let dialect = SparkDialect ;
131+ let dialect = build_spark_dialect ( ) ;
213132 assert ! ( matches!(
214133 dialect. character_length_style( ) ,
215134 CharacterLengthStyle :: Length
0 commit comments