11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- //! Planar distance between the paired points of two columns .
4+ //! Planar distance from a point column to a constant query point .
55
66use vortex_array:: ArrayRef ;
77use vortex_array:: ExecutionCtx ;
@@ -19,8 +19,10 @@ use vortex_array::scalar_fn::ScalarFnId;
1919use vortex_array:: scalar_fn:: ScalarFnVTable ;
2020use vortex_array:: scalar_fn:: TypedScalarFnInstance ;
2121use vortex_error:: VortexResult ;
22+ use vortex_error:: vortex_bail;
2223use vortex_session:: VortexSession ;
2324
25+ use crate :: extension:: coordinate_from_scalar;
2426use crate :: extension:: xy_columns;
2527
2628/// Planar Euclidean distance between `(ax, ay)` and `(bx, by)`.
@@ -30,14 +32,17 @@ fn euclidean_distance(ax: f64, ay: f64, bx: f64, by: f64) -> f64 {
3032 ( dx * dx + dy * dy) . sqrt ( )
3133}
3234
33- /// Expression computing the planar distance between the paired points of two columns. A constant
34- /// query point is just a [`ConstantArray`](vortex_array::arrays::ConstantArray) operand.
35+ /// Planar distance from each point in a point column to a single constant query point. The first
36+ /// operand is the point column, the second the constant query point; column-to-column distance is
37+ /// not supported.
3538#[ derive( Debug , Clone , Default , PartialEq , Eq , Hash ) ]
3639pub struct GeoDistance ;
3740
3841impl GeoDistance {
39- /// A lazy `ScalarFnArray` computing the distance between each row of `a` and `b`.
40- pub fn try_new_array ( a : ArrayRef , b : ArrayRef , len : usize ) -> VortexResult < ScalarFnArray > {
42+ /// A lazy `ScalarFnArray` computing the distance from each row of the point column `a` to the
43+ /// constant query point `b`. The output length is taken from `a`.
44+ pub fn try_new_array ( a : ArrayRef , b : ArrayRef ) -> VortexResult < ScalarFnArray > {
45+ let len = a. len ( ) ;
4146 ScalarFnArray :: try_new (
4247 TypedScalarFnInstance :: new ( GeoDistance , EmptyOptions ) . erased ( ) ,
4348 vec ! [ a, b] ,
@@ -83,21 +88,26 @@ impl ScalarFnVTable for GeoDistance {
8388 args : & dyn ExecutionArgs ,
8489 ctx : & mut ExecutionCtx ,
8590 ) -> VortexResult < ArrayRef > {
86- // Bulk path: one tight loop over the flat x/y slices, straight into the output buffer.
87- let ( ax, ay) = xy_columns ( & args. get ( 0 ) ?, ctx) ?;
88- let ( bx, by) = xy_columns ( & args. get ( 1 ) ?, ctx) ?;
89- let a = ax. as_slice :: < f64 > ( ) . iter ( ) . zip ( ay. as_slice :: < f64 > ( ) ) ;
90- let b = bx. as_slice :: < f64 > ( ) . iter ( ) . zip ( by. as_slice :: < f64 > ( ) ) ;
91- let distances = a
92- . zip ( b)
93- . map ( |( ( & ax, & ay) , ( & bx, & by) ) | euclidean_distance ( ax, ay, bx, by) ) ;
91+ // `a` is the point column; `b` is the constant query point, decoded once and broadcast.
92+ let points = args. get ( 0 ) ?;
93+ let Some ( query) = args. get ( 1 ) ?. as_constant ( ) else {
94+ vortex_bail ! ( "GeoDistance requires a constant query point as its second operand" ) ;
95+ } ;
96+ let query = coordinate_from_scalar ( & query) ?;
97+ let ( xs, ys) = xy_columns ( & points, ctx) ?;
98+ let distances = xs
99+ . as_slice :: < f64 > ( )
100+ . iter ( )
101+ . zip ( ys. as_slice :: < f64 > ( ) )
102+ . map ( |( & x, & y) | euclidean_distance ( x, y, query. x , query. y ) ) ;
94103 Ok ( PrimitiveArray :: from_iter ( distances) . into_array ( ) )
95104 }
96105}
97106
98107#[ cfg( test) ]
99108mod tests {
100109 use vortex_array:: ArrayRef ;
110+ use vortex_array:: Canonical ;
101111 use vortex_array:: ExecutionCtx ;
102112 use vortex_array:: IntoArray ;
103113 use vortex_array:: VortexSessionExecute ;
@@ -140,6 +150,15 @@ mod tests {
140150 Ok ( ConstantArray :: new ( single, len) . into_array ( ) )
141151 }
142152
153+ /// Execute a `GeoDistance` array and read back its per-row `f64` distances.
154+ fn distances ( distance : ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < Vec < f64 > > {
155+ Ok ( distance
156+ . execute :: < Canonical > ( ctx) ?
157+ . into_primitive ( )
158+ . as_slice :: < f64 > ( )
159+ . to_vec ( ) )
160+ }
161+
143162 /// The kernel computes planar Euclidean distance (the 3–4–5 triangle).
144163 #[ test]
145164 fn euclidean_distance_is_planar ( ) {
@@ -156,12 +175,38 @@ mod tests {
156175
157176 let a = point_column ( vec ! [ 0.0 , 3.0 , 0.0 , 3.0 ] , vec ! [ 0.0 , 0.0 , 4.0 , 4.0 ] ) ?;
158177 let b = point_constant ( 0.0 , 0.0 , 4 , & mut ctx) ?;
159- let distance = GeoDistance :: try_new_array ( a, b, 4 ) ?. into_array ( ) ;
178+ let distance = GeoDistance :: try_new_array ( a, b) ?. into_array ( ) ;
179+
180+ assert_eq ! ( distances( distance, & mut ctx) ?, vec![ 0.0 , 3.0 , 4.0 , 5.0 ] ) ;
181+ Ok ( ( ) )
182+ }
183+
184+ /// Without a constant query point on either side, column-to-column distance is unsupported and
185+ /// the kernel errors rather than computing it.
186+ #[ test]
187+ fn distance_requires_constant_query_point ( ) -> VortexResult < ( ) > {
188+ let session = VortexSession :: empty ( ) . with :: < ArraySession > ( ) ;
189+ let mut ctx = session. create_execution_ctx ( ) ;
190+
191+ let a = point_column ( vec ! [ 0.0 , 1.0 ] , vec ! [ 0.0 , 1.0 ] ) ?;
192+ let b = point_column ( vec ! [ 3.0 , 1.0 ] , vec ! [ 4.0 , 1.0 ] ) ?;
193+ let distance = GeoDistance :: try_new_array ( a, b) ?. into_array ( ) ;
194+
195+ assert ! ( distance. execute:: <Canonical >( & mut ctx) . is_err( ) ) ;
196+ Ok ( ( ) )
197+ }
198+
199+ /// Two constant operands: every row has the same distance.
200+ #[ test]
201+ fn distance_between_two_constants ( ) -> VortexResult < ( ) > {
202+ let session = VortexSession :: empty ( ) . with :: < ArraySession > ( ) ;
203+ let mut ctx = session. create_execution_ctx ( ) ;
204+
205+ let a = point_constant ( 0.0 , 0.0 , 3 , & mut ctx) ?;
206+ let b = point_constant ( 3.0 , 4.0 , 3 , & mut ctx) ?;
207+ let distance = GeoDistance :: try_new_array ( a, b) ?. into_array ( ) ;
160208
161- let got: Vec < f64 > = ( 0 ..4 )
162- . map ( |idx| f64:: try_from ( & distance. execute_scalar ( idx, & mut ctx) ?) )
163- . collect :: < VortexResult < _ > > ( ) ?;
164- assert_eq ! ( got, vec![ 0.0 , 3.0 , 4.0 , 5.0 ] ) ;
209+ assert_eq ! ( distances( distance, & mut ctx) ?, vec![ 5.0 , 5.0 , 5.0 ] ) ;
165210 Ok ( ( ) )
166211 }
167212}
0 commit comments