Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 93 additions & 54 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
}

// slow path, we match the fields by name
to_fields.iter().all(|to_field| {
if to_fields.iter().all(|to_field| {
from_fields
.iter()
.find(|from_field| from_field.name() == to_field.name())
Expand All @@ -263,7 +263,15 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
// cast kernel will return error.
can_cast_types(from_field.data_type(), to_field.data_type())
})
})
}) {
return true;
}

// if we couldn't match by name, we try to see if they can be matched by position
from_fields
.iter()
.zip(to_fields.iter())
.all(|(f1, f2)| can_cast_types(f1.data_type(), f2.data_type()))
}
(Struct(_), _) => false,
(_, Struct(_)) => false,
Expand Down Expand Up @@ -1218,49 +1226,12 @@ pub fn cast_with_options(
cast_options,
)
}
(Struct(from_fields), Struct(to_fields)) => {
let array = array.as_struct();

// Fast path: if field names are in the same order, we can just zip and cast
let fields_match_order = from_fields.len() == to_fields.len()
&& from_fields
.iter()
.zip(to_fields.iter())
.all(|(f1, f2)| f1.name() == f2.name());

let fields = if fields_match_order {
// Fast path: cast columns in order
array
.columns()
.iter()
.zip(to_fields.iter())
.map(|(column, field)| {
cast_with_options(column, field.data_type(), cast_options)
})
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
} else {
// Slow path: match fields by name and reorder
to_fields
.iter()
.map(|to_field| {
let from_field_idx = from_fields
.iter()
.position(|from_field| from_field.name() == to_field.name())
.ok_or_else(|| {
ArrowError::CastError(format!(
"Field '{}' not found in source struct",
to_field.name()
))
})?;
let column = array.column(from_field_idx);
cast_with_options(column, to_field.data_type(), cast_options)
})
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
};

let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
Ok(Arc::new(array) as ArrayRef)
}
(Struct(from_fields), Struct(to_fields)) => cast_struct_to_struct(
array.as_struct(),
from_fields.clone(),
to_fields.clone(),
cast_options,
),
(Struct(_), _) => Err(ArrowError::CastError(format!(
"Casting from {from_type} to {to_type} not supported"
))),
Expand Down Expand Up @@ -2292,6 +2263,74 @@ pub fn cast_with_options(
}
}

fn cast_struct_to_struct(
array: &StructArray,
from_fields: Fields,
to_fields: Fields,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
// Fast path: if field names are in the same order, we can just zip and cast
let fields_match_order = from_fields.len() == to_fields.len()
&& from_fields
.iter()
.zip(to_fields.iter())
.all(|(f1, f2)| f1.name() == f2.name());

let fields = if fields_match_order {
// Fast path: cast columns in order if their names match
cast_struct_fields_in_order(array, to_fields.clone(), cast_options)?
} else {
let all_fields_match_by_name = to_fields.iter().all(|to_field| {
from_fields
.iter()
.any(|from_field| from_field.name() == to_field.name())
});

if all_fields_match_by_name {
// Slow path: match fields by name and reorder
cast_struct_fields_by_name(array, from_fields.clone(), to_fields.clone(), cast_options)?
} else {
// Fallback: cast field by field in order
cast_struct_fields_in_order(array, to_fields.clone(), cast_options)?
}
};

let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
Ok(Arc::new(array) as ArrayRef)
}

fn cast_struct_fields_by_name(
array: &StructArray,
from_fields: Fields,
to_fields: Fields,
cast_options: &CastOptions,
) -> Result<Vec<ArrayRef>, ArrowError> {
to_fields
.iter()
.map(|to_field| {
let from_field_idx = from_fields
.iter()
.position(|from_field| from_field.name() == to_field.name())
.unwrap(); // safe because we checked above
let column = array.column(from_field_idx);
cast_with_options(column, to_field.data_type(), cast_options)
})
.collect::<Result<Vec<ArrayRef>, ArrowError>>()
}

fn cast_struct_fields_in_order(
array: &StructArray,
to_fields: Fields,
cast_options: &CastOptions,
) -> Result<Vec<ArrayRef>, ArrowError> {
array
.columns()
.iter()
.zip(to_fields.iter())
.map(|(l, field)| cast_with_options(l, field.data_type(), cast_options))
.collect::<Result<Vec<ArrayRef>, ArrowError>>()
}

fn cast_from_decimal<D, F>(
array: &dyn Array,
base: D::Native,
Expand Down Expand Up @@ -10917,11 +10956,11 @@ mod tests {
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("a", DataType::Boolean, false)),
Arc::new(Field::new("b", DataType::Boolean, false)),
boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("b", DataType::Int32, false)),
Arc::new(Field::new("c", DataType::Int32, false)),
int.clone() as ArrayRef,
),
]);
Expand Down Expand Up @@ -10965,11 +11004,11 @@ mod tests {
let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None]));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("a", DataType::Boolean, false)),
Arc::new(Field::new("b", DataType::Boolean, false)),
boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("b", DataType::Int32, true)),
Arc::new(Field::new("c", DataType::Int32, true)),
int.clone() as ArrayRef,
),
]);
Expand Down Expand Up @@ -10999,11 +11038,11 @@ mod tests {
let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100]));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("a", DataType::Boolean, false)),
Arc::new(Field::new("b", DataType::Boolean, false)),
boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("b", DataType::Int32, false)),
Arc::new(Field::new("c", DataType::Int32, false)),
int.clone() as ArrayRef,
),
]);
Expand Down Expand Up @@ -11139,7 +11178,7 @@ mod tests {
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Cast error: Field 'b' not found in source struct"
"Invalid argument error: Incorrect number of arrays for StructArray fields, expected 2 got 1"
);
}

Expand Down Expand Up @@ -11196,7 +11235,7 @@ mod tests {
}

#[test]
fn test_can_cast_struct_with_missing_field() {
fn test_can_cast_struct_rename_field() {
// Test that can_cast_types returns false when target has a field not in source
let from_type = DataType::Struct(
vec![
Expand All @@ -11214,7 +11253,7 @@ mod tests {
.into(),
);

assert!(!can_cast_types(&from_type, &to_type));
assert!(can_cast_types(&from_type, &to_type));
}

fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) {
Expand Down
Loading