Skip to content

Commit 0287a68

Browse files
committed
First draft for using extension type registry in SLT tests
1 parent 7116509 commit 0287a68

5 files changed

Lines changed: 156 additions & 25 deletions

File tree

datafusion/sql/src/planner.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ use std::vec;
2323

2424
use crate::utils::make_decimal_type;
2525
use arrow::datatypes::*;
26-
use datafusion_common::TableReference;
2726
use datafusion_common::config::SqlParserOptions;
2827
use datafusion_common::datatype::{DataTypeExt, FieldExt};
2928
use datafusion_common::error::add_possible_columns_to_diag;
30-
use datafusion_common::{DFSchema, DataFusionError, Result, not_impl_err, plan_err};
29+
use datafusion_common::TableReference;
3130
use datafusion_common::{
32-
DFSchemaRef, Diagnostic, SchemaError, field_not_found, internal_err,
33-
plan_datafusion_err,
31+
field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic,
32+
SchemaError,
3433
};
34+
use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result};
3535
use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder};
3636
pub use datafusion_expr::planner::ContextProvider;
3737
use datafusion_expr::utils::find_column_exprs;
38-
use datafusion_expr::{Expr, col};
38+
use datafusion_expr::{col, Expr};
3939
use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo, TimezoneInfo};
4040
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
4141
use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias};

datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs

Lines changed: 128 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,24 @@ use super::super::conversion::*;
1919
use super::error::{DFSqlLogicTestError, Result};
2020
use crate::engines::output::DFColumnType;
2121
use arrow::array::{Array, AsArray};
22-
use arrow::datatypes::{Fields, Schema};
23-
use arrow::util::display::ArrayFormatter;
24-
use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch};
22+
use arrow::datatypes::{Field, Fields, Schema};
23+
use arrow::error::ArrowError;
24+
use arrow::util::display::{
25+
ArrayFormatter, ArrayFormatterFactory, DisplayIndex, FormatOptions, FormatResult,
26+
};
27+
use arrow::{array, datatypes::DataType, record_batch::RecordBatch};
28+
use datafusion::catalog::Session;
2529
use datafusion::common::internal_datafusion_err;
2630
use datafusion::config::ConfigField;
31+
use datafusion::logical_expr::extension_types::DFArrayFormatterFactory;
32+
use datafusion::prelude::SessionContext;
33+
use std::fmt::Write;
2734
use std::path::PathBuf;
28-
use std::sync::LazyLock;
35+
use std::sync::{Arc, LazyLock};
2936

3037
/// Converts `batches` to a result as expected by sqllogictest.
3138
pub fn convert_batches(
39+
ctx: &SessionContext,
3240
schema: &Schema,
3341
batches: Vec<RecordBatch>,
3442
is_spark_path: bool,
@@ -44,21 +52,51 @@ pub fn convert_batches(
4452
)));
4553
}
4654

55+
let state = ctx.state();
56+
let options = state.config().options().format.clone();
57+
let arrow_options: FormatOptions = (&options).try_into()?;
58+
59+
let registry = state.extension_type_registry();
60+
let plain_formatter_factory = DFArrayFormatterFactory::new(Arc::clone(registry));
61+
let formatter_factory =
62+
NormalizingArrayFormatterFactory::new(plain_formatter_factory, is_spark_path);
63+
64+
let arrow_options = arrow_options
65+
.with_formatter_factory(Some(&formatter_factory))
66+
.with_null("NULL");
67+
68+
let formatters = batch
69+
.columns()
70+
.iter()
71+
.zip(schema.fields())
72+
.map(|(col, field)| {
73+
let formatter = formatter_factory.create_array_formatter(
74+
col,
75+
&arrow_options,
76+
Some(field),
77+
)?;
78+
79+
match formatter {
80+
None => Ok(ArrayFormatter::try_new(col.as_ref(), &arrow_options)?),
81+
Some(formatter) => Ok(formatter),
82+
}
83+
})
84+
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
85+
4786
// Convert a single batch to a `Vec<Vec<String>>` for comparison, flatten expanded rows, and normalize each.
4887
let new_rows = (0..batch.num_rows())
4988
.map(|row| {
50-
batch
51-
.columns()
89+
formatters
5290
.iter()
53-
.map(|col| cell_to_string(col, row, is_spark_path))
54-
.collect::<Result<Vec<String>>>()
91+
.map(|f| f.value(row).to_string())
92+
.collect::<Vec<String>>()
5593
})
56-
.collect::<Result<Vec<Vec<String>>>>()?
57-
.into_iter()
5894
.flat_map(expand_row)
5995
.map(normalize_paths);
96+
6097
rows.extend(new_rows);
6198
}
99+
62100
Ok(rows)
63101
}
64102

