Skip to content

Commit e2cb497

Browse files
committed
Check that the join expression has the right output_type
1 parent b48e2a6 commit e2cb497

1 file changed

Lines changed: 22 additions & 5 deletions

File tree

  • datafusion/substrait/src/logical_plan/producer/rel

datafusion/substrait/src/logical_plan/producer/rel/join.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,18 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
9696

9797
#[cfg(test)]
9898
mod tests {
99-
use crate::logical_plan::producer::{DefaultSubstraitProducer, SubstraitProducer};
99+
use crate::logical_plan::producer::{
100+
DefaultSubstraitProducer, SubstraitProducer, to_substrait_type,
101+
};
100102
use datafusion::arrow::datatypes::{DataType, Field, Schema};
101103
use datafusion::common::{JoinConstraint, JoinType, NullEquality};
102104
use datafusion::execution::SessionStateBuilder;
103105
use datafusion::logical_expr::utils::conjunction;
104106
use datafusion::logical_expr::{Join, col, table_scan};
105107
use std::sync::Arc;
108+
use substrait::proto::expression::{RexType, ScalarFunction};
106109
use substrait::proto::rel::RelType;
107-
use substrait::proto::{JoinRel, Rel, join_rel};
110+
use substrait::proto::{Expression, JoinRel, Rel, join_rel};
108111

109112
#[test]
110113
fn test_from_join() -> datafusion::common::Result<()> {
@@ -139,6 +142,9 @@ mod tests {
139142
col("t1.c").gt(col("t2.c")),
140143
])
141144
.unwrap();
145+
let expected_join_expression =
146+
producer.handle_expr(&expected_join_expr, &in_join_schema)?;
147+
142148
assert_eq!(
143149
join_expr,
144150
Box::new(Rel {
@@ -147,14 +153,25 @@ mod tests {
147153
left: Some(producer.handle_plan(&left_scan)?),
148154
right: Some(producer.handle_plan(&right_scan)?),
149155
r#type: join_rel::JoinType::Inner as i32,
150-
expression: Some(Box::new(
151-
producer.handle_expr(&expected_join_expr, &in_join_schema)?
152-
)),
156+
expression: Some(Box::new(expected_join_expression.clone())),
153157
post_join_filter: None,
154158
advanced_extension: None,
155159
})))
156160
})
157161
);
162+
163+
// Check that the join_expression has the expected output_type
164+
if let Expression {
165+
rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })),
166+
} = expected_join_expression
167+
{
168+
let expected_type =
169+
to_substrait_type(&mut producer, &DataType::Boolean, false)?;
170+
assert_eq!(output_type, Some(expected_type));
171+
} else {
172+
panic!("Substrait ScalarFunction expected")
173+
}
174+
158175
Ok(())
159176
}
160177
}

0 commit comments

Comments
 (0)