Skip to content

Commit 471fb2a

Browse files
andygroveclaude
andcommitted
perf: optimize struct field processing with field-major order
Optimize struct field processing in native shuffle by using field-major instead of row-major order. This moves type dispatch from O(rows × fields) to O(fields), eliminating per-row type matching overhead. Previously, for each row we iterated over all fields and called `append_field()` which did a type match for EVERY field in EVERY row. For a struct with N fields and M rows, that's N×M type matches. The new approach: 1. First pass: Loop over rows, build struct validity 2. Second pass: For each field, get typed builder once, then process all rows for that field This keeps type dispatch at O(fields) instead of O(rows × fields). For complex nested types (struct, list, map), falls back to existing `append_field` since they have their own recursive processing logic. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 57780bc commit 471fb2a

1 file changed

Lines changed: 210 additions & 21 deletions

File tree

  • native/core/src/execution/shuffle

native/core/src/execution/shuffle/row.rs

Lines changed: 210 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,205 @@ pub(crate) fn append_field(
439439
Ok(())
440440
}
441441

442+
/// Appends struct fields to the struct builder using field-major order.
443+
/// This processes one field at a time across all rows, which moves type dispatch
444+
/// outside the row loop (O(fields) dispatches instead of O(rows × fields)).
445+
#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)]
446+
fn append_struct_fields_field_major(
447+
row_addresses_ptr: *mut jlong,
448+
row_sizes_ptr: *mut jint,
449+
row_start: usize,
450+
row_end: usize,
451+
parent_row: &mut SparkUnsafeRow,
452+
column_idx: usize,
453+
struct_builder: &mut StructBuilder,
454+
fields: &arrow::datatypes::Fields,
455+
) -> Result<(), CometError> {
456+
let num_rows = row_end - row_start;
457+
let num_fields = fields.len();
458+
459+
// First pass: Build struct validity and collect which structs are null
460+
// We use a Vec<bool> for simplicity; could use a bitset for better memory
461+
let mut struct_is_null = Vec::with_capacity(num_rows);
462+
463+
for i in row_start..row_end {
464+
let row_addr = unsafe { *row_addresses_ptr.add(i) };
465+
let row_size = unsafe { *row_sizes_ptr.add(i) };
466+
parent_row.point_to(row_addr, row_size);
467+
468+
let is_null = parent_row.is_null_at(column_idx);
469+
struct_is_null.push(is_null);
470+
471+
if is_null {
472+
struct_builder.append_null();
473+
} else {
474+
struct_builder.append(true);
475+
}
476+
}
477+
478+
// Helper macro for processing primitive fields
479+
macro_rules! process_field {
480+
($builder_type:ty, $field_idx:expr, $get_value:expr) => {{
481+
let field_builder = struct_builder
482+
.field_builder::<$builder_type>($field_idx)
483+
.unwrap();
484+
485+
for (row_idx, i) in (row_start..row_end).enumerate() {
486+
if struct_is_null[row_idx] {
487+
// Struct is null, field is also null
488+
field_builder.append_null();
489+
} else {
490+
let row_addr = unsafe { *row_addresses_ptr.add(i) };
491+
let row_size = unsafe { *row_sizes_ptr.add(i) };
492+
parent_row.point_to(row_addr, row_size);
493+
let nested_row = parent_row.get_struct(column_idx, num_fields);
494+
495+
if nested_row.is_null_at($field_idx) {
496+
field_builder.append_null();
497+
} else {
498+
field_builder.append_value($get_value(&nested_row, $field_idx));
499+
}
500+
}
501+
}
502+
}};
503+
}
504+
505+
// Second pass: Process each field across all rows
506+
for (field_idx, field) in fields.iter().enumerate() {
507+
match field.data_type() {
508+
DataType::Boolean => {
509+
process_field!(BooleanBuilder, field_idx, |row: &SparkUnsafeRow, idx| row
510+
.get_boolean(idx));
511+
}
512+
DataType::Int8 => {
513+
process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow, idx| row
514+
.get_byte(idx));
515+
}
516+
DataType::Int16 => {
517+
process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow, idx| row
518+
.get_short(idx));
519+
}
520+
DataType::Int32 => {
521+
process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow, idx| row
522+
.get_int(idx));
523+
}
524+
DataType::Int64 => {
525+
process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow, idx| row
526+
.get_long(idx));
527+
}
528+
DataType::Float32 => {
529+
process_field!(Float32Builder, field_idx, |row: &SparkUnsafeRow, idx| row
530+
.get_float(idx));
531+
}
532+
DataType::Float64 => {
533+
process_field!(Float64Builder, field_idx, |row: &SparkUnsafeRow, idx| row
534+
.get_double(idx));
535+
}
536+
DataType::Date32 => {
537+
process_field!(Date32Builder, field_idx, |row: &SparkUnsafeRow, idx| row
538+
.get_date(idx));
539+
}
540+
DataType::Timestamp(TimeUnit::Microsecond, _) => {
541+
process_field!(
542+
TimestampMicrosecondBuilder,
543+
field_idx,
544+
|row: &SparkUnsafeRow, idx| row.get_timestamp(idx)
545+
);
546+
}
547+
DataType::Binary => {
548+
let field_builder = struct_builder
549+
.field_builder::<BinaryBuilder>(field_idx)
550+
.unwrap();
551+
552+
for (row_idx, i) in (row_start..row_end).enumerate() {
553+
if struct_is_null[row_idx] {
554+
field_builder.append_null();
555+
} else {
556+
let row_addr = unsafe { *row_addresses_ptr.add(i) };
557+
let row_size = unsafe { *row_sizes_ptr.add(i) };
558+
parent_row.point_to(row_addr, row_size);
559+
let nested_row = parent_row.get_struct(column_idx, num_fields);
560+
561+
if nested_row.is_null_at(field_idx) {
562+
field_builder.append_null();
563+
} else {
564+
field_builder.append_value(nested_row.get_binary(field_idx));
565+
}
566+
}
567+
}
568+
}
569+
DataType::Utf8 => {
570+
let field_builder = struct_builder
571+
.field_builder::<StringBuilder>(field_idx)
572+
.unwrap();
573+
574+
for (row_idx, i) in (row_start..row_end).enumerate() {
575+
if struct_is_null[row_idx] {
576+
field_builder.append_null();
577+
} else {
578+
let row_addr = unsafe { *row_addresses_ptr.add(i) };
579+
let row_size = unsafe { *row_sizes_ptr.add(i) };
580+
parent_row.point_to(row_addr, row_size);
581+
let nested_row = parent_row.get_struct(column_idx, num_fields);
582+
583+
if nested_row.is_null_at(field_idx) {
584+
field_builder.append_null();
585+
} else {
586+
field_builder.append_value(nested_row.get_string(field_idx));
587+
}
588+
}
589+
}
590+
}
591+
DataType::Decimal128(p, _) => {
592+
let p = *p;
593+
let field_builder = struct_builder
594+
.field_builder::<Decimal128Builder>(field_idx)
595+
.unwrap();
596+
597+
for (row_idx, i) in (row_start..row_end).enumerate() {
598+
if struct_is_null[row_idx] {
599+
field_builder.append_null();
600+
} else {
601+
let row_addr = unsafe { *row_addresses_ptr.add(i) };
602+
let row_size = unsafe { *row_sizes_ptr.add(i) };
603+
parent_row.point_to(row_addr, row_size);
604+
let nested_row = parent_row.get_struct(column_idx, num_fields);
605+
606+
if nested_row.is_null_at(field_idx) {
607+
field_builder.append_null();
608+
} else {
609+
field_builder.append_value(nested_row.get_decimal(field_idx, p));
610+
}
611+
}
612+
}
613+
}
614+
// For complex types (struct, list, map), fall back to append_field
615+
// since they have their own nested processing logic
616+
dt @ (DataType::Struct(_) | DataType::List(_) | DataType::Map(_, _)) => {
617+
for (row_idx, i) in (row_start..row_end).enumerate() {
618+
let nested_row = if struct_is_null[row_idx] {
619+
SparkUnsafeRow::default()
620+
} else {
621+
let row_addr = unsafe { *row_addresses_ptr.add(i) };
622+
let row_size = unsafe { *row_sizes_ptr.add(i) };
623+
parent_row.point_to(row_addr, row_size);
624+
parent_row.get_struct(column_idx, num_fields)
625+
};
626+
append_field(dt, struct_builder, &nested_row, field_idx)?;
627+
}
628+
}
629+
_ => {
630+
unreachable!(
631+
"Unsupported data type of struct field: {:?}",
632+
field.data_type()
633+
)
634+
}
635+
}
636+
}
637+
638+
Ok(())
639+
}
640+
442641
/// Appends column of top rows to the given array builder.
443642
#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)]
444643
pub(crate) fn append_columns(
@@ -637,27 +836,17 @@ pub(crate) fn append_columns(
637836
.expect("StructBuilder");
638837
let mut row = SparkUnsafeRow::new(schema);
639838

640-
for i in row_start..row_end {
641-
let row_addr = unsafe { *row_addresses_ptr.add(i) };
642-
let row_size = unsafe { *row_sizes_ptr.add(i) };
643-
row.point_to(row_addr, row_size);
644-
645-
let is_null = row.is_null_at(column_idx);
646-
647-
let nested_row = if is_null {
648-
// The struct is null.
649-
// Append a null value to the struct builder and field builders.
650-
struct_builder.append_null();
651-
SparkUnsafeRow::default()
652-
} else {
653-
struct_builder.append(true);
654-
row.get_struct(column_idx, fields.len())
655-
};
656-
657-
for (idx, field) in fields.into_iter().enumerate() {
658-
append_field(field.data_type(), struct_builder, &nested_row, idx)?;
659-
}
660-
}
839+
// Use field-major processing to avoid per-row type dispatch
840+
append_struct_fields_field_major(
841+
row_addresses_ptr,
842+
row_sizes_ptr,
843+
row_start,
844+
row_end,
845+
&mut row,
846+
column_idx,
847+
struct_builder,
848+
fields,
849+
)?;
661850
}
662851
_ => {
663852
unreachable!("Unsupported data type of column: {:?}", dt)

0 commit comments

Comments
 (0)