Skip to content

Commit fe2fcd0

Browse files
committed
Port TryCastExpr proto serialization hooks
1 parent 11a79a6 commit fe2fcd0

5 files changed

Lines changed: 184 additions & 23 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/physical-expr/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ recursive_protection = ["dep:recursive"]
4646
# `PhysicalExpr::to_proto` and letting expressions in this crate implement it.
4747
proto = [
4848
"dep:datafusion-proto-models",
49+
"dep:datafusion-proto-common",
4950
"datafusion-physical-expr-common/proto",
5051
]
5152

@@ -56,6 +57,7 @@ datafusion-expr = { workspace = true }
5657
datafusion-expr-common = { workspace = true }
5758
datafusion-functions-aggregate-common = { workspace = true }
5859
datafusion-physical-expr-common = { workspace = true }
60+
datafusion-proto-common = { workspace = true, optional = true }
5961
datafusion-proto-models = { workspace = true, optional = true }
6062
hashbrown = { workspace = true }
6163
indexmap = { workspace = true }

datafusion/physical-expr/src/expressions/try_cast.rs

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,67 @@ impl PhysicalExpr for TryCastExpr {
119119
self.expr.fmt_sql(f)?;
120120
write!(f, " AS {:?})", self.cast_type)
121121
}
122+
123+
#[cfg(feature = "proto")]
124+
fn try_to_proto(
125+
&self,
126+
ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>,
127+
) -> Result<Option<datafusion_proto_models::protobuf::PhysicalExprNode>> {
128+
use datafusion_proto_models::protobuf;
129+
130+
let arrow_type: datafusion_proto_common::protobuf_common::ArrowType =
131+
self.cast_type().try_into()?;
132+
133+
Ok(Some(protobuf::PhysicalExprNode {
134+
expr_id: None,
135+
expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new(
136+
protobuf::PhysicalTryCastNode {
137+
expr: Some(Box::new(ctx.encode_child(&self.expr)?)),
138+
arrow_type: Some(arrow_type),
139+
},
140+
))),
141+
}))
142+
}
143+
}
144+
145+
#[cfg(feature = "proto")]
146+
impl TryCastExpr {
147+
/// Reconstruct a [`TryCastExpr`] from its protobuf representation.
148+
///
149+
/// Takes the whole [`PhysicalExprNode`] so the decode signature matches
150+
/// other migrated expressions and can inspect outer-node metadata if
151+
/// needed in the future.
152+
///
153+
/// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode
154+
pub fn try_from_proto(
155+
node: &datafusion_proto_models::protobuf::PhysicalExprNode,
156+
ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>,
157+
) -> Result<Arc<dyn PhysicalExpr>> {
158+
use datafusion_common::{internal_datafusion_err, internal_err};
159+
use datafusion_proto_models::protobuf;
160+
161+
let try_cast = match &node.expr_type {
162+
Some(protobuf::physical_expr_node::ExprType::TryCast(try_cast)) => {
163+
try_cast.as_ref()
164+
}
165+
_ => return internal_err!("PhysicalExprNode is not a TryCastExpr"),
166+
};
167+
168+
let expr = ctx.decode_required_expression(
169+
try_cast.expr.as_deref(),
170+
"TryCastExpr",
171+
"expr",
172+
)?;
173+
let arrow_type: &datafusion_proto_common::protobuf_common::ArrowType =
174+
try_cast.arrow_type.as_ref().ok_or_else(|| {
175+
internal_datafusion_err!(
176+
"TryCastExpr is missing required field 'arrow_type'"
177+
)
178+
})?;
179+
let cast_type: DataType = arrow_type.try_into()?;
180+
181+
Ok(Arc::new(TryCastExpr::new(expr, cast_type)))
182+
}
122183
}
123184

