Skip to content

Commit 35e4d72

Browse files
authored
Misc upgrades to VortexSource (#8501)
## Summary 1. Add support for pushing down `byte_length()` 2. Better Display for `VortexSink` 3. Support disabling predicate pushdown --------- Signed-off-by: Adam Gutglick <adam@spiraldb.com>
1 parent 3a7b9a0 commit 35e4d72

7 files changed

Lines changed: 387 additions & 45 deletions

File tree

vortex-datafusion/src/convert/exprs.rs

Lines changed: 108 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use datafusion_common::tree_node::TreeNode;
1111
use datafusion_common::tree_node::TreeNodeRecursion;
1212
use datafusion_expr::Operator as DFOperator;
1313
use datafusion_functions::core::getfield::GetFieldFunc;
14+
use datafusion_functions::string::octet_length::OctetLengthFunc;
1415
use datafusion_physical_expr::PhysicalExpr;
1516
use datafusion_physical_expr::ScalarFunctionExpr;
1617
use datafusion_physical_expr::projection::ProjectionExpr;
@@ -24,6 +25,7 @@ use vortex::dtype::Nullability;
2425
use vortex::dtype::arrow::FromArrowType;
2526
use vortex::expr::Expression;
2627
use vortex::expr::and_collect;
28+
use vortex::expr::byte_length;
2729
use vortex::expr::cast;
2830
use vortex::expr::get_item;
2931
use vortex::expr::is_not_null;
@@ -111,8 +113,28 @@ pub trait ExpressionConvertor: Send + Sync {
111113
pub struct DefaultExpressionConvertor {}
112114

113115
impl DefaultExpressionConvertor {
116+
/// Attempts to convert DataFusion's `octet_length` function to Vortex `byte_length`.
117+
fn try_convert_octet_length(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
118+
let [input] = scalar_fn.args() else {
119+
return Err(exec_datafusion_err!(
120+
"octet_length requires exactly one argument"
121+
));
122+
};
123+
124+
let input = self.convert(input.as_ref())?;
125+
let return_dtype =
126+
DType::from_arrow((scalar_fn.return_type(), scalar_fn.nullable().into()));
127+
Ok(cast(byte_length(input), return_dtype))
128+
}
129+
114130
/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
115131
fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
132+
if let Some(octet_length_fn) =
133+
ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
134+
{
135+
return self.try_convert_octet_length(octet_length_fn);
136+
}
137+
116138
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
117139
{
118140
// DataFusion's GetFieldFunc flattens nested field access into a single call
@@ -289,7 +311,7 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
289311
let r = projection_expr.expr.apply(|node| {
290312
// We only pull column children of scalar functions that we can't push into the scan.
291313
if let Some(scalar_fn_expr) = node.downcast_ref::<ScalarFunctionExpr>()
292-
&& !can_scalar_fn_be_pushed_down(scalar_fn_expr)
314+
&& !can_scalar_fn_be_pushed_down(scalar_fn_expr, input_schema)
293315
{
294316
scan_projection.extend(
295317
collect_columns(node)
@@ -305,8 +327,8 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
305327
// Vortex expects a perfect match so we don't push it down.
306328
if let Some(binary_expr) = node.downcast_ref::<df_expr::BinaryExpr>()
307329
&& binary_expr.op().is_numerical_operators()
308-
&& (is_decimal(&binary_expr.left().data_type(input_schema)?)
309-
&& is_decimal(&binary_expr.right().data_type(input_schema)?))
330+
&& binary_expr.left().data_type(input_schema)?.is_decimal()
331+
&& binary_expr.right().data_type(input_schema)?.is_decimal()
310332
{
311333
scan_projection.extend(
312334
collect_columns(node)
@@ -430,7 +452,7 @@ fn can_be_pushed_down_impl(expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> boo
430452
.iter()
431453
.all(|e| can_be_pushed_down_impl(e, schema))
432454
} else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
433-
can_scalar_fn_be_pushed_down(scalar_fn)
455+
can_scalar_fn_be_pushed_down(scalar_fn, schema)
434456
} else if let Some(case_expr) = expr.downcast_ref::<df_expr::CaseExpr>() {
435457
can_case_be_pushed_down(case_expr, schema)
436458
} else {
@@ -454,9 +476,10 @@ fn is_convertible_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
454476
|| expr.downcast_ref::<df_expr::IsNullExpr>().is_some()
455477
|| expr.downcast_ref::<df_expr::IsNotNullExpr>().is_some()
456478
|| expr.downcast_ref::<df_expr::InListExpr>().is_some()
457-
|| expr
458-
.downcast_ref::<ScalarFunctionExpr>()
459-
.is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some())
479+
|| expr.downcast_ref::<ScalarFunctionExpr>().is_some_and(|sf| {
480+
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some()
481+
|| ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(sf).is_some()
482+
})
460483
}
461484

462485
fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
@@ -502,20 +525,11 @@ fn supported_data_types(dt: &DataType) -> bool {
502525

503526
let is_supported = dt.is_null()
504527
|| dt.is_numeric()
528+
|| dt.is_binary()
529+
|| dt.is_string()
505530
|| matches!(
506531
dt,
507-
Boolean
508-
| Utf8
509-
| LargeUtf8
510-
| Utf8View
511-
| Binary
512-
| LargeBinary
513-
| BinaryView
514-
| Date32
515-
| Date64
516-
| Timestamp(_, _)
517-
| Time32(_)
518-
| Time64(_)
532+
Boolean | Date32 | Date64 | Timestamp(_, _) | Time32(_) | Time64(_)
519533
);
520534

521535
if !is_supported {
@@ -526,20 +540,30 @@ fn supported_data_types(dt: &DataType) -> bool {
526540
}
527541

528542
/// Checks if a scalar function can be pushed down.
529-
/// Currently only GetFieldFunc is supported.
530-
fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
531-
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
543+
/// Currently GetFieldFunc and OctetLengthFunc are supported.
544+
fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
545+
if ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some() {
546+
return true;
547+
}
548+
549+
ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
550+
.is_some_and(|octet_length| can_octet_length_be_pushed_down(octet_length, schema))
532551
}
533552

534-
// TODO(adam): Replace with `DataType::is_decimal` once its released.
535-
fn is_decimal(dt: &DataType) -> bool {
536-
matches!(
537-
dt,
538-
DataType::Decimal32(_, _)
539-
| DataType::Decimal64(_, _)
540-
| DataType::Decimal128(_, _)
541-
| DataType::Decimal256(_, _)
542-
)
553+
fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
554+
let [input] = scalar_fn.args() else {
555+
return false;
556+
};
557+
558+
input.data_type(schema).as_ref().is_ok_and(|data_type| {
559+
let dt = if let DataType::Dictionary(_, value_type) = data_type {
560+
value_type.as_ref()
561+
} else {
562+
data_type
563+
};
564+
565+
dt.is_binary() || dt.is_string()
566+
}) && can_be_pushed_down_impl(input, schema)
543567
}
544568

545569
#[cfg(test)]
@@ -553,7 +577,9 @@ mod tests {
553577
use datafusion::arrow::array::AsArray;
554578
use datafusion::arrow::datatypes::Int32Type;
555579
use datafusion_common::ScalarValue;
580+
use datafusion_common::config::ConfigOptions;
556581
use datafusion_expr::Operator as DFOperator;
582+
use datafusion_expr::ScalarUDF;
557583
use datafusion_physical_expr::PhysicalExpr;
558584
use datafusion_physical_plan::expressions as df_expr;
559585
use insta::assert_snapshot;
@@ -582,6 +608,18 @@ mod tests {
582608
])
583609
}
584610

611+
fn octet_length_expr(input: Arc<dyn PhysicalExpr>, schema: &Schema) -> Arc<dyn PhysicalExpr> {
612+
Arc::new(
613+
ScalarFunctionExpr::try_new(
614+
Arc::new(ScalarUDF::from(OctetLengthFunc::new())),
615+
vec![input],
616+
schema,
617+
Arc::new(ConfigOptions::new()),
618+
)
619+
.unwrap(),
620+
)
621+
}
622+
585623
#[test]
586624
fn test_make_vortex_predicate_empty() {
587625
let expr_convertor = DefaultExpressionConvertor::default();
@@ -711,6 +749,23 @@ mod tests {
711749
);
712750
}
713751

752+
#[rstest]
753+
fn test_expr_from_df_octet_length(test_schema: Schema) {
754+
let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
755+
let octet_length = octet_length_expr(expr, &test_schema);
756+
757+
let result = DefaultExpressionConvertor::default()
758+
.convert(octet_length.as_ref())
759+
.unwrap();
760+
761+
assert_snapshot!(result.display_tree().to_string(), @r"
762+
vortex.cast(i32?)
763+
└── input: vortex.byte_length()
764+
└── input: vortex.get_item(name)
765+
└── input: vortex.root()
766+
");
767+
}
768+
714769
#[rstest]
715770
// Supported types
716771
#[case::null(DataType::Null, true)]
@@ -865,6 +920,28 @@ mod tests {
865920
assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
866921
}
867922

923+
#[rstest]
924+
fn test_can_be_pushed_down_octet_length_supported(test_schema: Schema) {
925+
let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
926+
let octet_length = octet_length_expr(expr, &test_schema);
927+
928+
assert!(can_be_pushed_down_impl(&octet_length, &test_schema));
929+
}
930+
931+
#[rstest]
932+
fn test_can_be_pushed_down_octet_length_unsupported_operand(test_schema: Schema) {
933+
let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
934+
let octet_length = Arc::new(ScalarFunctionExpr::new(
935+
"octet_length",
936+
Arc::new(ScalarUDF::from(OctetLengthFunc::new())),
937+
vec![expr],
938+
Arc::new(Field::new("octet_length", DataType::Int32, true)),
939+
Arc::new(ConfigOptions::new()),
940+
)) as Arc<dyn PhysicalExpr>;
941+
942+
assert!(!can_be_pushed_down_impl(&octet_length, &test_schema));
943+
}
944+
868945
// https://github.com/vortex-data/vortex/issues/6211
869946
#[tokio::test]
870947
async fn test_cast_int_to_string() -> anyhow::Result<()> {

vortex-datafusion/src/persistent/format.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ config_namespace! {
146146
///
147147
/// let factory = VortexFormatFactory::new().with_options(VortexTableOptions {
148148
/// projection_pushdown: true,
149+
/// predicate_pushdown: true,
149150
/// scan_concurrency: Some(8),
150151
/// ..Default::default()
151152
/// });
@@ -165,6 +166,12 @@ config_namespace! {
165166
/// the scan. When disabled, Vortex reads only the referenced columns and
166167
/// all expressions are evaluated after the scan.
167168
pub projection_pushdown: bool, default = false
169+
/// Whether to enable predicate pushdown into the underlying Vortex scan.
170+
///
171+
/// When enabled, supported filters are evaluated during the scan. When
172+
/// disabled, DataFusion evaluates filters after the scan, while
173+
/// `VortexSource` can still use the full predicate for file pruning.
174+
pub predicate_pushdown: bool, default = true
168175
/// The intra-partition scan concurrency, controlling the number of row splits to process
169176
/// concurrently per-thread within each file.
170177
///
@@ -198,6 +205,7 @@ impl Eq for VortexTableOptions {}
198205
///
199206
/// let factory = Arc::new(VortexFormatFactory::new().with_options(VortexTableOptions {
200207
/// projection_pushdown: true,
208+
/// predicate_pushdown: true,
201209
/// ..Default::default()
202210
/// }));
203211
///
@@ -263,6 +271,7 @@ impl VortexFormatFactory {
263271
///
264272
/// let factory = VortexFormatFactory::new().with_options(VortexTableOptions {
265273
/// projection_pushdown: true,
274+
/// predicate_pushdown: true,
266275
/// ..Default::default()
267276
/// });
268277
/// # let _ = factory;
@@ -617,14 +626,9 @@ impl FileFormat for VortexFormat {
617626
}
618627

619628
fn file_source(&self, table_schema: TableSchema) -> Arc<dyn FileSource> {
620-
let mut source = VortexSource::new(table_schema, self.session.clone())
621-
.with_projection_pushdown(self.opts.projection_pushdown);
622-
623-
if let Some(scan_concurrency) = self.opts.scan_concurrency {
624-
source = source.with_scan_concurrency(scan_concurrency);
625-
}
626-
627-
Arc::new(source) as _
629+
Arc::new(
630+
VortexSource::new(table_schema, self.session.clone()).with_options(self.opts.clone()),
631+
) as _
628632
}
629633
}
630634

@@ -682,7 +686,7 @@ mod tests {
682686
(c1 VARCHAR NOT NULL, c2 INT NOT NULL) \
683687
STORED AS vortex \
684688
LOCATION 'table/' \
685-
OPTIONS( footer_initial_read_size_bytes '12345', scan_concurrency '3' );",
689+
OPTIONS( footer_initial_read_size_bytes '12345', predicate_pushdown 'false', scan_concurrency '3' );",
686690
)
687691
.await?
688692
.collect()
@@ -699,4 +703,24 @@ mod tests {
699703
let format = VortexFormat::new_with_options(VortexSession::default(), opts);
700704
assert_eq!(format.options().footer_initial_read_size_bytes, 12345);
701705
}
706+
707+
#[test]
708+
fn format_plumbs_source_options() -> anyhow::Result<()> {
709+
let opts = VortexTableOptions {
710+
projection_pushdown: true,
711+
predicate_pushdown: false,
712+
scan_concurrency: Some(3),
713+
..Default::default()
714+
};
715+
let format = VortexFormat::new_with_options(VortexSession::default(), opts.clone());
716+
let table_schema = TableSchema::from_file_schema(Arc::new(Schema::empty()));
717+
718+
let source = format.file_source(table_schema);
719+
let source = source
720+
.downcast_ref::<VortexSource>()
721+
.ok_or_else(|| anyhow::anyhow!("expected VortexSource"))?;
722+
723+
assert_eq!(source.options(), &opts);
724+
Ok(())
725+
}
702726
}

0 commit comments

Comments
 (0)