Skip to content

Commit e3e579d

Browse files
authored
feat: support vector literals (#81)
* feat: support vector literals
1 parent f7c5e20 commit e3e579d

6 files changed

Lines changed: 292 additions & 4 deletions

File tree

rust/lance-graph/src/ast.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,9 @@ pub enum ValueExpression {
308308
},
309309
/// Parameter reference for query parameters (e.g., $query_vector)
310310
Parameter(String),
311+
/// Vector literal: [0.1, 0.2, 0.3]
312+
/// Represents an inline vector for similarity search
313+
VectorLiteral(Vec<f32>),
311314
}
312315

313316
/// Arithmetic operators

rust/lance-graph/src/datafusion_planner/expression.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,25 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr {
241241
vec![left_expr, right_expr],
242242
))
243243
}
244+
VE::VectorLiteral(values) => {
245+
// Convert Vec<f32> to DataFusion scalar FixedSizeList
246+
use arrow::array::{FixedSizeListArray, Float32Array};
247+
use arrow::datatypes::{DataType, Field};
248+
use datafusion::scalar::ScalarValue;
249+
use std::sync::Arc;
250+
251+
let dim = values.len() as i32;
252+
let field = Arc::new(Field::new("item", DataType::Float32, true));
253+
let float_array = Arc::new(Float32Array::from(values.clone()));
254+
255+
let list_array = FixedSizeListArray::try_new(field.clone(), dim, float_array, None)
256+
.expect("Failed to create FixedSizeListArray for vector literal");
257+
258+
let scalar = ScalarValue::try_from_array(&list_array, 0)
259+
.expect("Failed to create scalar from array");
260+
261+
lit(scalar)
262+
}
244263
VE::Parameter(name) => {
245264
// TODO: Implement proper parameter resolution
246265
// Parameters ($param) should be resolved to literal values from the query's

rust/lance-graph/src/parser.rs

Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use crate::error::{GraphError, Result};
1111
use nom::{
1212
branch::alt,
1313
bytes::complete::{tag, tag_no_case, take_while1},
14-
character::complete::{char, multispace0, multispace1},
15-
combinator::{map, opt, peek, recognize},
14+
character::complete::{char, digit0, digit1, multispace0, multispace1, one_of},
15+
combinator::{map, map_res, opt, peek, recognize},
1616
multi::{many0, separated_list0, separated_list1},
1717
sequence::{delimited, pair, preceded, tuple},
1818
IResult,
@@ -438,8 +438,9 @@ fn comparison_operator(input: &str) -> IResult<&str, ComparisonOperator> {
438438
// Parse a basic value expression (without vector functions to avoid circular dependency)
439439
fn basic_value_expression(input: &str) -> IResult<&str, ValueExpression> {
440440
alt((
441-
parse_parameter, // Try $parameter
442-
function_call, // Regular function calls
441+
parse_vector_literal, // Try vector literal first [0.1, 0.2]
442+
parse_parameter, // Try $parameter
443+
function_call, // Regular function calls
443444
map(property_value, ValueExpression::Literal), // Try literals BEFORE property references
444445
map(property_reference, ValueExpression::Property),
445446
map(identifier, |id| ValueExpression::Variable(id.to_string())),
@@ -603,6 +604,44 @@ fn value_expression_list(input: &str) -> IResult<&str, Vec<ValueExpression>> {
603604
)(input)
604605
}
605606

607+
// Parse a float32 literal for vectors
608+
fn float32_literal(input: &str) -> IResult<&str, f32> {
609+
map_res(
610+
recognize(tuple((
611+
opt(char('-')),
612+
alt((
613+
// Scientific notation: 1e-3, 2.5e2
614+
recognize(tuple((
615+
digit1,
616+
opt(tuple((char('.'), digit0))),
617+
one_of("eE"),
618+
opt(one_of("+-")),
619+
digit1,
620+
))),
621+
// Regular float: 1.23 or integer: 123
622+
recognize(tuple((digit1, opt(tuple((char('.'), digit0)))))),
623+
)),
624+
))),
625+
|s: &str| s.parse::<f32>(),
626+
)(input)
627+
}
628+
629+
// Parse vector literal: [0.1, 0.2, 0.3]
630+
fn parse_vector_literal(input: &str) -> IResult<&str, ValueExpression> {
631+
let (input, _) = char('[')(input)?;
632+
let (input, _) = multispace0(input)?;
633+
634+
let (input, values) = separated_list1(
635+
tuple((multispace0, char(','), multispace0)),
636+
float32_literal,
637+
)(input)?;
638+
639+
let (input, _) = multispace0(input)?;
640+
let (input, _) = char(']')(input)?;
641+
642+
Ok((input, ValueExpression::VectorLiteral(values)))
643+
}
644+
606645
// Parse a property reference: variable.property
607646
fn property_reference(input: &str) -> IResult<&str, PropertyRef> {
608647
let (input, variable) = identifier(input)?;
@@ -1597,4 +1636,90 @@ mod tests {
15971636
_ => panic!("Expected AND expression"),
15981637
}
15991638
}
1639+
1640+
#[test]
1641+
fn test_parse_vector_literal() {
1642+
let result = parse_vector_literal("[0.1, 0.2, 0.3]");
1643+
assert!(result.is_ok());
1644+
let (_, expr) = result.unwrap();
1645+
match expr {
1646+
ValueExpression::VectorLiteral(vec) => {
1647+
assert_eq!(vec.len(), 3);
1648+
assert_eq!(vec[0], 0.1);
1649+
assert_eq!(vec[1], 0.2);
1650+
assert_eq!(vec[2], 0.3);
1651+
}
1652+
_ => panic!("Expected VectorLiteral"),
1653+
}
1654+
}
1655+
1656+
#[test]
1657+
fn test_parse_vector_literal_with_negative_values() {
1658+
let result = parse_vector_literal("[-0.1, 0.2, -0.3]");
1659+
assert!(result.is_ok());
1660+
let (_, expr) = result.unwrap();
1661+
match expr {
1662+
ValueExpression::VectorLiteral(vec) => {
1663+
assert_eq!(vec.len(), 3);
1664+
assert_eq!(vec[0], -0.1);
1665+
assert_eq!(vec[2], -0.3);
1666+
}
1667+
_ => panic!("Expected VectorLiteral"),
1668+
}
1669+
}
1670+
1671+
#[test]
1672+
fn test_parse_vector_literal_scientific_notation() {
1673+
let result = parse_vector_literal("[1e-3, 2.5e2, -3e-1]");
1674+
assert!(result.is_ok());
1675+
let (_, expr) = result.unwrap();
1676+
match expr {
1677+
ValueExpression::VectorLiteral(vec) => {
1678+
assert_eq!(vec.len(), 3);
1679+
assert!((vec[0] - 0.001).abs() < 1e-6);
1680+
assert!((vec[1] - 250.0).abs() < 1e-6);
1681+
assert!((vec[2] - (-0.3)).abs() < 1e-6);
1682+
}
1683+
_ => panic!("Expected VectorLiteral"),
1684+
}
1685+
}
1686+
1687+
#[test]
1688+
fn test_vector_distance_with_literal() {
1689+
let query =
1690+
"MATCH (p:Person) WHERE vector_distance(p.embedding, [0.1, 0.2], l2) < 0.5 RETURN p";
1691+
let result = parse_cypher_query(query);
1692+
assert!(result.is_ok());
1693+
1694+
let ast = result.unwrap();
1695+
let where_clause = ast.where_clause.expect("Expected WHERE clause");
1696+
1697+
match where_clause.expression {
1698+
BooleanExpression::Comparison { left, operator, .. } => {
1699+
match left {
1700+
ValueExpression::VectorDistance {
1701+
left,
1702+
right,
1703+
metric,
1704+
} => {
1705+
// Left should be property reference
1706+
assert!(matches!(*left, ValueExpression::Property(_)));
1707+
// Right should be vector literal
1708+
match *right {
1709+
ValueExpression::VectorLiteral(vec) => {
1710+
assert_eq!(vec.len(), 2);
1711+
assert_eq!(vec[0], 0.1);
1712+
assert_eq!(vec[1], 0.2);
1713+
}
1714+
_ => panic!("Expected VectorLiteral"),
1715+
}
1716+
assert_eq!(metric, DistanceMetric::L2);
1717+
}
1718+
_ => panic!("Expected VectorDistance"),
1719+
}
1720+
assert_eq!(operator, ComparisonOperator::LessThan);
1721+
}
1722+
_ => panic!("Expected comparison"),
1723+
}
1724+
}
16001725
}

rust/lance-graph/src/semantic.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,18 @@ impl SemanticAnalyzer {
378378
});
379379
}
380380
}
381+
ValueExpression::VectorLiteral(values) => {
382+
// Validate non-empty
383+
if values.is_empty() {
384+
return Err(GraphError::PlanError {
385+
message: "Vector literal cannot be empty".to_string(),
386+
location: snafu::Location::new(file!(), line!(), column!()),
387+
});
388+
}
389+
390+
// Note: Very large vectors (>4096 dimensions) may impact performance
391+
// but we don't enforce a hard limit here
392+
}
381393
ValueExpression::Parameter(_) => {
382394
// Parameters are always valid (resolved at runtime)
383395
}

rust/lance-graph/src/simple_executor/expr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,6 @@ pub(crate) fn to_df_value_expr_simple(
172172
VE::VectorDistance { .. } => lit(0.0f32),
173173
VE::VectorSimilarity { .. } => lit(1.0f32),
174174
VE::Parameter(_) => lit(0),
175+
VE::VectorLiteral(_) => lit(0.0f32),
175176
}
176177
}

