Skip to content

Commit 8e53743

Browse files
committed
test(rule): cover computed sort projections
1 parent da9fb6d commit 8e53743

2 files changed

Lines changed: 85 additions & 5 deletions

File tree

src/rule.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool {
367367
fn build_outer_projection(exprs: &[Expr]) -> Vec<Expr> {
368368
exprs
369369
.iter()
370-
.filter_map(|expr| match expr {
371-
Expr::Alias(a) => Some(col(a.name.as_str())),
372-
Expr::Column(c) => Some(Expr::Column(c.clone())),
373-
_ => None,
370+
.map(|expr| match expr {
371+
Expr::Alias(a) => col(a.name.as_str()),
372+
Expr::Column(c) => Expr::Column(c.clone()),
373+
other => col(other.schema_name().to_string()),
374374
})
375375
.collect()
376376
}

tests/execution.rs

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
use std::sync::Arc;
2020

2121
use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
22-
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array};
22+
use arrow_array::{
23+
FixedSizeListArray, Float32Array, Int64Array, RecordBatch, StringArray, UInt64Array,
24+
};
2325
use arrow_schema::{DataType, Field, Schema};
2426
use datafusion::execution::session_state::SessionStateBuilder;
2527
use datafusion::prelude::SessionContext;
@@ -152,6 +154,60 @@ async fn collect_ids(ctx: &SessionContext, sql: &str) -> Vec<u64> {
152154
ids
153155
}
154156

157+
/// Collect a named integer column from a query result.
158+
async fn collect_i64_column(ctx: &SessionContext, sql: &str, column_name: &str) -> Vec<i64> {
159+
let df = ctx
160+
.sql(sql)
161+
.await
162+
.unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}"));
163+
let batches = df
164+
.collect()
165+
.await
166+
.unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}"));
167+
168+
let mut values: Vec<i64> = vec![];
169+
for batch in &batches {
170+
let col_idx = batch
171+
.schema()
172+
.index_of(column_name)
173+
.unwrap_or_else(|e| panic!("no '{column_name}' column in result: {e}\nSQL: {sql}"));
174+
let column = batch.column(col_idx);
175+
if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() {
176+
values.extend(arr.values().iter().map(|v| *v as i64));
177+
} else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
178+
values.extend(arr.values());
179+
} else {
180+
panic!("column '{column_name}' not Int64/UInt64\nSQL: {sql}");
181+
}
182+
}
183+
values
184+
}
185+
186+
/// Collect the first integer column from a query result.
187+
async fn collect_first_i64_column(ctx: &SessionContext, sql: &str) -> Vec<i64> {
188+
let df = ctx
189+
.sql(sql)
190+
.await
191+
.unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}"));
192+
let batches = df
193+
.collect()
194+
.await
195+
.unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}"));
196+
197+
let mut values: Vec<i64> = vec![];
198+
for batch in &batches {
199+
let column = batch.column(0);
200+
if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() {
201+
values.extend(arr.values().iter().map(|v| *v as i64));
202+
} else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
203+
values.extend(arr.values());
204+
} else {
205+
panic!("first result column not Int64/UInt64\nSQL: {sql}");
206+
}
207+
}
208+
values
209+
}
210+
155211
const Q: &str = "ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float]";
156212

157213
// ═══════════════════════════════════════════════════════════════════════════════
@@ -542,6 +598,30 @@ async fn exec_split_provider_order_by_udf_direct() {
542598
assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}");
543599
}
544600

601+
/// Direct ORDER BY UDF with an aliased computed projection must preserve the
602+
/// computed output through the rewrite.
603+
#[tokio::test]
604+
async fn exec_split_provider_order_by_udf_with_computed_alias() {
605+
let ctx = make_split_provider_ctx("items::vector").await;
606+
let sql = format!(
607+
"SELECT CAST(id + 1 AS BIGINT) AS id_plus FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"
608+
);
609+
let values = collect_i64_column(&ctx, &sql, "id_plus").await;
610+
assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}");
611+
}
612+
613+
/// Direct ORDER BY UDF with an unaliased computed projection relies on the
614+
/// outer projection rebuilding by schema name rather than by raw expression.
615+
#[tokio::test]
616+
async fn exec_split_provider_order_by_udf_with_computed_expr() {
617+
let ctx = make_split_provider_ctx("items::vector").await;
618+
let sql = format!(
619+
"SELECT CAST(id + 1 AS BIGINT) FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"
620+
);
621+
let values = collect_first_i64_column(&ctx, &sql).await;
622+
assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}");
623+
}
624+
545625
/// SELECT * with distance UDF — should fall back to UDF brute-force
546626
/// (since vector column is not in lookup provider schema).
547627
#[tokio::test]

0 commit comments

Comments
 (0)