Skip to content

Commit f7c5e20

Browse files
authored
feat: support vector search/similarity as UDFs (#80)
* feat: support vector search/similarity as udfs * feat: optimize the value expression parsing * refactor: remove the parameterized query
1 parent dd7cb99 commit f7c5e20

10 files changed

Lines changed: 2217 additions & 40 deletions

File tree

python/python/tests/test_to_sql.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -165,40 +165,6 @@ def test_collaborative_network_query(knowledge_graph_env):
165165
assert "ORDER BY" in sql_upper
166166

167167

168-
def test_parameterized_complex_query(knowledge_graph_env):
169-
"""Test complex query with multiple parameters.
170-
171-
Find authors from a specific country with papers above a citation threshold,
172-
published in recent years.
173-
"""
174-
config, datasets = knowledge_graph_env
175-
query = (
176-
CypherQuery(
177-
"""
178-
MATCH (a:Author)-[:AUTHORED]->(p:Paper)
179-
WHERE a.country = $country
180-
AND p.citations > $min_citations
181-
AND p.year >= $min_year
182-
RETURN a.name, a.h_index, p.title, p.citations
183-
ORDER BY p.citations DESC, a.h_index DESC
184-
"""
185-
)
186-
.with_config(config)
187-
.with_parameter("country", "USA")
188-
.with_parameter("min_citations", 300)
189-
.with_parameter("min_year", 2020)
190-
)
191-
192-
sql = query.to_sql(datasets)
193-
194-
assert isinstance(sql, str)
195-
sql_upper = sql.upper()
196-
assert "SELECT" in sql_upper
197-
assert "JOIN" in sql_upper
198-
assert "WHERE" in sql_upper
199-
assert "ORDER BY" in sql_upper
200-
201-
202168
def test_to_sql_without_config_raises_error(knowledge_graph_env):
203169
"""Test that to_sql fails gracefully without config."""
204170
_, datasets = knowledge_graph_env

rust/lance-graph/src/ast.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,18 @@ pub enum BooleanExpression {
249249
IsNotNull(ValueExpression),
250250
}
251251

252+
/// Distance metric for vector similarity
253+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
254+
pub enum DistanceMetric {
255+
/// Euclidean distance (L2)
256+
L2,
257+
/// Cosine similarity (1 - cosine distance)
258+
#[default]
259+
Cosine,
260+
/// Dot product
261+
Dot,
262+
}
263+
252264
/// Comparison operators
253265
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254266
pub enum ComparisonOperator {
@@ -280,6 +292,22 @@ pub enum ValueExpression {
280292
operator: ArithmeticOperator,
281293
right: Box<ValueExpression>,
282294
},
295+
/// Vector distance function: vector_distance(left, right, metric)
296+
/// Returns the distance as a float (lower = more similar for L2/Cosine)
297+
VectorDistance {
298+
left: Box<ValueExpression>,
299+
right: Box<ValueExpression>,
300+
metric: DistanceMetric,
301+
},
302+
/// Vector similarity function: vector_similarity(left, right, metric)
303+
/// Returns the similarity score as a float (higher = more similar)
304+
VectorSimilarity {
305+
left: Box<ValueExpression>,
306+
right: Box<ValueExpression>,
307+
metric: DistanceMetric,
308+
},
309+
/// Parameter reference for query parameters (e.g., $query_vector)
310+
Parameter(String),
283311
}
284312

285313
/// Arithmetic operators

rust/lance-graph/src/datafusion_planner/expression.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//! Converts AST expressions to DataFusion expressions
77
88
use crate::ast::{BooleanExpression, PropertyValue, ValueExpression};
9+
use crate::datafusion_planner::udf;
910
use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator};
1011
use datafusion_functions_aggregate::average::avg;
1112
use datafusion_functions_aggregate::count::count;
@@ -212,6 +213,48 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr {
212213
right: Box::new(r),
213214
})
214215
}
216+
VE::VectorDistance {
217+
left,
218+
right,
219+
metric,
220+
} => {
221+
// Create UDF for vector distance computation
222+
let udf = udf::create_vector_distance_udf(metric);
223+
let left_expr = to_df_value_expr(left);
224+
let right_expr = to_df_value_expr(right);
225+
Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
226+
udf,
227+
vec![left_expr, right_expr],
228+
))
229+
}
230+
VE::VectorSimilarity {
231+
left,
232+
right,
233+
metric,
234+
} => {
235+
// Create UDF for vector similarity computation
236+
let udf = udf::create_vector_similarity_udf(metric);
237+
let left_expr = to_df_value_expr(left);
238+
let right_expr = to_df_value_expr(right);
239+
Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
240+
udf,
241+
vec![left_expr, right_expr],
242+
))
243+
}
244+
VE::Parameter(name) => {
245+
// TODO: Implement proper parameter resolution
246+
// Parameters ($param) should be resolved to literal values from the query's
247+
// parameter map (CypherQuery::parameters()) before or during planning.
248+
//
249+
// Current limitation: This creates a column reference as a placeholder,
250+
// which will fail at execution if the column doesn't exist.
251+
//
252+
// Proper fix requires one of:
253+
// 1. Resolve parameters during semantic analysis (substitute before planning)
254+
// 2. Pass parameter map to to_df_value_expr and resolve here
255+
// 3. Use DataFusion's parameter binding mechanism
256+
col(format!("${}", name))
257+
}
215258
}
216259
}
217260

@@ -229,6 +272,12 @@ pub(crate) fn contains_aggregate(expr: &ValueExpression) -> bool {
229272
is_aggregate || args.iter().any(contains_aggregate)
230273
}
231274
VE::Arithmetic { left, right, .. } => contains_aggregate(left) || contains_aggregate(right),
275+
VE::VectorDistance { left, right, .. } => {
276+
contains_aggregate(left) || contains_aggregate(right)
277+
}
278+
VE::VectorSimilarity { left, right, .. } => {
279+
contains_aggregate(left) || contains_aggregate(right)
280+
}
232281
_ => false,
233282
}
234283
}

rust/lance-graph/src/datafusion_planner/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ mod config_helpers;
2020
mod expression;
2121
mod join_ops;
2222
mod scan_ops;
23+
mod udf;
24+
mod vector_ops;
2325

2426
#[cfg(test)]
2527
mod test_fixtures;

0 commit comments

Comments
 (0)