Skip to content

Commit 96147b4

Browse files
authored
fix: correctly specify schema in encoded substrait for expr (#3937)
This also moves to using datafusion's `to_substrait_extended_expr` instead of creating a dummy plan like we were before (we should do this on the parse path eventually too)
1 parent 8b96f1b commit 96147b4

1 file changed

Lines changed: 39 additions & 48 deletions

File tree

rust/lance-datafusion/src/substrait.rs

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,58 +29,32 @@ use std::collections::HashMap;
2929
use std::sync::Arc;
3030

3131
/// Convert a DF Expr into a Substrait ExtendedExpressions message
32+
///
33+
/// The schema needs to contain all of the fields that are referenced in the expression.
34+
/// It is ok if the schema has more fields than are required. However, we cannot currently
35+
/// convert all field types (e.g. extension types, FSL) and if these fields are present then
36+
/// the conversion will fail.
37+
///
38+
/// As a result, it may be a good idea for now to remove those types from the schema before
39+
/// calling this function.
3240
pub fn encode_substrait(expr: Expr, schema: Arc<ArrowSchema>) -> Result<Vec<u8>> {
33-
use datafusion::logical_expr::{builder::LogicalTableSource, logical_plan, LogicalPlan};
34-
use datafusion_substrait::substrait::proto::{plan_rel, ExpressionReference, NamedStruct};
35-
36-
let table_source = Arc::new(LogicalTableSource::new(schema.clone()));
41+
use arrow_schema::Field;
42+
use datafusion::logical_expr::ExprSchemable;
43+
use datafusion_common::DFSchema;
3744

38-
// DF doesn't handled ExtendedExpressions and so we need to create
39-
// a dummy plan with a single filter node
40-
let plan = LogicalPlan::Filter(logical_plan::Filter::try_new(
41-
expr,
42-
Arc::new(LogicalPlan::TableScan(logical_plan::TableScan::try_new(
43-
"dummy",
44-
table_source,
45-
None,
46-
vec![],
47-
None,
48-
)?)),
49-
)?);
45+
let ctx = SessionContext::new();
5046

51-
let session_context = SessionContext::new();
52-
53-
let substrait_plan = datafusion_substrait::logical_plan::producer::to_substrait_plan(
54-
&plan,
55-
&session_context.state(),
47+
let df_schema = Arc::new(DFSchema::try_from(schema)?);
48+
let output_type = expr.get_type(&df_schema)?;
49+
// Nullability doesn't matter
50+
let output_field = Field::new("output", output_type, /*nullable=*/ true);
51+
let extended_expr = datafusion_substrait::logical_plan::producer::to_substrait_extended_expr(
52+
&[(&expr, &output_field)],
53+
&df_schema,
54+
&ctx.state(),
5655
)?;
5756

58-
if let Some(plan_rel::RelType::Root(root)) = &substrait_plan.relations[0].rel_type {
59-
if let Some(rel::RelType::Filter(filt)) = &root.input.as_ref().unwrap().rel_type {
60-
let expr = filt.condition.as_ref().unwrap().clone();
61-
let schema = NamedStruct {
62-
names: schema.fields().iter().map(|f| f.name().clone()).collect(),
63-
r#struct: None,
64-
};
65-
let envelope = ExtendedExpression {
66-
advanced_extensions: substrait_plan.advanced_extensions.clone(),
67-
base_schema: Some(schema),
68-
expected_type_urls: substrait_plan.expected_type_urls.clone(),
69-
extension_uris: substrait_plan.extension_uris.clone(),
70-
extensions: substrait_plan.extensions.clone(),
71-
referred_expr: vec![ExpressionReference {
72-
output_names: vec![],
73-
expr_type: Some(ExprType::Expression(*expr)),
74-
}],
75-
version: substrait_plan.version.clone(),
76-
};
77-
Ok(envelope.encode_to_vec())
78-
} else {
79-
unreachable!()
80-
}
81-
} else {
82-
unreachable!()
83-
}
57+
Ok(extended_expr.encode_to_vec())
8458
}
8559

8660
fn count_fields(dtype: &Type) -> usize {
@@ -425,7 +399,7 @@ mod tests {
425399
helpers::{literals::literal, schema::SchemaInfo},
426400
};
427401

428-
use crate::substrait::parse_substrait;
402+
use crate::substrait::{encode_substrait, parse_substrait};
429403

430404
#[tokio::test]
431405
async fn test_substrait_conversion() {
@@ -462,4 +436,21 @@ mod tests {
462436
});
463437
assert_eq!(df_expr, expected);
464438
}
439+
440+
#[tokio::test]
441+
async fn test_expr_substrait_roundtrip() {
442+
let schema = arrow_schema::Schema::new(vec![Field::new("x", DataType::Int32, true)]);
443+
let expr = Expr::BinaryExpr(BinaryExpr {
444+
left: Box::new(Expr::Column(Column::new_unqualified("x"))),
445+
op: Operator::Lt,
446+
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))),
447+
});
448+
449+
let bytes = encode_substrait(expr.clone(), Arc::new(schema.clone())).unwrap();
450+
451+
let decoded = parse_substrait(bytes.as_slice(), Arc::new(schema.clone()))
452+
.await
453+
.unwrap();
454+
assert_eq!(decoded, expr);
455+
}
465456
}

0 commit comments

Comments
 (0)