Skip to content

Commit 8c05259

Browse files
committed
feat: add inner_product scalar function
1 parent 89e14f1 commit 8c05259

4 files changed

Lines changed: 409 additions & 0 deletions

File tree

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ScalarUDFImpl`] definitions for inner_product function.
19+
20+
use crate::utils::make_scalar_function;
21+
use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
22+
use arrow::datatypes::{
23+
DataType,
24+
DataType::{FixedSizeList, LargeList, List, Null},
25+
Field,
26+
};
27+
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
28+
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
29+
use datafusion_common::{
30+
Result, exec_err, internal_err, plan_err, utils::take_function_args,
31+
};
32+
use datafusion_expr::{
33+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34+
Volatility,
35+
};
36+
use datafusion_macros::user_doc;
37+
use std::sync::Arc;
38+
39+
make_udf_expr_and_func!(
40+
InnerProduct,
41+
inner_product,
42+
array1 array2,
43+
"returns the inner product (dot product) of two numeric arrays.",
44+
inner_product_udf
45+
);
46+
47+
#[user_doc(
48+
doc_section(label = "Array Functions"),
49+
description = "Returns the inner product (dot product) of two input arrays of equal length, computed as `sum(array1[i] * array2[i])`. Returns NULL if either array is NULL or contains NULL elements. Returns 0.0 for two empty arrays.",
50+
syntax_example = "inner_product(array1, array2)",
51+
sql_example = r#"```sql
52+
> select inner_product([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]);
53+
+-------------------------------------------------------+
54+
| inner_product(List([1.0,2.0,3.0]),List([4.0,5.0,6.0])) |
55+
+-------------------------------------------------------+
56+
| 32.0 |
57+
+-------------------------------------------------------+
58+
```"#,
59+
argument(
60+
name = "array1",
61+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62+
),
63+
argument(
64+
name = "array2",
65+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66+
)
67+
)]
68+
#[derive(Debug, PartialEq, Eq, Hash)]
69+
pub struct InnerProduct {
70+
signature: Signature,
71+
}
72+
73+
impl Default for InnerProduct {
74+
fn default() -> Self {
75+
Self::new()
76+
}
77+
}
78+
79+
impl InnerProduct {
80+
pub fn new() -> Self {
81+
Self {
82+
signature: Signature::user_defined(Volatility::Immutable),
83+
}
84+
}
85+
}
86+
87+
impl ScalarUDFImpl for InnerProduct {
88+
fn name(&self) -> &str {
89+
"inner_product"
90+
}
91+
92+
fn signature(&self) -> &Signature {
93+
&self.signature
94+
}
95+
96+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
97+
Ok(DataType::Float64)
98+
}
99+
100+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
101+
let [_, _] = take_function_args(self.name(), arg_types)?;
102+
let coercion = Some(&ListCoercion::FixedSizedListToList);
103+
104+
for arg_type in arg_types {
105+
if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
106+
return plan_err!("{} does not support type {arg_type}", self.name());
107+
}
108+
}
109+
110+
// If any input is `LargeList`, both sides must be widened to `LargeList`
111+
// so the runtime dispatch in `inner_product_inner` sees a homogeneous
112+
// pair. Follows the pattern in `ArrayConcat::coerce_types`.
113+
let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
114+
115+
let coerced = arg_types
116+
.iter()
117+
.map(|arg_type| {
118+
if matches!(arg_type, Null) {
119+
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
120+
return if any_large_list {
121+
LargeList(field)
122+
} else {
123+
List(field)
124+
};
125+
}
126+
let coerced = coerced_type_with_base_type_only(
127+
arg_type,
128+
&DataType::Float64,
129+
coercion,
130+
);
131+
match coerced {
132+
List(field) if any_large_list => LargeList(field),
133+
other => other,
134+
}
135+
})
136+
.collect();
137+
138+
Ok(coerced)
139+
}
140+
141+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
142+
make_scalar_function(inner_product_inner)(&args.args)
143+
}
144+
145+
fn documentation(&self) -> Option<&Documentation> {
146+
self.doc()
147+
}
148+
}
149+
150+
fn inner_product_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
151+
let [array1, array2] = take_function_args("inner_product", args)?;
152+
match (array1.data_type(), array2.data_type()) {
153+
(List(_), List(_)) => general_inner_product::<i32>(args),
154+
(LargeList(_), LargeList(_)) => general_inner_product::<i64>(args),
155+
(arg_type1, arg_type2) => internal_err!(
156+
"inner_product received unexpected types after coercion: {arg_type1} and {arg_type2}"
157+
),
158+
}
159+
}
160+
161+
fn general_inner_product<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
162+
let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
163+
let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
164+
165+
let values1 = as_float64_array(list_array1.values())?;
166+
let values2 = as_float64_array(list_array2.values())?;
167+
let offsets1 = list_array1.value_offsets();
168+
let offsets2 = list_array2.value_offsets();
169+
170+
let mut builder = Float64Array::builder(list_array1.len());
171+
for row in 0..list_array1.len() {
172+
if list_array1.is_null(row) || list_array2.is_null(row) {
173+
builder.append_null();
174+
continue;
175+
}
176+
177+
let start1 = offsets1[row].as_usize();
178+
let end1 = offsets1[row + 1].as_usize();
179+
let start2 = offsets2[row].as_usize();
180+
let end2 = offsets2[row + 1].as_usize();
181+
let len1 = end1 - start1;
182+
let len2 = end2 - start2;
183+
184+
if len1 != len2 {
185+
return exec_err!(
186+
"inner_product requires both list inputs to have the same length, got {len1} and {len2}"
187+
);
188+
}
189+
190+
let slice1 = values1.slice(start1, len1);
191+
let slice2 = values2.slice(start2, len2);
192+
if slice1.null_count() != 0 || slice2.null_count() != 0 {
193+
builder.append_null();
194+
continue;
195+
}
196+
197+
let vals1 = slice1.values();
198+
let vals2 = slice2.values();
199+
200+
let mut dot = 0.0;
201+
for i in 0..len1 {
202+
dot += vals1[i] * vals2[i];
203+
}
204+
builder.append_value(dot);
205+
}
206+
207+
Ok(Arc::new(builder.finish()) as ArrayRef)
208+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub mod except;
5050
pub mod expr_ext;
5151
pub mod extract;
5252
pub mod flatten;
53+
pub mod inner_product;
5354
pub mod length;
5455
pub mod make_array;
5556
pub mod map;
@@ -100,6 +101,7 @@ pub mod expr_fn {
100101
pub use super::extract::array_pop_front;
101102
pub use super::extract::array_slice;
102103
pub use super::flatten::flatten;
104+
pub use super::inner_product::inner_product;
103105
pub use super::length::array_length;
104106
pub use super::make_array::make_array;
105107
pub use super::map_entries::map_entries;
@@ -156,6 +158,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
156158
empty::array_empty_udf(),
157159
length::array_length_udf(),
158160
cosine_distance::cosine_distance_udf(),
161+
inner_product::inner_product_udf(),
159162
distance::array_distance_udf(),
160163
flatten::flatten_udf(),
161164
min_max::array_max_udf(),

0 commit comments

Comments
 (0)