@@ -185,7 +223,11 @@ macro_rules! get_row_value {
185223
/// [NULL Values and empty strings]: https://duckdb.org/dev/sqllogictest/result_verification#null-values-and-empty-strings
186224
///
187225
/// Floating numbers are rounded to have a consistent representation with the Postgres runner.
188-
pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result<String> {
226+
pub fn cell_to_string(
227+
col: &dyn Array,
228+
row: usize,
229+
is_spark_path: bool,
230+
) -> Result<String> {
189231
if col.is_null(row) {
190232
// represent any null value with the string "NULL"
191233
Ok(NULL_STR.to_string())
@@ -233,18 +275,18 @@ pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result
233275
DataType::Dictionary(_, _) => {
234276
let dict = col.as_any_dictionary();
235277
let key = dict.normalized_keys()[row];
236-
Ok(cell_to_string(dict.values(), key, is_spark_path)?)
278+
Ok(cell_to_string(dict.values().as_ref(), key, is_spark_path)?)
237279
}
238280
_ => {
239281
let mut datafusion_format_options =
240282
datafusion::config::FormatOptions::default();
241283

242284
datafusion_format_options.set("null", "NULL").unwrap();
243285

244-
let arrow_format_options: arrow::util::display::FormatOptions =
286+
let arrow_format_options: FormatOptions =
245287
(&datafusion_format_options).try_into().unwrap();
246288

247-
let f = ArrayFormatter::try_new(col.as_ref(), &arrow_format_options)?;
289+
let f = ArrayFormatter::try_new(col, &arrow_format_options)?;
248290

249291
Ok(f.value(row).to_string())
250292
}
@@ -298,3 +340,75 @@ pub fn convert_schema_to_types(columns: &Fields) -> Vec<DFColumnType> {
298340
})
299341
.collect()
300342
}
343+
344+
/// Wraps a [`DFArrayFormatterFactory`] and intercepts formatting columns that must be normalized.
345+
#[derive(Debug)]
346+
pub struct NormalizingArrayFormatterFactory {
347+
/// The inner formatter factory from DataFusion.
348+
inner: DFArrayFormatterFactory,
349+
/// Whether the test is a Spark test.
350+
is_spark_path: bool,
351+
}
352+
353+
impl NormalizingArrayFormatterFactory {
354+
/// Creates a new [`NormalizingArrayFormatterFactory`].
355+
pub fn new(inner: DFArrayFormatterFactory, is_spark_path: bool) -> Self {
356+
Self {
357+
inner,
358+
is_spark_path,
359+
}
360+
}
361+
}
362+
363+
impl ArrayFormatterFactory for NormalizingArrayFormatterFactory {
364+
fn create_array_formatter<'formatter>(
365+
&self,
366+
array: &'formatter dyn Array,
367+
options: &FormatOptions<'formatter>,
368+
field: Option<&'formatter Field>,
369+
) -> std::result::Result<Option<ArrayFormatter<'formatter>>, ArrowError> {
370+
// Extension types are always formatted via DataFusion.
371+
if let Some(field) = field {
372+
if field.extension_type_name().is_some() {
373+
return self
374+
.inner
375+
.create_array_formatter(array, options, Some(field));
376+
}
377+
}
378+
379+
// Intercept normalizing formatting of columns that must be normalized.
380+
match array.data_type() {
381+
DataType::Boolean
382+
| DataType::Float16
383+
| DataType::Float32
384+
| DataType::Float64
385+
| DataType::Decimal128(_, _)
386+
| DataType::Decimal256(_, _)
387+
| DataType::Utf8
388+
| DataType::LargeUtf8
389+
| DataType::Utf8View => {
390+
let display = SLTDisplayIndex {
391+
array,
392+
is_spark_path: self.is_spark_path,
393+
};
394+
Ok(Some(ArrayFormatter::new(Box::new(display), options.safe())))
395+
}
396+
_ => self.inner.create_array_formatter(array, options, field),
397+
}
398+
}
399+
}
400+
401+
/// Implements [`DisplayIndex`] by normalizing the values of the array using [`cell_to_string`].
402+
struct SLTDisplayIndex<'a> {
403+
array: &'a dyn Array,
404+
is_spark_path: bool,
405+
}
406+
407+
impl DisplayIndex for SLTDisplayIndex<'_> {
408+
fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult {
409+
let s = cell_to_string(self.array, idx, self.is_spark_path)
410+
.map_err(|_| std::fmt::Error)?;
411+
write!(f, "{s}")?;
412+
Ok(())
413+
}
414+
}

datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ async fn run_query(
212212
let stream = execute_stream(plan, task_ctx)?;
213213
let types = normalize::convert_schema_to_types(stream.schema().fields());
214214
let results: Vec<RecordBatch> = collect(stream).await?;
215-
let rows = normalize::convert_batches(&schema, results, is_spark_path)?;
215+
let rows = normalize::convert_batches(ctx, &schema, results, is_spark_path)?;
216216

217217
if rows.is_empty() && types.is_empty() {
218218
Ok(DBOutput::StatementComplete(0))

datafusion/sqllogictest/src/test_context.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ use arrow::record_batch::RecordBatch;
3232
use datafusion::catalog::{
3333
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, Session,
3434
};
35-
use datafusion::common::{DataFusionError, Result, not_impl_err};
35+
use datafusion::common::{not_impl_err, DataFusionError, Result};
3636
use datafusion::functions::math::abs;
3737
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
3838
use datafusion::logical_expr::{
39-
ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
40-
Volatility, create_udf,
39+
create_udf, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
40+
Signature, Volatility,
4141
};
4242
use datafusion::physical_plan::ExecutionPlan;
4343
use datafusion::prelude::*;
@@ -50,8 +50,9 @@ use datafusion_spark::SessionStateBuilderSpark;
5050
use crate::is_spark_path;
5151
use async_trait::async_trait;
5252
use datafusion::common::cast::as_float64_array;
53-
use datafusion::execution::SessionStateBuilder;
5453
use datafusion::execution::runtime_env::RuntimeEnv;
54+
use datafusion::execution::SessionStateBuilder;
55+
use datafusion::logical_expr::registry::MemoryExtensionTypeRegistry;
5556
use log::info;
5657
use tempfile::TempDir;
5758

@@ -148,6 +149,10 @@ impl TestContext {
148149
info!("Registering dummy async udf");
149150
register_async_abs_udf(test_ctx.session_ctx())
150151
}
152+
"sql_extension_types.slt" | "cast_extension_types_metadata.slt" => {
153+
info!("Registering canonical extension types");
154+
register_canonical_extension_types(test_ctx.session_ctx());
155+
}
151156
_ => {
152157
info!("Using default SessionContext");
153158
}
@@ -586,3 +591,15 @@ fn register_async_abs_udf(ctx: &SessionContext) {
586591
let udf = AsyncScalarUDF::new(Arc::new(async_abs));
587592
ctx.register_udf(udf.into_scalar_udf());
588593
}
594+
595+
/// Registers the canonical extension types in the session context.
596+
fn register_canonical_extension_types(test_ctx: &SessionContext) {
597+
let state = test_ctx.state();
598+
let registry = state.extension_type_registry();
599+
registry
600+
.extend(
601+
&MemoryExtensionTypeRegistry::new_with_canonical_extension_types()
602+
.all_extension_types(),
603+
)
604+
.unwrap();
605+
}

datafusion/sqllogictest/test_files/sql_extension_types.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ query ?T
2727
SELECT CAST(arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') AS UUID),
2828
arrow_metadata(CAST(arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') AS UUID), 'ARROW:extension:name');
2929
----
30-
00010203040506070809000102030506 arrow.uuid
30+
00010203-0405-0607-0809-000102030506 arrow.uuid
3131

3232
# CREATE TABLE with UUID column preserves extension metadata through VALUES
3333
statement ok

0 commit comments

Comments
 (0)