From 5ba2d053e7dea380fe1f271a67e6326136cf89dd Mon Sep 17 00:00:00 2001 From: Sean Doherty Date: Sun, 17 May 2026 02:10:58 -0500 Subject: [PATCH] Validate arrow_cast fixed-size binary length --- datafusion/functions/src/core/arrow_cast.rs | 49 ++++++++++++++++--- .../functions/src/core/arrow_try_cast.rs | 16 +++--- .../sqllogictest/test_files/arrow_typeof.slt | 9 ++++ 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 0b67883c17c87..903543abd0cbd 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -140,10 +140,9 @@ impl ScalarUDFImpl for ArrowCastFunc { self.name() ) }, - |casted_type| match casted_type.parse::() { - Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()), - Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), - Err(e) => Err(arrow_datafusion_err!(e)), + |casted_type| { + let data_type = parse_arrow_cast_data_type(casted_type)?; + Ok(Field::new(self.name(), data_type, nullable).into()) }, ) } @@ -189,10 +188,48 @@ pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result Result { + let data_type = casted_type.parse().map_err(|e| match e { // If the data type cannot be parsed, return a Plan error to signal an // error in the input rather than a more general ArrowError ArrowError::ParseError(e) => exec_datafusion_err!("{e}"), e => arrow_datafusion_err!(e), - }) + })?; + validate_arrow_cast_data_type(&data_type)?; + Ok(data_type) +} + +fn validate_arrow_cast_data_type(data_type: &DataType) -> Result<()> { + match data_type { + DataType::FixedSizeBinary(length) if *length < 0 => { + exec_err!("FixedSizeBinary element length must be non-negative, got {length}") + } + DataType::List(field) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) + | DataType::FixedSizeList(field, _) + | DataType::Map(field, _) => validate_arrow_cast_data_type(field.data_type()), + DataType::Struct(fields) => { + for field in fields.iter() { + validate_arrow_cast_data_type(field.data_type())?; + } + Ok(()) + } + DataType::Union(fields, _) => { + for (_, field) in fields.iter() { + validate_arrow_cast_data_type(field.data_type())?; + } + Ok(()) + } + DataType::Dictionary(_, value_type) => validate_arrow_cast_data_type(value_type), + DataType::RunEndEncoded(run_ends, values) => { + validate_arrow_cast_data_type(run_ends.data_type())?; + validate_arrow_cast_data_type(values.data_type()) + } + _ => Ok(()), + } } diff --git a/datafusion/functions/src/core/arrow_try_cast.rs b/datafusion/functions/src/core/arrow_try_cast.rs index d27b29ba5736d..8d723fc17e77e 100644 --- a/datafusion/functions/src/core/arrow_try_cast.rs +++ b/datafusion/functions/src/core/arrow_try_cast.rs @@ -18,10 +18,9 @@ //! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast` use arrow::datatypes::{DataType, Field, FieldRef}; -use arrow::error::ArrowError; use datafusion_common::{ - Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err, - internal_err, types::logical_string, utils::take_function_args, + Result, datatype::DataTypeExt, exec_err, internal_err, types::logical_string, + utils::take_function_args, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; @@ -31,7 +30,7 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use super::arrow_cast::data_type_from_type_arg; +use super::arrow_cast::{data_type_from_type_arg, parse_arrow_cast_data_type}; /// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring. /// @@ -111,12 +110,9 @@ impl ScalarUDFImpl for ArrowTryCastFunc { self.name() ) }, - |casted_type| match casted_type.parse::() { - Ok(data_type) => { - Ok(Field::new(self.name(), data_type, true).into()) - } - Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), - Err(e) => Err(arrow_datafusion_err!(e)), + |casted_type| { + let data_type = parse_arrow_cast_data_type(casted_type)?; + Ok(Field::new(self.name(), data_type, true).into()) }, ) } diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index e00909ad5fc59..7d594db8a0567 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -104,6 +104,15 @@ SELECT arrow_cast('1', arrow_cast('Utf8', 'Utf8')) query error DataFusion error: Execution error: Unsupported type 'unknown'\. Must be a supported arrow type name such as 'Int32' or 'Timestamp\(ns\)'\. Error unknown token: unknown SELECT arrow_cast('1', 'unknown') +query error DataFusion error: Execution error: FixedSizeBinary element length must be non-negative, got -1 +SELECT arrow_cast(NULL, 'FixedSizeBinary(-1)') + +query error DataFusion error: Execution error: FixedSizeBinary element length must be non-negative, got -1 +SELECT arrow_try_cast(NULL, 'FixedSizeBinary(-1)') + +query error DataFusion error: Execution error: FixedSizeBinary element length must be non-negative, got -1 +SELECT arrow_cast(NULL, 'List(FixedSizeBinary(-1))') + # Round Trip tests: query TTTTTTTTTTTTTTTTTTTTTTTTT SELECT