Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,9 @@ impl ScalarUDFImpl for ArrowCastFunc {
self.name()
)
},
|casted_type| match casted_type.parse::<DataType>() {
Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()),
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
Err(e) => Err(arrow_datafusion_err!(e)),
|casted_type| {
let data_type = parse_arrow_cast_data_type(casted_type)?;
Ok(Field::new(self.name(), data_type, nullable).into())
},
)
}
Expand Down Expand Up @@ -189,10 +188,48 @@ pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result<Dat
);
};

val.parse().map_err(|e| match e {
parse_arrow_cast_data_type(val)
}

pub(crate) fn parse_arrow_cast_data_type(casted_type: &str) -> Result<DataType> {
let data_type = casted_type.parse().map_err(|e| match e {
// If the data type cannot be parsed, return a Plan error to signal an
// error in the input rather than a more general ArrowError
ArrowError::ParseError(e) => exec_datafusion_err!("{e}"),
e => arrow_datafusion_err!(e),
})
})?;
validate_arrow_cast_data_type(&data_type)?;
Ok(data_type)
}

fn validate_arrow_cast_data_type(data_type: &DataType) -> Result<()> {
match data_type {
DataType::FixedSizeBinary(length) if *length < 0 => {
exec_err!("FixedSizeBinary element length must be non-negative, got {length}")
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::ListView(field)
| DataType::LargeListView(field)
| DataType::FixedSizeList(field, _)
| DataType::Map(field, _) => validate_arrow_cast_data_type(field.data_type()),
DataType::Struct(fields) => {
for field in fields.iter() {
validate_arrow_cast_data_type(field.data_type())?;
}
Ok(())
}
DataType::Union(fields, _) => {
for (_, field) in fields.iter() {
validate_arrow_cast_data_type(field.data_type())?;
}
Ok(())
}
DataType::Dictionary(_, value_type) => validate_arrow_cast_data_type(value_type),
DataType::RunEndEncoded(run_ends, values) => {
validate_arrow_cast_data_type(run_ends.data_type())?;
validate_arrow_cast_data_type(values.data_type())
}
_ => Ok(()),
}
}
16 changes: 6 additions & 10 deletions datafusion/functions/src/core/arrow_try_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
//! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast`

use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::{
Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err,
internal_err, types::logical_string, utils::take_function_args,
Result, datatype::DataTypeExt, exec_err, internal_err, types::logical_string,
utils::take_function_args,
};

use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
Expand All @@ -31,7 +30,7 @@ use datafusion_expr::{
};
use datafusion_macros::user_doc;

use super::arrow_cast::data_type_from_type_arg;
use super::arrow_cast::{data_type_from_type_arg, parse_arrow_cast_data_type};

/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring.
///
Expand Down Expand Up @@ -111,12 +110,9 @@ impl ScalarUDFImpl for ArrowTryCastFunc {
self.name()
)
},
|casted_type| match casted_type.parse::<DataType>() {
Ok(data_type) => {
Ok(Field::new(self.name(), data_type, true).into())
}
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
Err(e) => Err(arrow_datafusion_err!(e)),
|casted_type| {
let data_type = parse_arrow_cast_data_type(casted_type)?;
Ok(Field::new(self.name(), data_type, true).into())
},
)
}
Expand Down
9 changes: 9 additions & 0 deletions datafusion/sqllogictest/test_files/arrow_typeof.slt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ SELECT arrow_cast('1', arrow_cast('Utf8', 'Utf8'))
query error DataFusion error: Execution error: Unsupported type 'unknown'\. Must be a supported arrow type name such as 'Int32' or 'Timestamp\(ns\)'\. Error unknown token: unknown
SELECT arrow_cast('1', 'unknown')

query error DataFusion error: Execution error: FixedSizeBinary element length must be non-negative, got -1
SELECT arrow_cast(NULL, 'FixedSizeBinary(-1)')

query error DataFusion error: Execution error: FixedSizeBinary element length must be non-negative, got -1
SELECT arrow_try_cast(NULL, 'FixedSizeBinary(-1)')

query error DataFusion error: Execution error: FixedSizeBinary element length must be non-negative, got -1
SELECT arrow_cast(NULL, 'List(FixedSizeBinary(-1))')

# Round Trip tests:
query TTTTTTTTTTTTTTTTTTTTTTTTT
SELECT
Expand Down