rust/lance-graph/tests/test_vector_search.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,36 @@ async fn test_vector_distance_l2() -> Result<()> {
154154
Ok(())
155155
}
156156

157+
#[tokio::test]
158+
async fn test_vector_distance_l2_with_literal() -> Result<()> {
159+
let (config, datasets) = create_person_graph_with_embeddings();
160+
161+
// Same test as above but using vector literal instead of cross product
162+
// Find people similar to [1, 0, 0] (Alice's embedding)
163+
let query = CypherQuery::new(
164+
"MATCH (p:Person) \
165+
WHERE vector_distance(p.embedding, [1.0, 0.0, 0.0], l2) < 0.2 \
166+
RETURN p.name ORDER BY p.name",
167+
)?
168+
.with_config(config);
169+
170+
let result = query
171+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
172+
.await?;
173+
174+
// Should return Alice (exact match) and Bob (very similar)
175+
assert_eq!(result.num_rows(), 2);
176+
let names = result
177+
.column(0)
178+
.as_any()
179+
.downcast_ref::<StringArray>()
180+
.unwrap();
181+
assert_eq!(names.value(0), "Alice");
182+
assert_eq!(names.value(1), "Bob");
183+
184+
Ok(())
185+
}
186+
157187
#[tokio::test]
158188
async fn test_vector_distance_cosine() -> Result<()> {
159189
let (config, datasets) = create_person_graph_with_embeddings();
@@ -261,6 +291,37 @@ async fn test_vector_distance_order_by() -> Result<()> {
261291
Ok(())
262292
}
263293

294+
#[tokio::test]
295+
async fn test_vector_distance_order_by_with_literal() -> Result<()> {
296+
let (config, datasets) = create_person_graph_with_embeddings();
297+
298+
// Same as above but using vector literal - order by distance to [1,0,0] (Alice's vector)
299+
let query = CypherQuery::new(
300+
"MATCH (p:Person) \
301+
RETURN p.name \
302+
ORDER BY vector_distance(p.embedding, [1.0, 0.0, 0.0], cosine) ASC \
303+
LIMIT 3",
304+
)?
305+
.with_config(config);
306+
307+
let result = query
308+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
309+
.await?;
310+
311+
// Should return Alice (closest), Bob (second), Eve (third)
312+
assert_eq!(result.num_rows(), 3);
313+
let names = result
314+
.column(0)
315+
.as_any()
316+
.downcast_ref::<StringArray>()
317+
.unwrap();
318+
assert_eq!(names.value(0), "Alice");
319+
assert_eq!(names.value(1), "Bob");
320+
assert_eq!(names.value(2), "Eve");
321+
322+
Ok(())
323+
}
324+
264325
#[tokio::test]
265326
async fn test_vector_similarity_order_by() -> Result<()> {
266327
let (config, datasets) = create_person_graph_with_embeddings();
@@ -350,6 +411,36 @@ async fn test_hybrid_query_property_and_vector() -> Result<()> {
350411
Ok(())
351412
}
352413

414+
#[tokio::test]
415+
async fn test_hybrid_query_with_vector_literal() -> Result<()> {
416+
let (config, datasets) = create_person_graph_with_embeddings();
417+
418+
// Combine property filter with vector literal search
419+
let query = CypherQuery::new(
420+
"MATCH (p:Person) \
421+
WHERE p.age > 25 \
422+
AND vector_distance(p.embedding, [1.0, 0.0, 0.0], l2) < 0.3 \
423+
RETURN p.name ORDER BY p.name",
424+
)?
425+
.with_config(config);
426+
427+
let result = query
428+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
429+
.await?;
430+
431+
// Should return only Alice (age 30 > 25, close to [1,0,0])
432+
// Bob is age 25, not > 25
433+
assert_eq!(result.num_rows(), 1);
434+
let names = result
435+
.column(0)
436+
.as_any()
437+
.downcast_ref::<StringArray>()
438+
.unwrap();
439+
assert_eq!(names.value(0), "Alice");
440+
441+
Ok(())
442+
}
443+
353444
#[tokio::test]
354445
async fn test_vector_distance_dot_product() -> Result<()> {
355446
let (config, datasets) = create_person_graph_with_embeddings();
@@ -621,3 +712,40 @@ async fn test_vector_similarity_self_comparison() -> Result<()> {
621712

622713
Ok(())
623714
}
715+
716+
#[tokio::test]
717+
async fn test_vector_literal_in_return_clause() -> Result<()> {
718+
let (config, datasets) = create_person_graph_with_embeddings();
719+
720+
// Use vector literal in RETURN to compute distances
721+
let query = CypherQuery::new(
722+
"MATCH (p:Person) \
723+
RETURN p.name, vector_distance(p.embedding, [0.5, 0.5, 0.0], l2) AS dist \
724+
ORDER BY dist ASC \
725+
LIMIT 1",
726+
)?
727+
.with_config(config);
728+
729+
let result = query
730+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
731+
.await?;
732+
733+
// Should return Eve (closest to [0.5, 0.5, 0.0])
734+
assert_eq!(result.num_rows(), 1);
735+
let names = result
736+
.column(0)
737+
.as_any()
738+
.downcast_ref::<StringArray>()
739+
.unwrap();
740+
assert_eq!(names.value(0), "Eve");
741+
742+
// Distance should be 0 (exact match)
743+
let distances = result
744+
.column(1)
745+
.as_any()
746+
.downcast_ref::<Float32Array>()
747+
.unwrap();
748+
assert!(distances.value(0) < 0.001);
749+
750+
Ok(())
751+
}

0 commit comments

Comments
 (0)