Skip to content

Commit bb84358

Browse files
committed
feat(physical-expr-common): add proto helpers for variant match and required scalar field
Expressions migrating to `try_to_proto` / `try_from_proto` (apache#22418) keep hand-rolling two shapes that don't fit the existing `encode_child` / `decode_required_expression` helpers from apache#22513: - the outer `match &node.expr_type { ... }` that opens every `try_from_proto`, and - the `ok_or_else(|| internal_datafusion_err!("X is missing required field 'Y'"))` for non-expression fields like `arrow_type` on `CastExpr` / `TryCastExpr`. This commit adds two thin helpers, both gated on `feature = "proto"`: - `expect_expr_variant!` macro (re-exported at crate root) — matches `Option<ExprType>` and returns the inner payload (auto-derefs through `Box`), or returns an `internal_err!` naming the expected variant. - `proto_decode::require_proto_field<T>` — mirrors `decode_required_expression` for non-`PhysicalExprNode` fields, keeping the "missing required field" message format in one place. No existing call sites are migrated in this commit; the follow-up ports the already-migrated expressions onto these helpers.
1 parent cab69a1 commit bb84358

1 file changed

Lines changed: 149 additions & 0 deletions

File tree

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,48 @@ pub mod proto_decode {
619619

620620
use super::PhysicalExpr;
621621

622+
/// Open the outer [`PhysicalExprNode`] and assert it carries the expected
623+
/// `ExprType` variant, returning the inner payload (auto-derefs through
624+
/// `Box`) or bailing with an `Internal` error.
625+
///
626+
/// Every `try_from_proto` starts with the same six-line `match`:
627+
///
628+
/// ```ignore
629+
/// let try_cast = match &node.expr_type {
630+
/// Some(protobuf::physical_expr_node::ExprType::TryCast(x)) => x.as_ref(),
631+
/// _ => return internal_err!("PhysicalExprNode is not a TryCastExpr"),
632+
/// };
633+
/// ```
634+
///
635+
/// With this macro that collapses to:
636+
///
637+
/// ```ignore
638+
/// let try_cast = expect_expr_variant!(
639+
/// node,
640+
/// protobuf::physical_expr_node::ExprType::TryCast,
641+
/// "TryCastExpr",
642+
/// );
643+
/// ```
644+
///
645+
/// Pass the variant as a `::` path so the macro stays agnostic to how
646+
/// the caller imports the proto types.
647+
#[macro_export]
648+
macro_rules! expect_expr_variant {
649+
($node:expr, $variant:path, $expr_name:literal $(,)?) => {{
650+
match &$node.expr_type {
651+
::core::option::Option::Some($variant(inner)) => inner,
652+
_ => {
653+
return ::datafusion_common::internal_err!(concat!(
654+
"PhysicalExprNode is not a ",
655+
$expr_name
656+
));
657+
}
658+
}
659+
}};
660+
}
661+
#[doc(inline)]
662+
pub use expect_expr_variant;
663+
622664
/// Decoder context handed to per-expression `try_from_proto` constructors.
623665
///
624666
/// Wraps an internal [`PhysicalExprDecode`] trait object plus a borrowed
@@ -693,6 +735,33 @@ pub mod proto_decode {
693735
}
694736
}
695737

738+
/// Unwrap a required non-expression proto field.
739+
///
740+
/// Mirrors [`PhysicalExprDecodeCtx::decode_required_expression`] for proto
741+
/// fields that aren't [`PhysicalExprNode`]s — e.g. the `arrow_type` of a
742+
/// `PhysicalCastNode` or the `scalar` of a `PhysicalLiteralNode`. Keeps
743+
/// the "missing required field" message format identical across
744+
/// expressions:
745+
///
746+
/// ```ignore
747+
/// let arrow_type = require_proto_field(
748+
/// cast_expr.arrow_type.as_ref(),
749+
/// "CastExpr",
750+
/// "arrow_type",
751+
/// )?;
752+
/// ```
753+
pub fn require_proto_field<T>(
754+
opt: Option<T>,
755+
expr_name: &str,
756+
field: &str,
757+
) -> Result<T> {
758+
opt.ok_or_else(|| {
759+
datafusion_common::internal_datafusion_err!(
760+
"{expr_name} is missing required field '{field}'"
761+
)
762+
})
763+
}
764+
696765
/// Internal dispatch trait. Implementors live in `datafusion-proto`.
697766
/// Expression authors should use [`PhysicalExprDecodeCtx`] instead of
698767
/// calling this directly.
@@ -1143,3 +1212,83 @@ mod test {
11431212
);
11441213
}
11451214
}
1215+
1216+
#[cfg(all(test, feature = "proto"))]
1217+
mod proto_helper_tests {
1218+
use datafusion_common::DataFusionError;
1219+
use datafusion_proto_models::protobuf::{
1220+
self, PhysicalColumn, PhysicalExprNode, physical_expr_node,
1221+
};
1222+
1223+
use crate::expect_expr_variant;
1224+
use crate::physical_expr::proto_decode::require_proto_field;
1225+
1226+
fn column_node() -> PhysicalExprNode {
1227+
PhysicalExprNode {
1228+
expr_id: None,
1229+
expr_type: Some(physical_expr_node::ExprType::Column(PhysicalColumn {
1230+
name: "a".to_string(),
1231+
index: 0,
1232+
})),
1233+
}
1234+
}
1235+
1236+
#[test]
1237+
fn require_proto_field_returns_inner() {
1238+
let v = require_proto_field(Some(7_u32), "FooExpr", "answer").unwrap();
1239+
assert_eq!(v, 7);
1240+
}
1241+
1242+
#[test]
1243+
fn require_proto_field_reports_missing() {
1244+
let err = require_proto_field::<u32>(None, "FooExpr", "answer").unwrap_err();
1245+
assert!(matches!(
1246+
err,
1247+
DataFusionError::Internal(msg)
1248+
if msg.contains("FooExpr is missing required field 'answer'")
1249+
));
1250+
}
1251+
1252+
fn expect_column(
1253+
node: &PhysicalExprNode,
1254+
) -> Result<&PhysicalColumn, DataFusionError> {
1255+
let inner =
1256+
expect_expr_variant!(node, physical_expr_node::ExprType::Column, "Column",);
1257+
Ok(inner)
1258+
}
1259+
1260+
#[test]
1261+
fn expect_expr_variant_returns_inner_payload() {
1262+
let node = column_node();
1263+
let col = expect_column(&node).unwrap();
1264+
assert_eq!(col.name, "a");
1265+
}
1266+
1267+
#[test]
1268+
fn expect_expr_variant_rejects_wrong_variant() {
1269+
let node = PhysicalExprNode {
1270+
expr_id: None,
1271+
expr_type: Some(physical_expr_node::ExprType::Negative(Box::new(
1272+
protobuf::PhysicalNegativeNode { expr: None },
1273+
))),
1274+
};
1275+
let err = expect_column(&node).unwrap_err();
1276+
assert!(matches!(
1277+
err,
1278+
DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a Column")
1279+
));
1280+
}
1281+
1282+
#[test]
1283+
fn expect_expr_variant_rejects_missing_expr_type() {
1284+
let node = PhysicalExprNode {
1285+
expr_id: None,
1286+
expr_type: None,
1287+
};
1288+
let err = expect_column(&node).unwrap_err();
1289+
assert!(matches!(
1290+
err,
1291+
DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a Column")
1292+
));
1293+
}
1294+
}

0 commit comments

Comments
 (0)