Skip to content

Commit 35a40aa

Browse files
Merge pull request #58 from pydantic/try-cast
support union comparison for try-cast
2 parents ac577e7 + 3c86e77 commit 35a40aa

4 files changed

Lines changed: 142 additions & 78 deletions

File tree

datafusion/core/tests/sql/union_comparison.rs

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,68 @@ async fn test_union_cast_compatible_variant() -> Result<()> {
433433
Ok(())
434434
}
435435

436+
/// Tests TRY_CAST from union type to scalar type.
437+
/// TRY_CAST should extract the matching variant and return NULL for non-matching ones,
438+
/// just like CAST, but also return NULL (instead of erroring) for cast failures.
439+
#[tokio::test]
440+
async fn test_union_try_cast() -> Result<()> {
441+
let union_array = create_sparse_union_array(vec![
442+
UnionValue::Int(Some(10)),
443+
UnionValue::Str(Some("hello")),
444+
UnionValue::Int(Some(30)),
445+
UnionValue::Str(Some("42")),
446+
]);
447+
448+
let schema = Arc::new(Schema::new(vec![
449+
Field::new("id", DataType::Int32, false),
450+
Field::new(
451+
"val",
452+
DataType::Union(
453+
UnionFields::new(
454+
vec![0, 1],
455+
vec![
456+
Field::new("int", DataType::Int32, true),
457+
Field::new("str", DataType::Utf8, true),
458+
],
459+
),
460+
UnionMode::Sparse,
461+
),
462+
true,
463+
),
464+
]));
465+
466+
let batch = RecordBatch::try_new(
467+
schema,
468+
vec![
469+
Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
470+
Arc::new(union_array),
471+
],
472+
)?;
473+
474+
let ctx = SessionContext::new();
475+
ctx.register_batch("test", batch)?;
476+
477+
// TRY_CAST to INT should extract int variants, NULL for string variants
478+
let df = ctx
479+
.sql("SELECT id, TRY_CAST(val AS INT) as val_int FROM test")
480+
.await?;
481+
let results = df.collect().await?;
482+
483+
let expected = [
484+
"+----+---------+",
485+
"| id | val_int |",
486+
"+----+---------+",
487+
"| 1 | 10 |",
488+
"| 2 | |",
489+
"| 3 | 30 |",
490+
"| 4 | |",
491+
"+----+---------+",
492+
];
493+
assert_batches_eq!(expected, &results);
494+
495+
Ok(())
496+
}
497+
436498
/// Tests union-to-union equality comparison (supported via arrow-ord).
437499
#[tokio::test]
438500
async fn test_union_eq_same_union() -> Result<()> {
@@ -481,9 +543,7 @@ async fn test_union_eq_same_union() -> Result<()> {
481543
ctx.register_batch("test", batch)?;
482544

483545
// Row 1: Int(10) = Int(10) -> true; Row 2: Str("hello") = Str("world") -> false
484-
let df = ctx
485-
.sql("SELECT id FROM test WHERE val1 = val2")
486-
.await?;
546+
let df = ctx.sql("SELECT id FROM test WHERE val1 = val2").await?;
487547
let results = df.collect().await?;
488548

489549
let expected = ["+----+", "| id |", "+----+", "| 1 |", "+----+"];

datafusion/datasource-parquet/src/row_filter.rs

Lines changed: 63 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter};
7878
use parquet::file::metadata::ParquetMetaData;
7979
use parquet::schema::types::SchemaDescriptor;
8080

81-
use datafusion_common::{Result, ScalarValue};
8281
use datafusion_common::cast::as_boolean_array;
8382
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
83+
use datafusion_common::{Result, ScalarValue};
8484
use datafusion_physical_expr::ScalarFunctionExpr;
8585
use datafusion_physical_expr::expressions::{Column, Literal};
8686
use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns};
@@ -398,7 +398,7 @@ impl<'schema> PushdownChecker<'schema> {
398398
/// See <https://github.com/datafusion-contrib/datafusion-variant> for the
399399
/// `datafusion-variant` crate that defines these UDFs.
400400
const VARIANT_UDF_NAMES: &[&str] = &[
401-
"variant_get", // variant_get, variant_get_str, variant_get_int, etc.
401+
"variant_get", // variant_get, variant_get_str, variant_get_int, etc.
402402
"is_variant_null",
403403
];
404404

@@ -513,76 +513,69 @@ impl TreeNodeVisitor<'_> for PushdownChecker<'_> {
513513
// - `metadata` — always needed (variant metadata dictionary)
514514
// - `value` — always needed (fallback for non-shredded values)
515515
// - `typed_value.<path...>` — the specific shredded field(s)
516-
if let Some(func_expr) = node.as_any().downcast_ref::<ScalarFunctionExpr>() {
517-
if is_variant_udf_name(func_expr.name()) {
518-
if let Some(column) = func_expr
519-
.args()
520-
.first()
521-
.and_then(|a| a.as_any().downcast_ref::<Column>())
522-
{
523-
let Ok(root_idx) = self.file_schema.index_of(column.name()) else {
524-
self.projected_columns = true;
525-
return Ok(TreeNodeRecursion::Jump);
526-
};
527-
528-
// Extract the variant path from the second argument.
529-
// It can be a string literal or a list of string literals.
530-
let variant_path: Option<Vec<String>> =
531-
func_expr.args().get(1).and_then(|arg| {
532-
let lit = arg.as_any().downcast_ref::<Literal>()?;
533-
match lit.value() {
534-
ScalarValue::Utf8(Some(s))
535-
| ScalarValue::Utf8View(Some(s))
536-
| ScalarValue::LargeUtf8(Some(s)) => {
537-
Some(vec![s.to_string()])
538-
}
539-
ScalarValue::List(arr) if !arr.is_null(0) => {
540-
let values = arr.value(0);
541-
let strings =
542-
values.as_any().downcast_ref::<StringArray>()?;
543-
let path: Vec<String> = (0..strings.len())
544-
.filter_map(|i| {
545-
strings
546-
.is_valid(i)
547-
.then(|| strings.value(i).to_string())
548-
})
549-
.collect();
550-
Some(path)
551-
}
552-
_ => None,
553-
}
554-
});
555-
556-
// Record struct field accesses for the variant sub-fields:
557-
// metadata, value, and typed_value.<path>
558-
self.struct_field_accesses.push(StructFieldAccess {
559-
root_index: root_idx,
560-
field_path: vec!["metadata".to_string()],
561-
});
562-
self.struct_field_accesses.push(StructFieldAccess {
563-
root_index: root_idx,
564-
field_path: vec!["value".to_string()],
565-
});
566-
567-
if let Some(path) = variant_path {
568-
// typed_value.<field1>.<field2>...
569-
let mut typed_value_path = vec!["typed_value".to_string()];
570-
typed_value_path.extend(path);
571-
self.struct_field_accesses.push(StructFieldAccess {
572-
root_index: root_idx,
573-
field_path: typed_value_path,
574-
});
575-
} else {
576-
// Can't determine path statically — read entire typed_value
577-
self.struct_field_accesses.push(StructFieldAccess {
578-
root_index: root_idx,
579-
field_path: vec!["typed_value".to_string()],
580-
});
581-
}
516+
if let Some(func_expr) = node.as_any().downcast_ref::<ScalarFunctionExpr>()
517+
&& is_variant_udf_name(func_expr.name())
518+
&& let Some(column) = func_expr
519+
.args()
520+
.first()
521+
.and_then(|a| a.as_any().downcast_ref::<Column>())
522+
{
523+
let Ok(root_idx) = self.file_schema.index_of(column.name()) else {
524+
self.projected_columns = true;
525+
return Ok(TreeNodeRecursion::Jump);
526+
};
582527

583-
return Ok(TreeNodeRecursion::Jump);
584-
}
528+
// Extract the variant path from the second argument.
529+
// It can be a string literal or a list of string literals.
530+
let variant_path: Option<Vec<String>> =
531+
func_expr.args().get(1).and_then(|arg| {
532+
let lit = arg.as_any().downcast_ref::<Literal>()?;
533+
match lit.value() {
534+
ScalarValue::Utf8(Some(s))
535+
| ScalarValue::Utf8View(Some(s))
536+
| ScalarValue::LargeUtf8(Some(s)) => Some(vec![s.to_string()]),
537+
ScalarValue::List(arr) if !arr.is_null(0) => {
538+
let values = arr.value(0);
539+
let strings =
540+
values.as_any().downcast_ref::<StringArray>()?;
541+
let path: Vec<String> = (0..strings.len())
542+
.filter(|&i| strings.is_valid(i))
543+
.map(|i| strings.value(i).to_string())
544+
.collect();
545+
Some(path)
546+
}
547+
_ => None,
548+
}
549+
});
550+
551+
// Record struct field accesses for the variant sub-fields:
552+
// metadata, value, and typed_value.<path>
553+
self.struct_field_accesses.push(StructFieldAccess {
554+
root_index: root_idx,
555+
field_path: vec!["metadata".to_string()],
556+
});
557+
self.struct_field_accesses.push(StructFieldAccess {
558+
root_index: root_idx,
559+
field_path: vec!["value".to_string()],
560+
});
561+
562+
if let Some(path) = variant_path {
563+
// typed_value.<field1>.<field2>...
564+
let mut typed_value_path = vec!["typed_value".to_string()];
565+
typed_value_path.extend(path);
566+
self.struct_field_accesses.push(StructFieldAccess {
567+
root_index: root_idx,
568+
field_path: typed_value_path,
569+
});
570+
} else {
571+
// Can't determine path statically — read entire typed_value
572+
self.struct_field_accesses.push(StructFieldAccess {
573+
root_index: root_idx,
574+
field_path: vec!["typed_value".to_string()],
575+
});
585576
}
577+
578+
return Ok(TreeNodeRecursion::Jump);
586579
}
587580

588581
if let Some(column) = node.as_any().downcast_ref::<Column>()

datafusion/expr-common/src/columnar_value.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use arrow::buffer::NullBuffer;
2121
use arrow::{
2222
array::{Array, ArrayRef, Date32Array, Date64Array, NullArray},
23-
compute::{can_cast_types, CastOptions, kernels, max, min},
23+
compute::{CastOptions, can_cast_types, kernels, max, min},
2424
datatypes::DataType,
2525
error::ArrowError,
2626
util::pretty::pretty_format_columns,
@@ -322,7 +322,8 @@ fn cast_array_by_name(
322322
_ if datafusion_common::nested_struct::requires_nested_struct_cast(
323323
array.data_type(),
324324
cast_type,
325-
) => {
325+
) =>
326+
{
326327
datafusion_common::nested_struct::cast_column(array, cast_type, cast_options)
327328
}
328329
_ => {
@@ -346,7 +347,7 @@ fn cast_union_array(
346347
to_type: &DataType,
347348
cast_options: &CastOptions,
348349
) -> Result<ArrayRef, ArrowError> {
349-
use arrow::array::{make_array, new_null_array, UInt32Array, UnionArray};
350+
use arrow::array::{UInt32Array, UnionArray, make_array, new_null_array};
350351
use arrow::compute::cast_with_options;
351352
use arrow::datatypes::UnionMode;
352353

datafusion/physical-expr/src/expressions/try_cast.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,18 @@ pub fn try_cast(
138138
) -> Result<Arc<dyn PhysicalExpr>> {
139139
let expr_type = expr.data_type(input_schema)?;
140140
if expr_type == cast_type {
141-
Ok(Arc::clone(&expr))
142-
} else if can_cast_types(&expr_type, &cast_type) {
141+
return Ok(Arc::clone(&expr));
142+
}
143+
144+
// Check if cast is supported, with special handling for union types.
145+
let can_cast = match &expr_type {
146+
DataType::Union(fields, _) => fields
147+
.iter()
148+
.any(|(_, f)| can_cast_types(f.data_type(), &cast_type)),
149+
_ => can_cast_types(&expr_type, &cast_type),
150+
};
151+
152+
if can_cast {
143153
Ok(Arc::new(TryCastExpr::new(expr, cast_type)))
144154
} else {
145155
not_impl_err!("Unsupported TRY_CAST from {expr_type} to {cast_type}")

0 commit comments

Comments
 (0)