124185
/// Return a PhysicalExpression representing `expr` casted to
@@ -143,6 +204,8 @@ pub fn try_cast(
143204
#[cfg(test)]
144205
mod tests {
145206
use super::*;
207+
#[cfg(feature = "proto")]
208+
use crate::expressions::Column;
146209
use crate::expressions::col;
147210
use arrow::array::{
148211
Decimal128Array, Decimal128Builder, StringArray, Time64NanosecondArray,
@@ -154,7 +217,20 @@ mod tests {
154217
},
155218
datatypes::*,
156219
};
220+
#[cfg(feature = "proto")]
221+
use datafusion_common::DataFusionError;
157222
use datafusion_physical_expr_common::physical_expr::fmt_sql;
223+
#[cfg(feature = "proto")]
224+
use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx;
225+
#[cfg(feature = "proto")]
226+
use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx;
227+
#[cfg(feature = "proto")]
228+
use datafusion_proto_models::protobuf::{self, physical_expr_node};
229+
230+
#[cfg(feature = "proto")]
231+
use crate::proto_test_util::{
232+
StubDecoder, StubEncoder, UnreachableDecoder, column_node,
233+
};
158234

159235
// runs an end-to-end test of physical type cast
160236
// 1. construct a record batch with a column "a" of type A
@@ -592,4 +668,107 @@ mod tests {
592668

593669
Ok(())
594670
}
671+
672+
#[cfg(feature = "proto")]
673+
fn try_cast_node(
674+
expr: Option<Box<protobuf::PhysicalExprNode>>,
675+
cast_type: Option<DataType>,
676+
) -> protobuf::PhysicalExprNode {
677+
protobuf::PhysicalExprNode {
678+
expr_id: None,
679+
expr_type: Some(physical_expr_node::ExprType::TryCast(Box::new(
680+
protobuf::PhysicalTryCastNode {
681+
expr,
682+
arrow_type: cast_type.map(|cast_type| {
683+
let arrow_type: datafusion_proto_common::protobuf_common::ArrowType =
684+
(&cast_type).try_into().unwrap();
685+
arrow_type
686+
}),
687+
},
688+
))),
689+
}
690+
}
691+
692+
#[cfg(feature = "proto")]
693+
#[test]
694+
fn try_to_proto_encodes_try_cast_expr() -> Result<()> {
695+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
696+
let expr = TryCastExpr::new(col("a", &schema)?, DataType::Int32);
697+
let encoder = StubEncoder::ok();
698+
let ctx = PhysicalExprEncodeCtx::new(&encoder);
699+
700+
let node = expr.try_to_proto(&ctx)?.expect("TryCastExpr proto");
701+
let try_cast = match node.expr_type {
702+
Some(physical_expr_node::ExprType::TryCast(try_cast)) => try_cast,
703+
other => panic!("expected TryCast proto, got {other:?}"),
704+
};
705+
706+
assert!(try_cast.expr.is_some());
707+
let cast_type: DataType = try_cast.arrow_type.as_ref().unwrap().try_into()?;
708+
assert_eq!(cast_type, DataType::Int32);
709+
710+
Ok(())
711+
}
712+
713+
#[cfg(feature = "proto")]
714+
#[test]
715+
fn try_from_proto_decodes_try_cast_expr() {
716+
let node = try_cast_node(Some(Box::new(column_node("a"))), Some(DataType::Int64));
717+
let schema = Schema::empty();
718+
let decoder = StubDecoder::ok();
719+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
720+
721+
let decoded = TryCastExpr::try_from_proto(&node, &ctx).unwrap();
722+
let try_cast = decoded
723+
.downcast_ref::<TryCastExpr>()
724+
.expect("decoded expr should be a TryCastExpr");
725+
726+
assert!(try_cast.expr().downcast_ref::<Column>().is_some());
727+
assert_eq!(try_cast.cast_type(), &DataType::Int64);
728+
}
729+
730+
#[cfg(feature = "proto")]
731+
#[test]
732+
fn try_from_proto_rejects_non_try_cast_node() {
733+
let node = column_node("a");
734+
let schema = Schema::empty();
735+
let decoder = UnreachableDecoder;
736+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
737+
738+
let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err();
739+
assert!(matches!(
740+
err,
741+
DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a TryCastExpr")
742+
));
743+
}
744+
745+
#[cfg(feature = "proto")]
746+
#[test]
747+
fn try_from_proto_rejects_missing_expr() {
748+
let node = try_cast_node(None, Some(DataType::Int32));
749+
let schema = Schema::empty();
750+
let decoder = UnreachableDecoder;
751+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
752+
753+
let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err();
754+
assert!(matches!(
755+
err,
756+
DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'expr'")
757+
));
758+
}
759+
760+
#[cfg(feature = "proto")]
761+
#[test]
762+
fn try_from_proto_rejects_missing_arrow_type() {
763+
let node = try_cast_node(Some(Box::new(column_node("a"))), None);
764+
let schema = Schema::empty();
765+
let decoder = StubDecoder::ok();
766+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
767+
768+
let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err();
769+
assert!(matches!(
770+
err,
771+
DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'arrow_type'")
772+
));
773+
}
595774
}

datafusion/proto/src/physical_plan/from_proto.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -371,16 +371,7 @@ pub fn parse_physical_expr_with_converter(
371371
convert_required!(e.arrow_type)?,
372372
None,
373373
)),
374-
ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
375-
parse_required_physical_expr(
376-
e.expr.as_deref(),
377-
ctx,
378-
"expr",
379-
input_schema,
380-
proto_converter,
381-
)?,
382-
convert_required!(e.arrow_type)?,
383-
)),
374+
ExprType::TryCast(_) => TryCastExpr::try_from_proto(proto, &decode_ctx)?,
384375
ExprType::ScalarUdf(e) => {
385376
let udf = match &e.fun_definition {
386377
Some(buf) => ctx.codec().try_decode_udf(&e.name, buf)?,

datafusion/proto/src/physical_plan/to_proto.rs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindo
3737
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
3838
use datafusion_physical_plan::expressions::{
3939
CaseExpr, CastExpr, DynamicFilterPhysicalExpr, IsNotNullExpr, IsNullExpr, Literal,
40-
NotExpr, TryCastExpr, UnKnownColumn,
40+
NotExpr, UnKnownColumn,
4141
};
4242
use datafusion_physical_plan::joins::HashExpr;
4343
use datafusion_physical_plan::udaf::AggregateFunctionExpr;
@@ -408,18 +408,6 @@ pub fn serialize_physical_expr_with_converter(
408408
},
409409
))),
410410
})
411-
} else if let Some(cast) = expr.downcast_ref::<TryCastExpr>() {
412-
Ok(protobuf::PhysicalExprNode {
413-
expr_id,
414-
expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new(
415-
protobuf::PhysicalTryCastNode {
416-
expr: Some(Box::new(
417-
proto_converter.physical_expr_to_proto(cast.expr(), codec)?,
418-
)),
419-
arrow_type: Some(cast.cast_type().try_into()?),
420-
},
421-
))),
422-
})
423411
} else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
424412
let mut buf = Vec::new();
425413
codec.try_encode_udf(expr.fun(), &mut buf)?;

0 commit comments

Comments
 (0)