Skip to content

Commit 351c2b1

Browse files
committed
Port TryCastExpr proto serialization hooks
1 parent 0e17880 commit 351c2b1

3 files changed

Lines changed: 192 additions & 23 deletions

File tree

datafusion/physical-expr/src/expressions/try_cast.rs

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,56 @@ impl PhysicalExpr for TryCastExpr {
119119
self.expr.fmt_sql(f)?;
120120
write!(f, " AS {:?})", self.cast_type)
121121
}
122+
123+
#[cfg(feature = "proto")]
124+
fn try_to_proto(
125+
&self,
126+
ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>,
127+
) -> Result<Option<datafusion_proto_models::protobuf::PhysicalExprNode>> {
128+
use datafusion_proto_models::protobuf;
129+
130+
Ok(Some(protobuf::PhysicalExprNode {
131+
expr_id: None,
132+
expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new(
133+
protobuf::PhysicalTryCastNode {
134+
expr: Some(Box::new(ctx.encode_child(&self.expr)?)),
135+
arrow_type: Some(self.cast_type().try_into()?),
136+
},
137+
))),
138+
}))
139+
}
140+
}
141+
142+
#[cfg(feature = "proto")]
143+
impl TryCastExpr {
144+
/// Reconstruct a [`TryCastExpr`] from its protobuf representation.
145+
pub fn try_from_proto(
146+
node: &datafusion_proto_models::protobuf::PhysicalExprNode,
147+
ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>,
148+
) -> Result<Arc<dyn PhysicalExpr>> {
149+
use datafusion_physical_expr_common::expect_expr_variant;
150+
use datafusion_physical_expr_common::physical_expr::proto_decode::require_proto_field;
151+
use datafusion_proto_models::protobuf;
152+
153+
let try_cast = expect_expr_variant!(
154+
node,
155+
protobuf::physical_expr_node::ExprType::TryCast,
156+
"TryCastExpr",
157+
);
158+
let expr = ctx.decode_required_expression(
159+
try_cast.expr.as_deref(),
160+
"TryCastExpr",
161+
"expr",
162+
)?;
163+
let arrow_type = require_proto_field(
164+
try_cast.arrow_type.as_ref(),
165+
"TryCastExpr",
166+
"arrow_type",
167+
)?;
168+
let cast_type: DataType = arrow_type.try_into()?;
169+
170+
Ok(Arc::new(TryCastExpr::new(expr, cast_type)))
171+
}
122172
}
123173

