|
19 | 19 | use std::sync::Arc; |
20 | 20 |
|
21 | 21 | use arrow_array::builder::{FixedSizeListBuilder, Float32Builder}; |
22 | | -use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array}; |
| 22 | +use arrow_array::{ |
| 23 | + FixedSizeListArray, Float32Array, Int64Array, RecordBatch, StringArray, UInt64Array, |
| 24 | +}; |
23 | 25 | use arrow_schema::{DataType, Field, Schema}; |
24 | 26 | use datafusion::execution::session_state::SessionStateBuilder; |
25 | 27 | use datafusion::prelude::SessionContext; |
@@ -152,6 +154,60 @@ async fn collect_ids(ctx: &SessionContext, sql: &str) -> Vec<u64> { |
152 | 154 | ids |
153 | 155 | } |
154 | 156 |
|
| 157 | +/// Collect a named integer column from a query result. |
| 158 | +async fn collect_i64_column(ctx: &SessionContext, sql: &str, column_name: &str) -> Vec<i64> { |
| 159 | + let df = ctx |
| 160 | + .sql(sql) |
| 161 | + .await |
| 162 | + .unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}")); |
| 163 | + let batches = df |
| 164 | + .collect() |
| 165 | + .await |
| 166 | + .unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}")); |
| 167 | + |
| 168 | + let mut values: Vec<i64> = vec![]; |
| 169 | + for batch in &batches { |
| 170 | + let col_idx = batch |
| 171 | + .schema() |
| 172 | + .index_of(column_name) |
| 173 | + .unwrap_or_else(|e| panic!("no '{column_name}' column in result: {e}\nSQL: {sql}")); |
| 174 | + let column = batch.column(col_idx); |
| 175 | + if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() { |
| 176 | + values.extend(arr.values().iter().map(|v| *v as i64)); |
| 177 | + } else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() { |
| 178 | + values.extend(arr.values()); |
| 179 | + } else { |
| 180 | + panic!("column '{column_name}' not Int64/UInt64\nSQL: {sql}"); |
| 181 | + } |
| 182 | + } |
| 183 | + values |
| 184 | +} |
| 185 | + |
| 186 | +/// Collect the first integer column from a query result. |
| 187 | +async fn collect_first_i64_column(ctx: &SessionContext, sql: &str) -> Vec<i64> { |
| 188 | + let df = ctx |
| 189 | + .sql(sql) |
| 190 | + .await |
| 191 | + .unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}")); |
| 192 | + let batches = df |
| 193 | + .collect() |
| 194 | + .await |
| 195 | + .unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}")); |
| 196 | + |
| 197 | + let mut values: Vec<i64> = vec![]; |
| 198 | + for batch in &batches { |
| 199 | + let column = batch.column(0); |
| 200 | + if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() { |
| 201 | + values.extend(arr.values().iter().map(|v| *v as i64)); |
| 202 | + } else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() { |
| 203 | + values.extend(arr.values()); |
| 204 | + } else { |
| 205 | + panic!("first result column not Int64/UInt64\nSQL: {sql}"); |
| 206 | + } |
| 207 | + } |
| 208 | + values |
| 209 | +} |
| 210 | + |
155 | 211 | const Q: &str = "ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float]"; |
156 | 212 |
|
157 | 213 | // ═══════════════════════════════════════════════════════════════════════════════ |
@@ -542,6 +598,30 @@ async fn exec_split_provider_order_by_udf_direct() { |
542 | 598 | assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); |
543 | 599 | } |
544 | 600 |
|
| 601 | +/// Direct ORDER BY UDF with an aliased computed projection must preserve the |
| 602 | +/// computed output through the rewrite. |
| 603 | +#[tokio::test] |
| 604 | +async fn exec_split_provider_order_by_udf_with_computed_alias() { |
| 605 | + let ctx = make_split_provider_ctx("items::vector").await; |
| 606 | + let sql = format!( |
| 607 | + "SELECT CAST(id + 1 AS BIGINT) AS id_plus FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2" |
| 608 | + ); |
| 609 | + let values = collect_i64_column(&ctx, &sql, "id_plus").await; |
| 610 | + assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}"); |
| 611 | +} |
| 612 | + |
| 613 | +/// Direct ORDER BY UDF with an unaliased computed projection relies on the |
| 614 | +/// outer projection rebuilding by schema name rather than by raw expression. |
| 615 | +#[tokio::test] |
| 616 | +async fn exec_split_provider_order_by_udf_with_computed_expr() { |
| 617 | + let ctx = make_split_provider_ctx("items::vector").await; |
| 618 | + let sql = format!( |
| 619 | + "SELECT CAST(id + 1 AS BIGINT) FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2" |
| 620 | + ); |
| 621 | + let values = collect_first_i64_column(&ctx, &sql).await; |
| 622 | + assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}"); |
| 623 | +} |
| 624 | + |
545 | 625 | /// SELECT * with distance UDF — should fall back to UDF brute-force |
546 | 626 | /// (since vector column is not in lookup provider schema). |
547 | 627 | #[tokio::test] |
|
0 commit comments