@@ -19,16 +19,24 @@ use super::super::conversion::*;
1919use super :: error:: { DFSqlLogicTestError , Result } ;
2020use crate :: engines:: output:: DFColumnType ;
2121use arrow:: array:: { Array , AsArray } ;
22- use arrow:: datatypes:: { Fields , Schema } ;
23- use arrow:: util:: display:: ArrayFormatter ;
24- use arrow:: { array, array:: ArrayRef , datatypes:: DataType , record_batch:: RecordBatch } ;
22+ use arrow:: datatypes:: { Field , Fields , Schema } ;
23+ use arrow:: error:: ArrowError ;
24+ use arrow:: util:: display:: {
25+ ArrayFormatter , ArrayFormatterFactory , DisplayIndex , FormatOptions , FormatResult ,
26+ } ;
27+ use arrow:: { array, datatypes:: DataType , record_batch:: RecordBatch } ;
28+ use datafusion:: catalog:: Session ;
2529use datafusion:: common:: internal_datafusion_err;
2630use datafusion:: config:: ConfigField ;
31+ use datafusion:: logical_expr:: extension_types:: DFArrayFormatterFactory ;
32+ use datafusion:: prelude:: SessionContext ;
33+ use std:: fmt:: Write ;
2734use std:: path:: PathBuf ;
28- use std:: sync:: LazyLock ;
35+ use std:: sync:: { Arc , LazyLock } ;
2936
3037/// Converts `batches` to a result as expected by sqllogictest.
3138pub fn convert_batches (
39+ ctx : & SessionContext ,
3240 schema : & Schema ,
3341 batches : Vec < RecordBatch > ,
3442 is_spark_path : bool ,
@@ -44,21 +52,51 @@ pub fn convert_batches(
4452 ) ) ) ;
4553 }
4654
55+ let state = ctx. state ( ) ;
56+ let options = state. config ( ) . options ( ) . format . clone ( ) ;
57+ let arrow_options: FormatOptions = ( & options) . try_into ( ) ?;
58+
59+ let registry = state. extension_type_registry ( ) ;
60+ let plain_formatter_factory = DFArrayFormatterFactory :: new ( Arc :: clone ( registry) ) ;
61+ let formatter_factory =
62+ NormalizingArrayFormatterFactory :: new ( plain_formatter_factory, is_spark_path) ;
63+
64+ let arrow_options = arrow_options
65+ . with_formatter_factory ( Some ( & formatter_factory) )
66+ . with_null ( "NULL" ) ;
67+
68+ let formatters = batch
69+ . columns ( )
70+ . iter ( )
71+ . zip ( schema. fields ( ) )
72+ . map ( |( col, field) | {
73+ let formatter = formatter_factory. create_array_formatter (
74+ col,
75+ & arrow_options,
76+ Some ( field) ,
77+ ) ?;
78+
79+ match formatter {
80+ None => Ok ( ArrayFormatter :: try_new ( col. as_ref ( ) , & arrow_options) ?) ,
81+ Some ( formatter) => Ok ( formatter) ,
82+ }
83+ } )
84+ . collect :: < std:: result:: Result < Vec < _ > , ArrowError > > ( ) ?;
85+
4786 // Convert a single batch to a `Vec<Vec<String>>` for comparison, flatten expanded rows, and normalize each.
4887 let new_rows = ( 0 ..batch. num_rows ( ) )
4988 . map ( |row| {
50- batch
51- . columns ( )
89+ formatters
5290 . iter ( )
53- . map ( |col| cell_to_string ( col , row, is_spark_path ) )
54- . collect :: < Result < Vec < String > > > ( )
91+ . map ( |f| f . value ( row) . to_string ( ) )
92+ . collect :: < Vec < String > > ( )
5593 } )
56- . collect :: < Result < Vec < Vec < String > > > > ( ) ?
57- . into_iter ( )
5894 . flat_map ( expand_row)
5995 . map ( normalize_paths) ;
96+
6097 rows. extend ( new_rows) ;
6198 }
99+
62100 Ok ( rows)
63101}
64102
@@ -185,7 +223,11 @@ macro_rules! get_row_value {
185223/// [NULL Values and empty strings]: https://duckdb.org/dev/sqllogictest/result_verification#null-values-and-empty-strings
186224///
187225/// Floating numbers are rounded to have a consistent representation with the Postgres runner.
188- pub fn cell_to_string ( col : & ArrayRef , row : usize , is_spark_path : bool ) -> Result < String > {
226+ pub fn cell_to_string (
227+ col : & dyn Array ,
228+ row : usize ,
229+ is_spark_path : bool ,
230+ ) -> Result < String > {
189231 if col. is_null ( row) {
190232 // represent any null value with the string "NULL"
191233 Ok ( NULL_STR . to_string ( ) )
@@ -233,18 +275,18 @@ pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result
233275 DataType :: Dictionary ( _, _) => {
234276 let dict = col. as_any_dictionary ( ) ;
235277 let key = dict. normalized_keys ( ) [ row] ;
236- Ok ( cell_to_string ( dict. values ( ) , key, is_spark_path) ?)
278+ Ok ( cell_to_string ( dict. values ( ) . as_ref ( ) , key, is_spark_path) ?)
237279 }
238280 _ => {
239281 let mut datafusion_format_options =
240282 datafusion:: config:: FormatOptions :: default ( ) ;
241283
242284 datafusion_format_options. set ( "null" , "NULL" ) . unwrap ( ) ;
243285
244- let arrow_format_options: arrow :: util :: display :: FormatOptions =
286+ let arrow_format_options: FormatOptions =
245287 ( & datafusion_format_options) . try_into ( ) . unwrap ( ) ;
246288
247- let f = ArrayFormatter :: try_new ( col. as_ref ( ) , & arrow_format_options) ?;
289+ let f = ArrayFormatter :: try_new ( col, & arrow_format_options) ?;
248290
249291 Ok ( f. value ( row) . to_string ( ) )
250292 }
@@ -298,3 +340,75 @@ pub fn convert_schema_to_types(columns: &Fields) -> Vec<DFColumnType> {
298340 } )
299341 . collect ( )
300342}
343+
344+ /// Wraps a [`DFArrayFormatterFactory`] and intercepts formatting columns that must be normalized.
345+ #[ derive( Debug ) ]
346+ pub struct NormalizingArrayFormatterFactory {
347+ /// The inner formatter factory from DataFusion.
348+ inner : DFArrayFormatterFactory ,
349+ /// Whether the test is a Spark test.
350+ is_spark_path : bool ,
351+ }
352+
353+ impl NormalizingArrayFormatterFactory {
354+ /// Creates a new [`NormalizingArrayFormatterFactory`].
355+ pub fn new ( inner : DFArrayFormatterFactory , is_spark_path : bool ) -> Self {
356+ Self {
357+ inner,
358+ is_spark_path,
359+ }
360+ }
361+ }
362+
363+ impl ArrayFormatterFactory for NormalizingArrayFormatterFactory {
364+ fn create_array_formatter < ' formatter > (
365+ & self ,
366+ array : & ' formatter dyn Array ,
367+ options : & FormatOptions < ' formatter > ,
368+ field : Option < & ' formatter Field > ,
369+ ) -> std:: result:: Result < Option < ArrayFormatter < ' formatter > > , ArrowError > {
370+ // Extension types are always formatted via DataFusion.
371+ if let Some ( field) = field {
372+ if field. extension_type_name ( ) . is_some ( ) {
373+ return self
374+ . inner
375+ . create_array_formatter ( array, options, Some ( field) ) ;
376+ }
377+ }
378+
379+ // Intercept normalizing formatting of columns that must be normalized.
380+ match array. data_type ( ) {
381+ DataType :: Boolean
382+ | DataType :: Float16
383+ | DataType :: Float32
384+ | DataType :: Float64
385+ | DataType :: Decimal128 ( _, _)
386+ | DataType :: Decimal256 ( _, _)
387+ | DataType :: Utf8
388+ | DataType :: LargeUtf8
389+ | DataType :: Utf8View => {
390+ let display = SLTDisplayIndex {
391+ array,
392+ is_spark_path : self . is_spark_path ,
393+ } ;
394+ Ok ( Some ( ArrayFormatter :: new ( Box :: new ( display) , options. safe ( ) ) ) )
395+ }
396+ _ => self . inner . create_array_formatter ( array, options, field) ,
397+ }
398+ }
399+ }
400+
401+ /// Implements [`DisplayIndex`] by normalizing the values of the array using [`cell_to_string`].
402+ struct SLTDisplayIndex < ' a > {
403+ array : & ' a dyn Array ,
404+ is_spark_path : bool ,
405+ }
406+
407+ impl DisplayIndex for SLTDisplayIndex < ' _ > {
408+ fn write ( & self , idx : usize , f : & mut dyn Write ) -> FormatResult {
409+ let s = cell_to_string ( self . array , idx, self . is_spark_path )
410+ . map_err ( |_| std:: fmt:: Error ) ?;
411+ write ! ( f, "{s}" ) ?;
412+ Ok ( ( ) )
413+ }
414+ }
0 commit comments