|
18 | 18 | // |
19 | 19 | // Replacement: |
20 | 20 | // |
21 | | -// Sort(fetch=k) ← kept (sort order) |
22 | | -// Projection([col(a), col(b), col("_distance").alias("dist")]) |
23 | | -// USearchNode ← executes ANN |
| 21 | +// Projection([final output cols]) |
| 22 | +// Sort(fetch=k) |
| 23 | +// Projection([final output cols + optional hidden _distance]) |
| 24 | +// USearchNode |
24 | 25 |
|
25 | 26 | use std::collections::HashMap; |
26 | 27 | use std::sync::Arc; |
@@ -151,25 +152,33 @@ impl USearchRule { |
151 | 152 | node: Arc::new(node) as Arc<dyn UserDefinedLogicalNode>, |
152 | 153 | })); |
153 | 154 |
|
154 | | - // Build Projection over USearchNode matching the original output schema. |
| 155 | + // Build the final user-visible projection over USearchNode output. |
155 | 156 | let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance"); |
156 | | - let new_proj_exprs = if proj_exprs_slice.is_empty() { |
| 157 | + let final_proj_exprs = if proj_exprs_slice.is_empty() { |
157 | 158 | passthrough_projection(&vsn_df_schema, &table_ref) |
158 | 159 | } else { |
159 | 160 | remap_projections(proj_exprs_slice, dist_alias_str, &table_ref) |
160 | 161 | }; |
161 | | - let new_proj = Projection::try_new(new_proj_exprs, node_plan).ok()?; |
162 | | - |
163 | | - // Keep the Sort node so DataFusion handles ordering by _distance / dist. |
164 | | - // USearch returns results in arbitrary (internal) order when the underlying |
165 | | - // data is fetched from the TableProvider. |
166 | | - Some(LogicalPlan::Sort( |
167 | | - datafusion::logical_expr::logical_plan::Sort { |
168 | | - expr: sort.expr.clone(), |
169 | | - input: Arc::new(LogicalPlan::Projection(new_proj)), |
170 | | - fetch: sort.fetch, |
171 | | - }, |
172 | | - )) |
| 162 | + let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref()); |
| 163 | + let needs_hidden_distance = remapped_sort_exprs.iter().any(|e| { |
| 164 | + matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance") |
| 165 | + }) && !projection_exposes_name(&final_proj_exprs, "_distance"); |
| 166 | + |
| 167 | + let mut sort_input_exprs = final_proj_exprs.clone(); |
| 168 | + if needs_hidden_distance { |
| 169 | + sort_input_exprs.push(col("_distance")); |
| 170 | + } |
| 171 | + |
| 172 | + let sort_input = Projection::try_new(sort_input_exprs, node_plan).ok()?; |
| 173 | + let sorted = LogicalPlan::Sort(datafusion::logical_expr::logical_plan::Sort { |
| 174 | + expr: remapped_sort_exprs, |
| 175 | + input: Arc::new(LogicalPlan::Projection(sort_input)), |
| 176 | + fetch: sort.fetch, |
| 177 | + }); |
| 178 | + |
| 179 | + let outer_proj_exprs = build_outer_projection(&final_proj_exprs); |
| 180 | + let outer_proj = Projection::try_new(outer_proj_exprs, Arc::new(sorted)).ok()?; |
| 181 | + Some(LogicalPlan::Projection(outer_proj)) |
173 | 182 | } |
174 | 183 | } |
175 | 184 |
|
@@ -283,7 +292,11 @@ fn dist_type_matches_metric(dist_type: &DistanceType, metric: MetricKind) -> boo |
283 | 292 | } |
284 | 293 |
|
285 | 294 | fn is_distance_expr(expr: &Expr) -> bool { |
286 | | - matches!(expr, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name())) |
| 295 | + let inner = match expr { |
| 296 | + Expr::Alias(a) => a.expr.as_ref(), |
| 297 | + other => other, |
| 298 | + }; |
| 299 | + matches!(inner, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name())) |
287 | 300 | } |
288 | 301 |
|
289 | 302 | fn try_extract_distance(expr: &Expr) -> Option<(String, String, Vec<f64>)> { |
@@ -322,6 +335,45 @@ fn remap_projections( |
322 | 335 | .collect() |
323 | 336 | } |
324 | 337 |
|
| 338 | +fn remap_sort_exprs( |
| 339 | + sort_exprs: &[datafusion::logical_expr::SortExpr], |
| 340 | + dist_alias_name: Option<&str>, |
| 341 | +) -> Vec<datafusion::logical_expr::SortExpr> { |
| 342 | + sort_exprs |
| 343 | + .iter() |
| 344 | + .map(|sort_expr| { |
| 345 | + let remapped_expr = match &sort_expr.expr { |
| 346 | + Expr::Column(c) if Some(c.name.as_str()) == dist_alias_name => col(c.name.as_str()), |
| 347 | + expr if is_distance_expr(expr) => col("_distance"), |
| 348 | + other => other.clone(), |
| 349 | + }; |
| 350 | + datafusion::logical_expr::SortExpr { |
| 351 | + expr: remapped_expr, |
| 352 | + asc: sort_expr.asc, |
| 353 | + nulls_first: sort_expr.nulls_first, |
| 354 | + } |
| 355 | + }) |
| 356 | + .collect() |
| 357 | +} |
| 358 | + |
| 359 | +fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool { |
| 360 | + exprs.iter().any(|expr| match expr { |
| 361 | + Expr::Alias(a) => a.name == name, |
| 362 | + Expr::Column(c) => c.name == name, |
| 363 | + _ => false, |
| 364 | + }) |
| 365 | +} |
| 366 | + |
| 367 | +fn build_outer_projection(exprs: &[Expr]) -> Vec<Expr> { |
| 368 | + exprs.iter() |
| 369 | + .filter_map(|expr| match expr { |
| 370 | + Expr::Alias(a) => Some(col(a.name.as_str())), |
| 371 | + Expr::Column(c) => Some(Expr::Column(c.clone())), |
| 372 | + _ => None, |
| 373 | + }) |
| 374 | + .collect() |
| 375 | +} |
| 376 | + |
325 | 377 | /// Build a passthrough Projection for SELECT * queries (no original Projection node). |
326 | 378 | /// Projects only the original table columns (not `_distance`) so the output schema |
327 | 379 | /// matches the original Sort schema. The Sort re-evaluates the distance UDF expression |
|
0 commit comments