diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 8534ddb8d104f..f7f20378cf3f0 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -22,8 +22,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ColumnarValue; use std::hash::Hash; use std::sync::Arc; @@ -102,6 +101,60 @@ impl PhysicalExpr for IsNullExpr { self.arg.fmt_sql(f)?; write!(f, " IS NULL") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { + expr: Some(Box::new(ctx.encode_child(&self.arg)?)), + }), + )), + })) + } +} + +#[cfg(feature = "proto")] +impl IsNullExpr { + /// Reconstruct an [`IsNullExpr`] from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`] — the exact inverse of what + /// [`PhysicalExpr::try_to_proto`] produces — so every expression's + /// `try_from_proto` shares one signature. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + /// [`PhysicalExpr::try_to_proto`]: datafusion_physical_expr_common::physical_expr::PhysicalExpr::try_to_proto + /// [`PhysicalExprDecodeCtx::decode`]: datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx::decode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + let node = match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::IsNullExpr(node)) => { + node.as_ref() + } + _ => { + return datafusion_common::internal_err!( + "PhysicalExprNode is not an IsNullExpr" + ); + } + }; + let expr = node.expr.as_deref().ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "IsNullExpr is missing required field 'expr'".to_string(), + ) + })?; + + Ok(Arc::new(IsNullExpr::new(ctx.decode(expr)?))) + } } /// Create an IS NULL expression @@ -112,6 +165,8 @@ pub fn is_null(arg: Arc) -> Result> { #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "proto")] + use crate::expressions::Column; use crate::expressions::col; use arrow::array::{ Array, BooleanArray, Float64Array, Int32Array, StringArray, UnionArray, @@ -121,6 +176,54 @@ mod tests { use datafusion_common::cast::as_boolean_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; + #[cfg(feature = "proto")] + use datafusion_physical_expr_common::physical_expr::{ + proto_decode::{PhysicalExprDecode, PhysicalExprDecodeCtx}, + proto_encode::{PhysicalExprEncode, PhysicalExprEncodeCtx}, + }; + #[cfg(feature = "proto")] + use datafusion_proto_models::protobuf; + + #[cfg(feature = "proto")] + struct TestProtoCodec; + + #[cfg(feature = "proto")] + impl PhysicalExprEncode for TestProtoCodec { + fn encode( + &self, + expr: &Arc, + ) -> Result { + let ctx = PhysicalExprEncodeCtx::new(self); + expr.try_to_proto(&ctx)?.ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Expression did not serialize in test codec".to_string(), + ) + }) + } + } + + #[cfg(feature = "proto")] + impl PhysicalExprDecode for TestProtoCodec { + fn decode( + &self, + node: &protobuf::PhysicalExprNode, + schema: &Schema, + ) -> Result> { + let ctx = PhysicalExprDecodeCtx::new(schema, self); + match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::Column(_)) => { + Column::try_from_proto(node, &ctx) + } + Some(protobuf::physical_expr_node::ExprType::IsNullExpr(_)) => { + IsNullExpr::try_from_proto(node, &ctx) + } + _ => datafusion_common::internal_err!( + "Unsupported expression in test decoder" + ), + } + } + } + #[test] fn is_null_op() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); @@ -223,4 +326,72 @@ mod tests { Ok(()) } + + #[cfg(feature = "proto")] + #[test] + fn is_null_proto_hook_roundtrip() -> Result<()> { + let arg = Arc::new(Column::new("a", 0)) as Arc; + let expr = Arc::new(IsNullExpr::new(arg)) as Arc; + + let codec = TestProtoCodec; + let proto = codec.encode(&expr)?; + + let is_null = match &proto.expr_type { + Some(protobuf::physical_expr_node::ExprType::IsNullExpr(is_null)) => is_null, + other => panic!("Expected IsNullExpr proto, got {other:?}"), + }; + assert!(matches!( + is_null + .expr + .as_deref() + .and_then(|expr| expr.expr_type.as_ref()), + Some(protobuf::physical_expr_node::ExprType::Column(_)) + )); + + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let decode_ctx = PhysicalExprDecodeCtx::new(&schema, &codec); + let decoded = IsNullExpr::try_from_proto(&proto, &decode_ctx)?; + + assert_eq!(decoded.to_string(), "a@0 IS NULL"); + Ok(()) + } + + #[cfg(feature = "proto")] + #[test] + fn is_null_try_from_proto_errors() { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let codec = TestProtoCodec; + let decode_ctx = PhysicalExprDecodeCtx::new(&schema, &codec); + + let wrong_variant = protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + )), + }; + let error = IsNullExpr::try_from_proto(&wrong_variant, &decode_ctx) + .expect_err("wrong variant should error") + .strip_backtrace(); + assert!( + error.contains("PhysicalExprNode is not an IsNullExpr"), + "{error}" + ); + + let missing_child = protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { expr: None }), + )), + }; + let error = IsNullExpr::try_from_proto(&missing_child, &decode_ctx) + .expect_err("missing child should error") + .strip_backtrace(); + assert!( + error.contains("IsNullExpr is missing required field 'expr'"), + "{error}" + ); + } } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 96144b11e9d3a..aec05b5b9690c 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -294,15 +294,7 @@ pub fn parse_physical_expr_with_converter( ExprType::Sort(_) => { return not_impl_err!("Cannot convert sort expr node to physical expression"); } - ExprType::IsNullExpr(e) => { - Arc::new(IsNullExpr::new(parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - proto_converter, - )?)) - } + ExprType::IsNullExpr(_) => IsNullExpr::try_from_proto(proto, &decode_ctx)?, ExprType::IsNotNullExpr(e) => { Arc::new(IsNotNullExpr::new(parse_required_physical_expr( e.expr.as_deref(), diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5dd643c84ba21..8806a0b9168ed 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -36,8 +36,8 @@ use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::expressions::{ - CaseExpr, CastExpr, DynamicFilterPhysicalExpr, IsNotNullExpr, IsNullExpr, Literal, - NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + CaseExpr, CastExpr, DynamicFilterPhysicalExpr, IsNotNullExpr, Literal, NegativeExpr, + NotExpr, TryCastExpr, UnKnownColumn, }; use datafusion_physical_plan::joins::{HashExpr, HashTableLookupExpr}; use datafusion_physical_plan::udaf::AggregateFunctionExpr; @@ -390,17 +390,6 @@ pub fn serialize_physical_expr_with_converter( }, ))), }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_id, - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new( - proto_converter.physical_expr_to_proto(expr.arg(), codec)?, - )), - }), - )), - }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id,