Skip to content

Commit 0ed06b3

Browse files
feat(vortex-datafusion): struct scalar conversion + extension-over-struct scan (#8453)
<!-- Thank you for submitting a pull request! We appreciate your time and effort. Please make sure to provide enough information so that we can review your pull request. The Summary and Testing sections below contain guidance on what to include. --> ## Summary 1. DataFusion and Vortex can now exchange struct-shaped scalars. 2. Scan can resolve columns whose type is an extension over a struct. <!-- ## API Changes Uncomment this section if there are any user-facing changes. Consider whether the change affects users in one of the following ways: 1. Breaks public APIs in some way. 3. Changes the underlying behavior of one of the engine integrations. 4. Should some documentation be updated to reflect this change? If a public API is changed in a breaking manner, make sure to add the appropriate label. --> ## Testing <!-- Please describe how this change was tested. Here are some common categories for testing in Vortex: 1. Verifying existing behavior is maintained. 2. Verifying new behavior and functionality works correctly. 3. Serialization compatibility (backwards and forwards) should be maintained or explicitly broken. --> add null/non-null struct scalar round-trips. --------- Signed-off-by: Nemo Yu <zyu379@wisc.edu>
1 parent a45bb60 commit 0ed06b3

4 files changed

Lines changed: 128 additions & 3 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-datafusion/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ rust-version = { workspace = true }
1414
version = { workspace = true }
1515

1616
[dependencies]
17+
arrow-array = { workspace = true }
1718
arrow-schema = { workspace = true }
1819
async-trait = { workspace = true }
1920
datafusion-catalog = { workspace = true }

vortex-datafusion/src/convert/scalars.rs

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
6+
use arrow_array::Array;
7+
use arrow_array::StructArray;
8+
use arrow_schema::Field;
9+
use arrow_schema::Fields;
410
use datafusion_common::ScalarValue;
511
use vortex::buffer::ByteBuffer;
612
use vortex::dtype::DType;
@@ -14,6 +20,8 @@ use vortex::dtype::i256;
1420
use vortex::error::VortexExpect;
1521
use vortex::error::VortexResult;
1622
use vortex::error::vortex_bail;
23+
use vortex::error::vortex_err;
24+
use vortex::error::vortex_panic;
1725
use vortex::extension::datetime::AnyTemporal;
1826
use vortex::extension::datetime::TemporalMetadata;
1927
use vortex::extension::datetime::TimeUnit;
@@ -113,7 +121,7 @@ impl TryToDataFusion<ScalarValue> for Scalar {
113121
),
114122
DType::List(..) => todo!("list scalar conversion"),
115123
DType::FixedSizeList(..) => todo!("fixed-size list scalar conversion"),
116-
DType::Struct(..) => todo!("struct scalar conversion"),
124+
DType::Struct(..) => struct_to_df(self)?,
117125
DType::Union(..) => todo!("union scalar conversion"),
118126
DType::Variant(_) => vortex_bail!("Variant scalars aren't supported with DF"),
119127
DType::Extension(ext) => {
@@ -288,11 +296,74 @@ impl FromDataFusion<ScalarValue> for Scalar {
288296
}
289297
}
290298
ScalarValue::Dictionary(_, v) => Scalar::from_df(v.as_ref()),
299+
ScalarValue::Struct(array) => struct_from_df(array),
291300
_ => unimplemented!("Can't convert {value:?} value to a Vortex scalar"),
292301
}
293302
}
294303
}
295304

