Skip to content

Commit 7e66476

Browse files
authored
fix: intersection of list of struct (#3665)
Previously we only descended into children if the field was a struct. This is a problem when we have lists of struct because we might fail to project out child fields.
1 parent 64d3ecb commit 7e66476

3 files changed

Lines changed: 288 additions & 3 deletions

File tree

rust/lance-arrow/src/lib.rs

Lines changed: 259 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use arrow_array::{
1313
GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt32Array,
1414
UInt8Array,
1515
};
16-
use arrow_array::{Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array};
16+
use arrow_array::{
17+
new_null_array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
18+
};
1719
use arrow_buffer::MutableBuffer;
1820
use arrow_data::ArrayDataBuilder;
1921
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema};
@@ -755,6 +757,150 @@ fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
755757
StructArray::try_new(fields.clone(), columns, None)
756758
}
757759

760+
fn lists_have_same_offsets_helper<T: OffsetSizeTrait>(left: &dyn Array, right: &dyn Array) -> bool {
761+
let left_list: &GenericListArray<T> = left.as_list();
762+
let right_list: &GenericListArray<T> = right.as_list();
763+
left_list.offsets().inner() == right_list.offsets().inner()
764+
}
765+
766+
fn merge_list_structs_helper<T: OffsetSizeTrait>(
767+
left: &dyn Array,
768+
right: &dyn Array,
769+
items_field_name: impl Into<String>,
770+
items_nullable: bool,
771+
) -> Arc<dyn Array> {
772+
let left_list: &GenericListArray<T> = left.as_list();
773+
let right_list: &GenericListArray<T> = right.as_list();
774+
let left_struct = left_list.values();
775+
let right_struct = right_list.values();
776+
let left_struct_arr = left_struct.as_struct();
777+
let right_struct_arr = right_struct.as_struct();
778+
let merged_items = Arc::new(merge(left_struct_arr, right_struct_arr));
779+
let items_field = Arc::new(Field::new(
780+
items_field_name,
781+
merged_items.data_type().clone(),
782+
items_nullable,
783+
));
784+
Arc::new(GenericListArray::<T>::new(
785+
items_field,
786+
left_list.offsets().clone(),
787+
merged_items,
788+
left_list.nulls().cloned(),
789+
))
790+
}
791+
792+
fn merge_list_struct_null_helper<T: OffsetSizeTrait>(
793+
left: &dyn Array,
794+
right: &dyn Array,
795+
not_null: &dyn Array,
796+
items_field_name: impl Into<String>,
797+
) -> Arc<dyn Array> {
798+
let left_list: &GenericListArray<T> = left.as_list::<T>();
799+
let not_null_list = not_null.as_list::<T>();
800+
let right_list = right.as_list::<T>();
801+
802+
let left_struct = left_list.values().as_struct();
803+
let not_null_struct: &StructArray = not_null_list.values().as_struct();
804+
let right_struct = right_list.values().as_struct();
805+
806+
let values_len = not_null_list.values().len();
807+
let mut merged_fields =
808+
Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
809+
let mut merged_columns =
810+
Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
811+
812+
for (_, field) in left_struct.columns().iter().zip(left_struct.fields()) {
813+
merged_fields.push(field.clone());
814+
if let Some(val) = not_null_struct.column_by_name(field.name()) {
815+
merged_columns.push(val.clone());
816+
} else {
817+
merged_columns.push(new_null_array(field.data_type(), values_len))
818+
}
819+
}
820+
for (_, field) in right_struct
821+
.columns()
822+
.iter()
823+
.zip(right_struct.fields())
824+
.filter(|(_, field)| left_struct.column_by_name(field.name()).is_none())
825+
{
826+
merged_fields.push(field.clone());
827+
if let Some(val) = not_null_struct.column_by_name(field.name()) {
828+
merged_columns.push(val.clone());
829+
} else {
830+
merged_columns.push(new_null_array(field.data_type(), values_len));
831+
}
832+
}
833+
834+
let merged_struct = Arc::new(StructArray::new(
835+
Fields::from(merged_fields),
836+
merged_columns,
837+
not_null_struct.nulls().cloned(),
838+
));
839+
let items_field = Arc::new(Field::new(
840+
items_field_name,
841+
merged_struct.data_type().clone(),
842+
true,
843+
));
844+
Arc::new(GenericListArray::<T>::new(
845+
items_field,
846+
not_null_list.offsets().clone(),
847+
merged_struct,
848+
not_null_list.nulls().cloned(),
849+
))
850+
}
851+
852+
fn merge_list_struct_null(
853+
left: &dyn Array,
854+
right: &dyn Array,
855+
not_null: &dyn Array,
856+
) -> Arc<dyn Array> {
857+
match left.data_type() {
858+
DataType::List(left_field) => {
859+
merge_list_struct_null_helper::<i32>(left, right, not_null, left_field.name())
860+
}
861+
DataType::LargeList(left_field) => {
862+
merge_list_struct_null_helper::<i64>(left, right, not_null, left_field.name())
863+
}
864+
_ => unreachable!(),
865+
}
866+
}
867+
868+
fn merge_list_struct(left: &dyn Array, right: &dyn Array) -> Arc<dyn Array> {
869+
// Merging fields into a list<struct<...>> is tricky and can only succeed
870+
// in two ways. First, if both lists have the same offsets. Second, if
871+
// one of the lists is all-null
872+
if left.null_count() == left.len() {
873+
return merge_list_struct_null(left, right, right);
874+
} else if right.null_count() == right.len() {
875+
return merge_list_struct_null(left, right, left);
876+
}
877+
match (left.data_type(), right.data_type()) {
878+
(DataType::List(left_field), DataType::List(_)) => {
879+
if !lists_have_same_offsets_helper::<i32>(left, right) {
880+
panic!("Attempt to merge list struct arrays which do not have same offsets");
881+
}
882+
merge_list_structs_helper::<i32>(
883+
left,
884+
right,
885+
left_field.name(),
886+
left_field.is_nullable(),
887+
)
888+
}
889+
(DataType::LargeList(left_field), DataType::LargeList(_)) => {
890+
if !lists_have_same_offsets_helper::<i64>(left, right) {
891+
panic!("Attempt to merge list struct arrays which do not have same offsets");
892+
}
893+
merge_list_structs_helper::<i64>(
894+
left,
895+
right,
896+
left_field.name(),
897+
left_field.is_nullable(),
898+
)
899+
}
900+
_ => unreachable!(),
901+
}
902+
}
903+
758904
fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray {
759905
let mut fields: Vec<Field> = vec![];
760906
let mut columns: Vec<ArrayRef> = vec![];
@@ -788,6 +934,27 @@ fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> S
788934
));
789935
columns.push(Arc::new(merged_sub_array) as ArrayRef);
790936
}
937+
(DataType::List(left_list), DataType::List(right_list))
938+
if left_list.data_type().is_struct()
939+
&& right_list.data_type().is_struct() =>
940+
{
941+
// If there is nothing to merge just use the left field
942+
if left_list.data_type() == right_list.data_type() {
943+
fields.push(left_field.as_ref().clone());
944+
columns.push(left_column.clone());
945+
}
946+
// If we have two List<Struct> and they have different sets of fields then
947+
// we can merge them if the offsets arrays are the same. Otherwise, we
948+
// have to consider it an error.
949+
let merged_sub_array = merge_list_struct(&left_column, &right_column);
950+
951+
fields.push(Field::new(
952+
left_field.name(),
953+
merged_sub_array.data_type().clone(),
954+
left_field.is_nullable(),
955+
));
956+
columns.push(merged_sub_array);
957+
}
791958
// otherwise, just use the field on the left hand side
792959
_ => {
793960
// TODO handle list-of-struct and other types
@@ -1004,7 +1171,8 @@ impl BufferExt for arrow_buffer::Buffer {
10041171
#[cfg(test)]
10051172
mod tests {
10061173
use super::*;
1007-
use arrow_array::{new_empty_array, Int32Array, StringArray};
1174+
use arrow_array::{new_empty_array, new_null_array, Int32Array, ListArray, StringArray};
1175+
use arrow_buffer::OffsetBuffer;
10081176

10091177
#[test]
10101178
fn test_merge_recursive() {
@@ -1134,6 +1302,95 @@ mod tests {
11341302
assert_eq!(merged.schema().as_ref(), &naive_schema);
11351303
}
11361304

1305+
#[test]
1306+
fn test_merge_list_struct() {
1307+
let x_field = Arc::new(Field::new("x", DataType::Int32, true));
1308+
let y_field = Arc::new(Field::new("y", DataType::Int32, true));
1309+
let x_struct_field = Arc::new(Field::new(
1310+
"item",
1311+
DataType::Struct(Fields::from(vec![x_field.clone()])),
1312+
true,
1313+
));
1314+
let y_struct_field = Arc::new(Field::new(
1315+
"item",
1316+
DataType::Struct(Fields::from(vec![y_field.clone()])),
1317+
true,
1318+
));
1319+
let both_struct_field = Arc::new(Field::new(
1320+
"item",
1321+
DataType::Struct(Fields::from(vec![x_field.clone(), y_field.clone()])),
1322+
true,
1323+
));
1324+
let left_schema = Schema::new(vec![Field::new(
1325+
"list_struct",
1326+
DataType::List(x_struct_field.clone()),
1327+
true,
1328+
)]);
1329+
let right_schema = Schema::new(vec![Field::new(
1330+
"list_struct",
1331+
DataType::List(y_struct_field.clone()),
1332+
true,
1333+
)]);
1334+
let both_schema = Schema::new(vec![Field::new(
1335+
"list_struct",
1336+
DataType::List(both_struct_field.clone()),
1337+
true,
1338+
)]);
1339+
1340+
let x = Arc::new(Int32Array::from(vec![1]));
1341+
let y = Arc::new(Int32Array::from(vec![2]));
1342+
let x_struct = Arc::new(StructArray::new(
1343+
Fields::from(vec![x_field.clone()]),
1344+
vec![x.clone()],
1345+
None,
1346+
));
1347+
let y_struct = Arc::new(StructArray::new(
1348+
Fields::from(vec![y_field.clone()]),
1349+
vec![y.clone()],
1350+
None,
1351+
));
1352+
let both_struct = Arc::new(StructArray::new(
1353+
Fields::from(vec![x_field.clone(), y_field.clone()]),
1354+
vec![x.clone(), y],
1355+
None,
1356+
));
1357+
let both_null_struct = Arc::new(StructArray::new(
1358+
Fields::from(vec![x_field, y_field]),
1359+
vec![x, Arc::new(new_null_array(&DataType::Int32, 1))],
1360+
None,
1361+
));
1362+
let offsets = OffsetBuffer::from_lengths([1]);
1363+
let x_s_list = ListArray::new(x_struct_field, offsets.clone(), x_struct, None);
1364+
let y_s_list = ListArray::new(y_struct_field, offsets.clone(), y_struct, None);
1365+
let both_list = ListArray::new(
1366+
both_struct_field.clone(),
1367+
offsets.clone(),
1368+
both_struct,
1369+
None,
1370+
);
1371+
let both_null_list = ListArray::new(both_struct_field, offsets, both_null_struct, None);
1372+
let x_batch =
1373+
RecordBatch::try_new(Arc::new(left_schema), vec![Arc::new(x_s_list)]).unwrap();
1374+
let y_batch = RecordBatch::try_new(
1375+
Arc::new(right_schema.clone()),
1376+
vec![Arc::new(y_s_list.clone())],
1377+
)
1378+
.unwrap();
1379+
let merged = x_batch.merge(&y_batch).unwrap();
1380+
let expected =
1381+
RecordBatch::try_new(Arc::new(both_schema.clone()), vec![Arc::new(both_list)]).unwrap();
1382+
assert_eq!(merged, expected);
1383+
1384+
let y_null_list = new_null_array(y_s_list.data_type(), 1);
1385+
let y_null_batch =
1386+
RecordBatch::try_new(Arc::new(right_schema), vec![Arc::new(y_null_list.clone())])
1387+
.unwrap();
1388+
let expected =
1389+
RecordBatch::try_new(Arc::new(both_schema), vec![Arc::new(both_null_list)]).unwrap();
1390+
let merged = x_batch.merge(&y_null_batch).unwrap();
1391+
assert_eq!(merged, expected);
1392+
}
1393+
11371394
#[test]
11381395
fn test_take_record_batch() {
11391396
let schema = Arc::new(Schema::new(vec![

rust/lance-core/src/datatypes/field.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,11 @@ impl Field {
674674
}
675675
let self_type = self.data_type();
676676
let other_type = other.data_type();
677-
if self_type.is_struct() && other_type.is_struct() {
677+
678+
if matches!(
679+
(&self_type, &other_type),
680+
(DataType::Struct(_), DataType::Struct(_)) | (DataType::List(_), DataType::List(_))
681+
) {
678682
let children = self
679683
.children
680684
.iter()

rust/lance-core/src/datatypes/schema.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,30 @@ mod tests {
13091309
ArrowField::new("c", DataType::Float64, false),
13101310
]);
13111311
assert_eq!(actual, expected);
1312+
1313+
let schema_with_list_struct = ArrowSchema::new(vec![ArrowField::new(
1314+
"struct_list",
1315+
DataType::List(Arc::new(ArrowField::new(
1316+
"item",
1317+
DataType::Struct(ArrowFields::from(vec![
1318+
ArrowField::new("f1", DataType::Utf8, true),
1319+
ArrowField::new("f2", DataType::Boolean, false),
1320+
])),
1321+
true,
1322+
))),
1323+
true,
1324+
)]);
1325+
let schema_with_list_struct = Schema::try_from(&schema_with_list_struct).unwrap();
1326+
1327+
let with_missing_field = schema_with_list_struct.project_by_ids(&[1, 3], false);
1328+
let intersection = schema_with_list_struct
1329+
.intersection_ignore_types(&with_missing_field)
1330+
.unwrap();
1331+
assert_eq!(intersection, with_missing_field);
1332+
let intersection = with_missing_field
1333+
.intersection_ignore_types(&schema_with_list_struct)
1334+
.unwrap();
1335+
assert_eq!(intersection, with_missing_field);
13121336
}
13131337

13141338
#[test]

0 commit comments

Comments
 (0)