@@ -19,13 +19,14 @@ use arrow::array::{
1919 Array , ArrayRef , AsArray , BooleanArray , GenericListArray , Int64Array , OffsetSizeTrait ,
2020} ;
2121use arrow:: datatypes:: {
22- DataType , Date32Type , Decimal128Type , Float32Type , Float64Type , Int16Type , Int32Type ,
23- Int64Type , Int8Type , TimestampMicrosecondType ,
22+ ArrowPrimitiveType , DataType , Date32Type , Decimal128Type , Float32Type , Float64Type , Int16Type ,
23+ Int32Type , Int64Type , Int8Type , TimestampMicrosecondType ,
2424} ;
2525use datafusion:: common:: { exec_err, DataFusionError , Result as DataFusionResult , ScalarValue } ;
2626use datafusion:: logical_expr:: {
2727 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
2828} ;
29+ use num:: Float ;
2930use std:: any:: Any ;
3031use std:: sync:: Arc ;
3132
@@ -36,7 +37,6 @@ fn spark_array_position(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFus
3637 return exec_err ! ( "array_position function takes exactly two arguments" ) ;
3738 }
3839
39- // Convert all arguments to arrays for consistent processing
4040 let len = args
4141 . iter ( )
4242 . fold ( Option :: < usize > :: None , |acc, arg| match arg {
@@ -68,144 +68,227 @@ fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError>
6868 }
6969}
7070
71- /// Find the 1-based position of `search_val` in a typed primitive array.
72- /// Returns 0 if not found.
73- macro_rules! find_position_primitive {
74- ( $list_items: expr, $element: expr, $row_index: expr, $arrow_type: ty) => { {
75- let items = $list_items. as_primitive:: <$arrow_type>( ) ;
76- let search = $element. as_primitive:: <$arrow_type>( ) ;
77- let search_val = search. value( $row_index) ;
71+ /// Searches for an element in a list array using the flat values buffer and offsets directly,
72+ /// avoiding per-row subarray allocation. Dispatches to typed fast paths by element data type.
73+ fn generic_array_position < O : OffsetSizeTrait > (
74+ array : & ArrayRef ,
75+ element : & ArrayRef ,
76+ ) -> Result < ArrayRef , DataFusionError > {
77+ let list_array = array
78+ . as_any ( )
79+ . downcast_ref :: < GenericListArray < O > > ( )
80+ . unwrap ( ) ;
81+
82+ let values = list_array. values ( ) ;
83+ let offsets = list_array. offsets ( ) ;
84+ let elem_type = values. data_type ( ) . clone ( ) ;
85+
86+ match & elem_type {
87+ DataType :: Boolean => {
88+ position_boolean :: < O > ( list_array, offsets, values, element)
89+ }
90+ DataType :: Int8 => position_primitive :: < O , Int8Type > ( list_array, offsets, values, element) ,
91+ DataType :: Int16 => position_primitive :: < O , Int16Type > ( list_array, offsets, values, element) ,
92+ DataType :: Int32 => position_primitive :: < O , Int32Type > ( list_array, offsets, values, element) ,
93+ DataType :: Int64 => position_primitive :: < O , Int64Type > ( list_array, offsets, values, element) ,
94+ DataType :: Float32 => {
95+ position_float :: < O , Float32Type > ( list_array, offsets, values, element)
96+ }
97+ DataType :: Float64 => {
98+ position_float :: < O , Float64Type > ( list_array, offsets, values, element)
99+ }
100+ DataType :: Decimal128 ( _, _) => {
101+ position_primitive :: < O , Decimal128Type > ( list_array, offsets, values, element)
102+ }
103+ DataType :: Date32 => {
104+ position_primitive :: < O , Date32Type > ( list_array, offsets, values, element)
105+ }
106+ DataType :: Timestamp ( arrow:: datatypes:: TimeUnit :: Microsecond , _) => {
107+ position_primitive :: < O , TimestampMicrosecondType > (
108+ list_array, offsets, values, element,
109+ )
110+ }
111+ DataType :: Utf8 => position_string :: < O , i32 > ( list_array, offsets, values, element) ,
112+ DataType :: LargeUtf8 => position_string :: < O , i64 > ( list_array, offsets, values, element) ,
113+ // Fallback to ScalarValue for complex types (nested arrays, etc.)
114+ _ => position_fallback :: < O > ( list_array, offsets, values, element) ,
115+ }
116+ }
117+
118+ /// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer.
119+ fn position_primitive < O : OffsetSizeTrait , T : ArrowPrimitiveType > (
120+ list_array : & GenericListArray < O > ,
121+ offsets : & arrow:: buffer:: OffsetBuffer < O > ,
122+ values : & ArrayRef ,
123+ element : & ArrayRef ,
124+ ) -> Result < ArrayRef , DataFusionError >
125+ where
126+ T :: Native : PartialEq ,
127+ {
128+ let values_typed = values. as_primitive :: < T > ( ) ;
129+ let element_typed = element. as_primitive :: < T > ( ) ;
130+ let num_rows = list_array. len ( ) ;
131+ let mut result = Vec :: with_capacity ( num_rows) ;
132+
133+ for ( row_index, w) in offsets. windows ( 2 ) . enumerate ( ) {
134+ if list_array. is_null ( row_index) || element. is_null ( row_index) {
135+ result. push ( None ) ;
136+ continue ;
137+ }
138+ let start = w[ 0 ] . as_usize ( ) ;
139+ let end = w[ 1 ] . as_usize ( ) ;
140+ let search_val = element_typed. value ( row_index) ;
78141 let mut pos: i64 = 0 ;
79- for i in 0 ..items . len ( ) {
80- if !items . is_null( i) && items . value( i) == search_val {
81- pos = ( i + 1 ) as i64 ;
142+ for i in start..end {
143+ if !values_typed . is_null ( i) && values_typed . value ( i) == search_val {
144+ pos = ( i - start + 1 ) as i64 ;
82145 break ;
83146 }
84147 }
85- pos
86- } } ;
148+ result. push ( Some ( pos) ) ;
149+ }
150+
151+ Ok ( Arc :: new ( Int64Array :: from ( result) ) )
87152}
88153
89- /// Float-aware comparison that treats NaN == NaN (matching Spark's ordering.equiv() semantics).
90- macro_rules! find_position_float {
91- ( $list_items: expr, $element: expr, $row_index: expr, $arrow_type: ty) => { {
92- let items = $list_items. as_primitive:: <$arrow_type>( ) ;
93- let search = $element. as_primitive:: <$arrow_type>( ) ;
94- let search_val = search. value( $row_index) ;
154+ /// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics).
155+ fn position_float < O : OffsetSizeTrait , T : ArrowPrimitiveType > (
156+ list_array : & GenericListArray < O > ,
157+ offsets : & arrow:: buffer:: OffsetBuffer < O > ,
158+ values : & ArrayRef ,
159+ element : & ArrayRef ,
160+ ) -> Result < ArrayRef , DataFusionError >
161+ where
162+ T :: Native : PartialEq + num:: Float ,
163+ {
164+ let values_typed = values. as_primitive :: < T > ( ) ;
165+ let element_typed = element. as_primitive :: < T > ( ) ;
166+ let num_rows = list_array. len ( ) ;
167+ let mut result = Vec :: with_capacity ( num_rows) ;
168+
169+ for ( row_index, w) in offsets. windows ( 2 ) . enumerate ( ) {
170+ if list_array. is_null ( row_index) || element. is_null ( row_index) {
171+ result. push ( None ) ;
172+ continue ;
173+ }
174+ let start = w[ 0 ] . as_usize ( ) ;
175+ let end = w[ 1 ] . as_usize ( ) ;
176+ let search_val = element_typed. value ( row_index) ;
95177 let search_is_nan = search_val. is_nan ( ) ;
96178 let mut pos: i64 = 0 ;
97- for i in 0 ..items . len ( ) {
98- if !items . is_null( i) {
99- let item_val = items . value( i) ;
100- if ( search_is_nan && item_val . is_nan( ) ) || item_val == search_val {
101- pos = ( i + 1 ) as i64 ;
179+ for i in start..end {
180+ if !values_typed . is_null ( i) {
181+ let v = values_typed . value ( i) ;
182+ if ( search_is_nan && v . is_nan ( ) ) || v == search_val {
183+ pos = ( i - start + 1 ) as i64 ;
102184 break ;
103185 }
104186 }
105187 }
106- pos
107- } } ;
188+ result. push ( Some ( pos) ) ;
189+ }
190+
191+ Ok ( Arc :: new ( Int64Array :: from ( result) ) )
108192}
109193
110- fn find_position_in_row (
111- list_items : & ArrayRef ,
194+ /// Boolean path.
195+ fn position_boolean < O : OffsetSizeTrait > (
196+ list_array : & GenericListArray < O > ,
197+ offsets : & arrow:: buffer:: OffsetBuffer < O > ,
198+ values : & ArrayRef ,
112199 element : & ArrayRef ,
113- row_index : usize ,
114- ) -> Result < i64 , DataFusionError > {
115- let pos = match list_items. data_type ( ) {
116- DataType :: Boolean => {
117- let items = list_items. as_any ( ) . downcast_ref :: < BooleanArray > ( ) . unwrap ( ) ;
118- let search = element. as_any ( ) . downcast_ref :: < BooleanArray > ( ) . unwrap ( ) ;
119- let search_val = search. value ( row_index) ;
120- let mut pos: i64 = 0 ;
121- for i in 0 ..items. len ( ) {
122- if !items. is_null ( i) && items. value ( i) == search_val {
123- pos = ( i + 1 ) as i64 ;
124- break ;
125- }
126- }
127- pos
128- }
129- DataType :: Int8 => find_position_primitive ! ( list_items, element, row_index, Int8Type ) ,
130- DataType :: Int16 => find_position_primitive ! ( list_items, element, row_index, Int16Type ) ,
131- DataType :: Int32 => find_position_primitive ! ( list_items, element, row_index, Int32Type ) ,
132- DataType :: Int64 => find_position_primitive ! ( list_items, element, row_index, Int64Type ) ,
133- DataType :: Float32 => find_position_float ! ( list_items, element, row_index, Float32Type ) ,
134- DataType :: Float64 => find_position_float ! ( list_items, element, row_index, Float64Type ) ,
135- DataType :: Decimal128 ( _, _) => {
136- find_position_primitive ! ( list_items, element, row_index, Decimal128Type )
137- }
138- DataType :: Date32 => {
139- find_position_primitive ! ( list_items, element, row_index, Date32Type )
200+ ) -> Result < ArrayRef , DataFusionError > {
201+ let values_typed = values. as_any ( ) . downcast_ref :: < BooleanArray > ( ) . unwrap ( ) ;
202+ let element_typed = element. as_any ( ) . downcast_ref :: < BooleanArray > ( ) . unwrap ( ) ;
203+ let num_rows = list_array. len ( ) ;
204+ let mut result = Vec :: with_capacity ( num_rows) ;
205+
206+ for ( row_index, w) in offsets. windows ( 2 ) . enumerate ( ) {
207+ if list_array. is_null ( row_index) || element. is_null ( row_index) {
208+ result. push ( None ) ;
209+ continue ;
140210 }
141- DataType :: Timestamp ( arrow:: datatypes:: TimeUnit :: Microsecond , _) => {
142- find_position_primitive ! ( list_items, element, row_index, TimestampMicrosecondType )
143- }
144- DataType :: Utf8 => {
145- let items = list_items. as_string :: < i32 > ( ) ;
146- let search = element. as_string :: < i32 > ( ) ;
147- let search_val = search. value ( row_index) ;
148- let mut pos: i64 = 0 ;
149- for i in 0 ..items. len ( ) {
150- if !items. is_null ( i) && items. value ( i) == search_val {
151- pos = ( i + 1 ) as i64 ;
152- break ;
153- }
154- }
155- pos
156- }
157- DataType :: LargeUtf8 => {
158- let items = list_items. as_string :: < i64 > ( ) ;
159- let search = element. as_string :: < i64 > ( ) ;
160- let search_val = search. value ( row_index) ;
161- let mut pos: i64 = 0 ;
162- for i in 0 ..items. len ( ) {
163- if !items. is_null ( i) && items. value ( i) == search_val {
164- pos = ( i + 1 ) as i64 ;
165- break ;
166- }
211+ let start = w[ 0 ] . as_usize ( ) ;
212+ let end = w[ 1 ] . as_usize ( ) ;
213+ let search_val = element_typed. value ( row_index) ;
214+ let mut pos: i64 = 0 ;
215+ for i in start..end {
216+ if !values_typed. is_null ( i) && values_typed. value ( i) == search_val {
217+ pos = ( i - start + 1 ) as i64 ;
218+ break ;
167219 }
168- pos
169220 }
170- // Fallback to ScalarValue for complex types (nested arrays, etc.)
171- _ => {
172- let element_scalar = ScalarValue :: try_from_array ( element, row_index) ?;
173- let mut pos: i64 = 0 ;
174- for i in 0 ..list_items. len ( ) {
175- let item_scalar = ScalarValue :: try_from_array ( list_items, i) ?;
176- if !item_scalar. is_null ( ) && element_scalar == item_scalar {
177- pos = ( i + 1 ) as i64 ;
178- break ;
179- }
221+ result. push ( Some ( pos) ) ;
222+ }
223+
224+ Ok ( Arc :: new ( Int64Array :: from ( result) ) )
225+ }
226+
227+ /// String path: downcast once, iterate using offsets into the flat string buffer.
228+ fn position_string < O : OffsetSizeTrait , S : OffsetSizeTrait > (
229+ list_array : & GenericListArray < O > ,
230+ offsets : & arrow:: buffer:: OffsetBuffer < O > ,
231+ values : & ArrayRef ,
232+ element : & ArrayRef ,
233+ ) -> Result < ArrayRef , DataFusionError > {
234+ let values_typed = values. as_string :: < S > ( ) ;
235+ let element_typed = element. as_string :: < S > ( ) ;
236+ let num_rows = list_array. len ( ) ;
237+ let mut result = Vec :: with_capacity ( num_rows) ;
238+
239+ for ( row_index, w) in offsets. windows ( 2 ) . enumerate ( ) {
240+ if list_array. is_null ( row_index) || element. is_null ( row_index) {
241+ result. push ( None ) ;
242+ continue ;
243+ }
244+ let start = w[ 0 ] . as_usize ( ) ;
245+ let end = w[ 1 ] . as_usize ( ) ;
246+ let search_val = element_typed. value ( row_index) ;
247+ let mut pos: i64 = 0 ;
248+ for i in start..end {
249+ if !values_typed. is_null ( i) && values_typed. value ( i) == search_val {
250+ pos = ( i - start + 1 ) as i64 ;
251+ break ;
180252 }
181- pos
182253 }
183- } ;
184- Ok ( pos)
254+ result. push ( Some ( pos) ) ;
255+ }
256+
257+ Ok ( Arc :: new ( Int64Array :: from ( result) ) )
185258}
186259
187- fn generic_array_position < O : OffsetSizeTrait > (
188- array : & ArrayRef ,
260+ /// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison.
261+ fn position_fallback < O : OffsetSizeTrait > (
262+ list_array : & GenericListArray < O > ,
263+ offsets : & arrow:: buffer:: OffsetBuffer < O > ,
264+ values : & ArrayRef ,
189265 element : & ArrayRef ,
190266) -> Result < ArrayRef , DataFusionError > {
191- let list_array = array
192- . as_any ( )
193- . downcast_ref :: < GenericListArray < O > > ( )
194- . unwrap ( ) ;
195-
196- let mut data = Vec :: with_capacity ( list_array. len ( ) ) ;
267+ let num_rows = list_array. len ( ) ;
268+ let mut result = Vec :: with_capacity ( num_rows) ;
197269
198- for row_index in 0 ..list_array . len ( ) {
270+ for ( row_index, w ) in offsets . windows ( 2 ) . enumerate ( ) {
199271 if list_array. is_null ( row_index) || element. is_null ( row_index) {
200- data. push ( None ) ;
201- } else {
202- let list_array_row = list_array. value ( row_index) ;
203- let position = find_position_in_row ( & list_array_row, element, row_index) ?;
204- data. push ( Some ( position) ) ;
272+ result. push ( None ) ;
273+ continue ;
274+ }
275+ let start = w[ 0 ] . as_usize ( ) ;
276+ let end = w[ 1 ] . as_usize ( ) ;
277+ let search_scalar = ScalarValue :: try_from_array ( element, row_index) ?;
278+ let mut pos: i64 = 0 ;
279+ for i in start..end {
280+ if !values. is_null ( i) {
281+ let item_scalar = ScalarValue :: try_from_array ( values, i) ?;
282+ if search_scalar == item_scalar {
283+ pos = ( i - start + 1 ) as i64 ;
284+ break ;
285+ }
286+ }
205287 }
288+ result. push ( Some ( pos) ) ;
206289 }
207290
208- Ok ( Arc :: new ( Int64Array :: from ( data ) ) )
291+ Ok ( Arc :: new ( Int64Array :: from ( result ) ) )
209292}
210293
211294#[ derive( Debug , Hash , Eq , PartialEq ) ]
0 commit comments