305+
/// Converts a Vortex struct scalar to a DataFusion `ScalarValue::Struct`.
306+
fn struct_to_df(scalar: &Scalar) -> VortexResult<ScalarValue> {
307+
let scalar = scalar.as_struct();
308+
let struct_fields = scalar.struct_fields();
309+
let (fields, arrays): (Vec<Field>, Vec<_>) = struct_fields
310+
.names()
311+
.iter()
312+
.zip(struct_fields.fields())
313+
.enumerate()
314+
.map(|(idx, (name, field_dtype))| {
315+
let nullable = field_dtype.is_nullable();
316+
let child = if scalar.is_null() {
317+
Scalar::null(field_dtype)
318+
} else {
319+
scalar
320+
.field_by_idx(idx)
321+
.ok_or_else(|| vortex_err!("missing struct field {name}"))?
322+
};
323+
let array = child
324+
.try_to_df()?
325+
.to_array()
326+
.map_err(|e| vortex_err!("failed to build struct field array: {e}"))?;
327+
Ok((
328+
Field::new(name.as_ref(), array.data_type().clone(), nullable),
329+
array,
330+
))
331+
})
332+
.collect::<VortexResult<Vec<_>>>()?
333+
.into_iter()
334+
.unzip();
335+
336+
let fields = Fields::from(fields);
337+
let struct_array = if scalar.is_null() {
338+
StructArray::new_null(fields, 1)
339+
} else {
340+
StructArray::try_new(fields, arrays, None)
341+
.map_err(|e| vortex_err!("failed to build struct scalar array: {e}"))?
342+
};
343+
Ok(ScalarValue::Struct(Arc::new(struct_array)))
344+
}
345+
346+
/// Converts a DataFusion `ScalarValue::Struct` (a one-row struct array) to a Vortex struct scalar.
347+
fn struct_from_df(array: &StructArray) -> Scalar {
348+
let dtype = DType::from_arrow((array.data_type(), Nullability::Nullable));
349+
if array.is_null(0) {
350+
Scalar::null(dtype)
351+
} else {
352+
let children = array
353+
.columns()
354+
.iter()
355+
.map(|column| {
356+
Scalar::from_df(
357+
&ScalarValue::try_from_array(column.as_ref(), 0).unwrap_or_else(|e| {
358+
vortex_panic!("cannot convert struct field to a Vortex scalar: {e}")
359+
}),
360+
)
361+
})
362+
.collect::<Vec<_>>();
363+
Scalar::struct_(dtype, children)
364+
}
365+
}
366+
296367
#[cfg(test)]
297368
mod tests {
298369
use datafusion_common::ScalarValue;
@@ -301,8 +372,10 @@ mod tests {
301372
use vortex::buffer::ByteBuffer;
302373
use vortex::dtype::DType;
303374
use vortex::dtype::DecimalDType;
375+
use vortex::dtype::FieldNames;
304376
use vortex::dtype::Nullability;
305377
use vortex::dtype::PType;
378+
use vortex::dtype::StructFields;
306379
use vortex::dtype::i256;
307380
use vortex::scalar::DecimalValue;
308381
use vortex::scalar::Scalar;
@@ -691,4 +764,49 @@ mod tests {
691764
.into();
692765
assert_eq!(result_bytes, vec![1u8, 2, 3, 4, 5]);
693766
}
767+
768+
#[test]
769+
fn struct_scalar_round_trips() -> VortexResult<()> {
770+
let dtype = DType::Struct(
771+
StructFields::new(
772+
FieldNames::from(["x", "y"]),
773+
vec![
774+
DType::Primitive(PType::F64, Nullability::NonNullable),
775+
DType::Primitive(PType::F64, Nullability::NonNullable),
776+
],
777+
),
778+
Nullability::NonNullable,
779+
);
780+
let original = Scalar::struct_(
781+
dtype,
782+
vec![Scalar::from(-111.7610f64), Scalar::from(34.8697f64)],
783+
);
784+
785+
let df = original.try_to_df()?;
786+
assert!(matches!(df, ScalarValue::Struct(_)));
787+
788+
// Back through `from_df` and out again yields the identical DataFusion struct value.
789+
let back = Scalar::from_df(&df);
790+
assert_eq!(back.try_to_df()?, df);
791+
Ok(())
792+
}
793+
794+
#[test]
795+
fn null_struct_scalar_round_trips() -> VortexResult<()> {
796+
let dtype = DType::Struct(
797+
StructFields::new(
798+
FieldNames::from(["x", "y"]),
799+
vec![
800+
DType::Primitive(PType::F64, Nullability::Nullable),
801+
DType::Primitive(PType::F64, Nullability::Nullable),
802+
],
803+
),
804+
Nullability::Nullable,
805+
);
806+
807+
let df = Scalar::null(dtype).try_to_df()?;
808+
assert!(matches!(df, ScalarValue::Struct(_)));
809+
assert!(Scalar::from_df(&df).is_null());
810+
Ok(())
811+
}
694812
}

vortex-datafusion/src/convert/schema.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,14 @@ fn calculate_physical_field_type(
8585
// RunEndEncoded loses its encoding
8686
DataType::RunEndEncoded(..) => logical_type.clone(),
8787

88-
// For struct types, recursively check each field
88+
// For struct types, recursively check each field.
8989
DataType::Struct(logical_fields) => {
90-
if let DType::Struct(struct_dtype, _) = dtype {
90+
// Walk through any extension layers to reach the underlying struct fields.
91+
let mut inner = dtype;
92+
while let DType::Extension(ext) = inner {
93+
inner = ext.storage_dtype();
94+
}
95+
if let DType::Struct(struct_dtype, _) = inner {
9196
let physical_fields: Vec<Field> = struct_dtype
9297
.names()
9398
.iter()

0 commit comments

Comments
 (0)