124174
/// Return a PhysicalExpression representing `expr` casted to
@@ -593,3 +643,143 @@ mod tests {
593643
Ok(())
594644
}
595645
}
646+
647+
#[cfg(all(test, feature = "proto"))]
648+
mod proto_tests {
649+
use super::*;
650+
use crate::expressions::{Column, col};
651+
use crate::proto_test_util::{
652+
StubDecoder, StubEncoder, UnreachableDecoder, column_node,
653+
};
654+
use arrow::datatypes::Field;
655+
use datafusion_common::DataFusionError;
656+
use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx;
657+
use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx;
658+
use datafusion_proto_models::datafusion_common::ArrowType;
659+
use datafusion_proto_models::protobuf::{
660+
PhysicalExprNode, PhysicalTryCastNode, physical_expr_node,
661+
};
662+
663+
fn try_cast_fixture() -> TryCastExpr {
664+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
665+
TryCastExpr::new(col("a", &schema).unwrap(), DataType::Int32)
666+
}
667+
668+
fn int32_arrow_type() -> ArrowType {
669+
(&DataType::Int32).try_into().unwrap()
670+
}
671+
672+
fn try_cast_node(
673+
expr: Option<Box<PhysicalExprNode>>,
674+
arrow_type: Option<ArrowType>,
675+
) -> PhysicalExprNode {
676+
PhysicalExprNode {
677+
expr_id: None,
678+
expr_type: Some(physical_expr_node::ExprType::TryCast(Box::new(
679+
PhysicalTryCastNode { expr, arrow_type },
680+
))),
681+
}
682+
}
683+
684+
#[test]
685+
fn try_to_proto_encodes_try_cast_expr() {
686+
let try_cast = try_cast_fixture();
687+
let encoder = StubEncoder::ok();
688+
let ctx = PhysicalExprEncodeCtx::new(&encoder);
689+
690+
let node = try_cast
691+
.try_to_proto(&ctx)
692+
.unwrap()
693+
.expect("TryCastExpr should encode to Some(node)");
694+
695+
assert!(node.expr_id.is_none());
696+
let try_cast_node = match node.expr_type {
697+
Some(physical_expr_node::ExprType::TryCast(boxed)) => *boxed,
698+
other => panic!("expected a TryCastExpr node, got {other:?}"),
699+
};
700+
assert!(try_cast_node.expr.is_some());
701+
702+
let arrow_type = try_cast_node
703+
.arrow_type
704+
.as_ref()
705+
.expect("try cast type should be encoded");
706+
let data_type: DataType = arrow_type.try_into().unwrap();
707+
assert_eq!(data_type, DataType::Int32);
708+
}
709+
710+
#[test]
711+
fn try_to_proto_propagates_child_encode_error() {
712+
let try_cast = try_cast_fixture();
713+
let encoder = StubEncoder::failing_on(1);
714+
let ctx = PhysicalExprEncodeCtx::new(&encoder);
715+
let err = try_cast.try_to_proto(&ctx).unwrap_err();
716+
assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1")));
717+
}
718+
719+
#[test]
720+
fn try_from_proto_decodes_try_cast_expr() {
721+
let node =
722+
try_cast_node(Some(Box::new(column_node("a"))), Some(int32_arrow_type()));
723+
let schema = Schema::empty();
724+
let decoder = StubDecoder::ok();
725+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
726+
727+
let decoded = TryCastExpr::try_from_proto(&node, &ctx).unwrap();
728+
let try_cast = decoded
729+
.downcast_ref::<TryCastExpr>()
730+
.expect("decoded expr should be a TryCastExpr");
731+
732+
assert_eq!(try_cast.cast_type(), &DataType::Int32);
733+
assert!(try_cast.expr().downcast_ref::<Column>().is_some());
734+
}
735+
736+
#[test]
737+
fn try_from_proto_rejects_non_try_cast_node() {
738+
let node = column_node("a");
739+
let schema = Schema::empty();
740+
let decoder = UnreachableDecoder;
741+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
742+
743+
let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err();
744+
assert!(
745+
matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a TryCastExpr"))
746+
);
747+
}
748+
749+
#[test]
750+
fn try_from_proto_rejects_missing_expr() {
751+
let node = try_cast_node(None, Some(int32_arrow_type()));
752+
let schema = Schema::empty();
753+
let decoder = UnreachableDecoder;
754+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
755+
756+
let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err();
757+
assert!(
758+
matches!(err, DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'expr'"))
759+
);
760+
}
761+
762+
#[test]
763+
fn try_from_proto_rejects_missing_arrow_type() {
764+
let node = try_cast_node(Some(Box::new(column_node("a"))), None);
765+
let schema = Schema::empty();
766+
let decoder = StubDecoder::ok();
767+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
768+
769+
let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err();
770+
assert!(
771+
matches!(err, DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'arrow_type'"))
772+
);
773+
}
774+
775+
#[test]
776+
fn try_from_proto_propagates_child_decode_error() {
777+
let node =
778+
try_cast_node(Some(Box::new(column_node("a"))), Some(int32_arrow_type()));
779+
let schema = Schema::empty();
780+
let decoder = StubDecoder::failing_on(1);
781+
let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder);
782+
let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err();
783+
assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1")));
784+
}
785+
}

datafusion/proto/src/physical_plan/from_proto.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -347,16 +347,7 @@ pub fn parse_physical_expr_with_converter(
347347
.transpose()?,
348348
)?),
349349
ExprType::Cast(_) => CastExpr::try_from_proto(proto, &decode_ctx)?,
350-
ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
351-
parse_required_physical_expr(
352-
e.expr.as_deref(),
353-
ctx,
354-
"expr",
355-
input_schema,
356-
proto_converter,
357-
)?,
358-
convert_required!(e.arrow_type)?,
359-
)),
350+
ExprType::TryCast(_) => TryCastExpr::try_from_proto(proto, &decode_ctx)?,
360351
ExprType::ScalarUdf(e) => {
361352
let udf = match &e.fun_definition {
362353
Some(buf) => ctx.codec().try_decode_udf(&e.name, buf)?,

datafusion/proto/src/physical_plan/to_proto.rs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr;
3636
use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr};
3737
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
3838
use datafusion_physical_plan::expressions::{
39-
CaseExpr, DynamicFilterPhysicalExpr, IsNullExpr, Literal, TryCastExpr,
39+
CaseExpr, DynamicFilterPhysicalExpr, IsNullExpr, Literal,
4040
};
4141
use datafusion_physical_plan::udaf::AggregateFunctionExpr;
4242
use datafusion_physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr};
@@ -363,18 +363,6 @@ pub fn serialize_physical_expr_with_converter(
363363
lit.value().try_into()?,
364364
)),
365365
})
366-
} else if let Some(cast) = expr.downcast_ref::<TryCastExpr>() {
367-
Ok(protobuf::PhysicalExprNode {
368-
expr_id,
369-
expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new(
370-
protobuf::PhysicalTryCastNode {
371-
expr: Some(Box::new(
372-
proto_converter.physical_expr_to_proto(cast.expr(), codec)?,
373-
)),
374-
arrow_type: Some(cast.cast_type().try_into()?),
375-
},
376-
))),
377-
})
378366
} else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
379367
let mut buf = Vec::new();
380368
codec.try_encode_udf(expr.fun(), &mut buf)?;

0 commit comments

Comments
 (0)