Skip to content

Commit 86e5a38

Browse files
authored
fix: restore Stage plan during deserialization round-trip (#404)
Closes #402 ### Problem After #345, `parse_stage_proto` unconditionally sets `stage.plan = None` during deserialization, discarding children that DataFusion's framework correctly deserialized. This breaks any architecture where the distributed plan goes through a serialize/deserialize round-trip before `DistributedExec::prepare_plan()` runs — such as Arrow Flight SQL's two-phase `GetFlightInfo`/`DoGet` protocol. ### Fix `children()` on network boundary nodes already encodes plan presence: it returns `[plan]` when `Some`, `[]` when `None`. DataFusion's serialization framework preserves this through the round-trip and passes the deserialized children as `inputs`. The fix infers plan presence from `inputs` rather than ignoring them: ```rust plan: inputs.first().cloned(), ``` ### Tests - Updated two existing tests (`test_roundtrip_single_flight`, `test_roundtrip_single_flight_coalesce`) to pass correct inputs for stages with `plan: None` - Added two new tests (`test_roundtrip_single_flight_with_plan`, `test_roundtrip_single_flight_coalesce_with_plan`) that validate the round-trip when a stage has `plan: Some(...)` ### Visibility changes In a Flight SQL architecture, the coordinator server operates in two distinct phases separated by a network boundary: 1. GetFlightInfo — optimizes the query and serializes the full physical plan (including DFD's network boundary nodes) into a ticket 2. DoGet — deserializes the plan from the ticket and executes it This means the server must be able to serialize and deserialize DFD's plan nodes independently of DFD's internal machinery. To do this, it needs DistributedCodec in its PhysicalExtensionCodec chain — otherwise the server has no way to round-trip NetworkShuffleExec, NetworkCoalesceExec, etc. through the ticket. Similarly, a Flight SQL server that also acts as a DFD worker needs to implement the WorkerService gRPC trait. This requires WorkerServiceServer to register the service with tonic, FlightAppMetadata to read passthrough headers from requests, and Worker::impl_execute_task to delegate task execution. The error conversion functions (datafusion_error_to_tonic_status, etc.) are needed to translate between DFD's error types and gRPC status codes in the server's request handlers. Currently these types are all pub(crate), which forces downstream Flight SQL integrations to fork the crate just to access them. Making them public allows DFD to be used as a library in these architectures without modification.
1 parent d862428 commit 86e5a38

3 files changed

Lines changed: 57 additions & 4 deletions

File tree

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod distributed_planner;
1313
mod networking;
1414
mod observability;
1515
mod protobuf;
16+
pub use protobuf::DistributedCodec;
1617
#[cfg(any(feature = "integration", test))]
1718
pub mod test_utils;
1819

src/protobuf/distributed_codec.rs

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
6060

6161
fn parse_stage_proto(
6262
proto: Option<StageProto>,
63-
_inputs: &[Arc<dyn ExecutionPlan>],
63+
inputs: &[Arc<dyn ExecutionPlan>],
6464
) -> Result<Stage, DataFusionError> {
6565
let Some(proto) = proto else {
6666
return Err(proto_error("Empty StageProto"));
@@ -70,7 +70,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
7070
query_id: uuid::Uuid::from_slice(proto.query_id.as_ref())
7171
.map_err(|_| proto_error("Invalid query_id in StageProto"))?,
7272
num: proto.num as usize,
73-
plan: None,
73+
plan: inputs.first().cloned(),
7474
tasks: decode_tasks(proto.tasks)?,
7575
})
7676
}
@@ -556,6 +556,15 @@ mod tests {
556556
}
557557
}
558558

559+
fn dummy_stage_with_plan() -> Stage {
560+
Stage {
561+
query_id: Default::default(),
562+
num: 0,
563+
plan: Some(empty_exec()),
564+
tasks: vec![],
565+
}
566+
}
567+
559568
fn schema_i32(name: &str) -> Arc<Schema> {
560569
Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)]))
561570
}
@@ -581,7 +590,7 @@ mod tests {
581590
let mut buf = Vec::new();
582591
codec.try_encode(plan.clone(), &mut buf)?;
583592

584-
let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
593+
let decoded = codec.try_decode(&buf, &[], &ctx)?;
585594
assert_eq!(repr(&plan), repr(&decoded));
586595

587596
Ok(())
@@ -686,6 +695,49 @@ mod tests {
686695
let mut buf = Vec::new();
687696
codec.try_encode(plan.clone(), &mut buf)?;
688697

698+
let decoded = codec.try_decode(&buf, &[], &ctx)?;
699+
assert_eq!(repr(&plan), repr(&decoded));
700+
701+
Ok(())
702+
}
703+
704+
#[test]
705+
fn test_roundtrip_single_flight_with_plan() -> datafusion::common::Result<()> {
706+
let codec = DistributedCodec;
707+
let ctx = create_context();
708+
709+
let schema = schema_i32("a");
710+
let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
711+
let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_hash_shuffle_exec(
712+
part,
713+
schema,
714+
dummy_stage_with_plan(),
715+
));
716+
717+
let mut buf = Vec::new();
718+
codec.try_encode(plan.clone(), &mut buf)?;
719+
720+
let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
721+
assert_eq!(repr(&plan), repr(&decoded));
722+
723+
Ok(())
724+
}
725+
726+
#[test]
727+
fn test_roundtrip_single_flight_coalesce_with_plan() -> datafusion::common::Result<()> {
728+
let codec = DistributedCodec;
729+
let ctx = create_context();
730+
731+
let schema = schema_i32("e");
732+
let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_coalesce_tasks_exec(
733+
Partitioning::RoundRobinBatch(3),
734+
schema,
735+
dummy_stage_with_plan(),
736+
));
737+
738+
let mut buf = Vec::new();
739+
codec.try_encode(plan.clone(), &mut buf)?;
740+
689741
let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
690742
assert_eq!(repr(&plan), repr(&decoded));
691743

src/protobuf/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod distributed_codec;
22
mod errors;
33
mod user_codec;
44

5-
pub(crate) use distributed_codec::DistributedCodec;
5+
pub use distributed_codec::DistributedCodec;
66
pub(crate) use errors::{
77
datafusion_error_to_tonic_status, map_flight_to_datafusion_error,
88
tonic_status_to_datafusion_error,

0 commit comments

Comments
 (0)