Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helix-db/src/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ rerank_mmr = { "RerankMMR" ~ "(" ~ "lambda" ~ ":" ~ evaluates_to_number ~ ("," ~
// ---------------------------------------------------------------------
// Vector steps
// ---------------------------------------------------------------------
search_vector = { "SearchV" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ vector_data ~ "," ~ (integer | identifier) ~ ")" }// ~ ("::" ~ pre_filter)? }
search_vector = { "SearchV" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ vector_data ~ "," ~ (integer | identifier) ~ ")" ~ ("::" ~ pre_filter)? }
bm25_search = { "SearchBM25" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ (string_literal | identifier) ~ "," ~ (integer | identifier) ~ ")" }
pre_filter = { "PREFILTER" ~ "(" ~ (evaluates_to_bool | anonymous_traversal) ~ ")" }
BatchAddV = { "BatchAddV" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ identifier ~ ")" }
Expand Down
9 changes: 9 additions & 0 deletions helix-db/src/helixc/analyzer/methods/graph_step_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,15 @@ pub(crate) fn apply_graph_step<'a>(
{
generate_error!(ctx, original_query, sv.loc.clone(), E103, ty.as_str());
}
if sv.pre_filter.is_some() {
generate_error!(
ctx,
original_query,
sv.loc.clone(),
E601,
"PREFILTER is only supported on root SearchV calls, not graph-step SearchV"
);
}
let vec = match &sv.data {
Some(VectorData::Vector(v)) => {
VecData::Standard(GeneratedValue::Literal(GenRef::Ref(format!(
Expand Down
190 changes: 133 additions & 57 deletions helix-db/src/helixc/analyzer/methods/infer_expr_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{
statements::Statement as GeneratedStatement,
traversal_steps::{
ShouldCollect, Step as GeneratedStep, Traversal as GeneratedTraversal,
TraversalType, Where, WhereRef,
TraversalType,
},
utils::{GenRef, GeneratedValue, Separator, VecData},
},
Expand All @@ -37,6 +37,90 @@ use crate::{
use paste::paste;
use std::collections::HashMap;

fn is_supported_search_vector_prefilter_traversal(traversal: &GeneratedTraversal) -> bool {
let is_val_traversal = match &traversal.traversal_type {
TraversalType::FromIter(var) | TraversalType::FromSingle(var) => match var {
GenRef::Std(s) | GenRef::Literal(s) => s == DEFAULT_VAR_NAME,
_ => false,
},
_ => false,
};

if !is_val_traversal {
return false;
}

// Prefilter closures currently only support direct property predicates on the vector.
if !matches!(traversal.source_step.inner(), SourceStep::Anonymous) {
return false;
}

if traversal.steps.len() != 2 {
return false;
}

matches!(traversal.steps[0].inner(), GeneratedStep::PropertyFetch(_))
&& matches!(traversal.steps[1].inner(), GeneratedStep::BoolOp(_))
}

fn is_supported_search_vector_prefilter_expr(expr: &BoExp) -> bool {
match expr {
BoExp::Not(inner) => is_supported_search_vector_prefilter_expr(inner),
BoExp::And(exprs) | BoExp::Or(exprs) => exprs
.iter()
.all(is_supported_search_vector_prefilter_expr),
BoExp::Expr(traversal) => is_supported_search_vector_prefilter_traversal(traversal),
BoExp::Exists(_) | BoExp::Empty => false,
}
}

pub(crate) fn build_search_vector_pre_filter<'a>(
ctx: &mut Ctx<'a>,
pre_filter_expr: &'a Expression,
scope: &mut HashMap<&'a str, VariableInfo>,
original_query: &'a Query,
vector_type: Option<String>,
gen_query: &mut GeneratedQuery,
) -> Option<Vec<BoExp>> {
let (_, stmt) = infer_expr_type(
ctx,
pre_filter_expr,
scope,
original_query,
Some(Type::Vector(vector_type)),
gen_query,
);

let pre_filter = match stmt {
Some(GeneratedStatement::Traversal(tr)) => BoExp::Expr(tr),
Some(GeneratedStatement::BoExp(expr)) => expr,
Some(_) => {
generate_error!(
ctx,
original_query,
pre_filter_expr.loc.clone(),
E306,
"PREFILTER"
);
return None;
}
None => return None,
};

if !is_supported_search_vector_prefilter_expr(&pre_filter) {
generate_error!(
ctx,
original_query,
pre_filter_expr.loc.clone(),
E601,
"PREFILTER only supports simple vector property predicates like _::{field}::EQ(value)"
);
return None;
}

Some(vec![pre_filter])
}

/// Infer the end type of an expression and returns the statement to generate from the expression
///
/// This function is used to infer the end type of an expression and returns the statement to generate from the expression
Expand Down Expand Up @@ -1278,62 +1362,16 @@ pub(crate) fn infer_expr_type<'a>(
}
};

let pre_filter: Option<Vec<BoExp>> = match &sv.pre_filter {
Some(expr) => {
let (_, stmt) = infer_expr_type(
ctx,
expr,
scope,
original_query,
Some(Type::Vector(sv.vector_type.clone())),
gen_query,
);
// Where/boolean ops don't change the element type,
// so `cur_ty` stays the same.
if stmt.is_none() {
return (Type::Vector(sv.vector_type.clone()), None);
}
let stmt = stmt.unwrap();
let mut gen_traversal = GeneratedTraversal {
traversal_type: TraversalType::FromIter(GenRef::Std("v".to_string())),
steps: vec![],
should_collect: ShouldCollect::ToVec,
source_step: Separator::Empty(SourceStep::Anonymous),
..Default::default()
};
match stmt {
GeneratedStatement::Traversal(tr) => {
gen_traversal
.steps
.push(Separator::Period(GeneratedStep::Where(Where::Ref(
WhereRef {
expr: BoExp::Expr(tr),
},
))));
}
GeneratedStatement::BoExp(expr) => {
gen_traversal
.steps
.push(Separator::Period(GeneratedStep::Where(match expr {
BoExp::Exists(mut traversal) => {
traversal.should_collect = ShouldCollect::No;
Where::Ref(WhereRef {
expr: BoExp::Exists(traversal),
})
}
_ => Where::Ref(WhereRef { expr }),
})));
}
// Pre-filter should produce Traversal or BoExp
_ => {
// Fall through - pre-filter will be None
return (Type::Vector(sv.vector_type.clone()), None);
}
}
Some(vec![BoExp::Expr(gen_traversal)])
}
None => None,
};
let pre_filter = sv.pre_filter.as_ref().and_then(|expr| {
build_search_vector_pre_filter(
ctx,
expr,
scope,
original_query,
sv.vector_type.clone(),
gen_query,
)
});

// Search returns nodes that contain the vectors
(
Expand Down Expand Up @@ -2116,4 +2154,42 @@ mod tests {
let (diagnostics, _) = result.unwrap();
assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E660));
}

#[test]
fn test_search_vector_prefilter_simple_property_predicate_valid() {
let source = r#"
V::Document { content: String, category: String, embedding: [F32] }

QUERY test(query_vec: [F64]) =>
docs <- SearchV<Document>(query_vec, 10)::PREFILTER(_::{category}::EQ("tech"))
RETURN docs
"#;

let content = write_to_temp_file(vec![source]);
let parsed = HelixParser::parse_source(&content).unwrap();
let result = crate::helixc::analyzer::analyze(&parsed);

assert!(result.is_ok());
let (diagnostics, _) = result.unwrap();
assert!(!diagnostics.iter().any(|d| d.error_code == ErrorCode::E601));
}

#[test]
fn test_search_vector_prefilter_non_boolean_traversal_emits_e601() {
let source = r#"
V::Document { content: String, category: String, embedding: [F32] }

QUERY test(query_vec: [F64]) =>
docs <- SearchV<Document>(query_vec, 10)::PREFILTER(_::{category})
RETURN docs
"#;

let content = write_to_temp_file(vec![source]);
let parsed = HelixParser::parse_source(&content).unwrap();
let result = crate::helixc::analyzer::analyze(&parsed);

assert!(result.is_ok());
let (diagnostics, _) = result.unwrap();
assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E601));
}
}
63 changes: 12 additions & 51 deletions helix-db/src/helixc/analyzer/methods/traversal_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use crate::{
errors::push_query_err,
methods::{
exclude_validation::validate_exclude, graph_step_validation::apply_graph_step,
infer_expr_type::infer_expr_type, object_validation::validate_object,
infer_expr_type::{build_search_vector_pre_filter, infer_expr_type},
object_validation::validate_object,
},
types::{AggregateInfo, Type},
utils::{
Expand Down Expand Up @@ -639,56 +640,16 @@ pub(crate) fn validate_traversal<'a>(
}
};

// let pre_filter: Option<Vec<BoExp>> = match &sv.pre_filter {
// Some(expr) => {
// let (_, stmt) = infer_expr_type(
// ctx,
// expr,
// scope,
// original_query,
// Some(Type::Vector(sv.vector_type.clone())),
// gen_query,
// );
// // Where/boolean ops don't change the element type,
// // so `cur_ty` stays the same.
// assert!(stmt.is_some());
// let stmt = stmt.unwrap();
// let mut gen_traversal = GeneratedTraversal {
// traversal_type: TraversalType::NestedFrom(GenRef::Std("v".to_string())),
// steps: vec![],
// should_collect: ShouldCollect::ToVec,
// source_step: Separator::Empty(SourceStep::Anonymous),
// };
// match stmt {
// GeneratedStatement::Traversal(tr) => {
// gen_traversal
// .steps
// .push(Separator::Period(GeneratedStep::Where(Where::Ref(
// WhereRef {
// expr: BoExp::Expr(tr),
// },
// ))));
// }
// GeneratedStatement::BoExp(expr) => {
// gen_traversal
// .steps
// .push(Separator::Period(GeneratedStep::Where(match expr {
// BoExp::Exists(mut traversal) => {
// traversal.should_collect = ShouldCollect::No;
// Where::Ref(WhereRef {
// expr: BoExp::Exists(traversal),
// })
// }
// _ => Where::Ref(WhereRef { expr }),
// })));
// }
// _ => unreachable!(),
// }
// Some(vec![BoExp::Expr(gen_traversal)])
// }
// None => None,
// };
let pre_filter = None;
let pre_filter = sv.pre_filter.as_ref().and_then(|expr| {
build_search_vector_pre_filter(
ctx,
expr,
scope,
original_query,
sv.vector_type.clone(),
gen_query,
)
});

gen_traversal.traversal_type = TraversalType::Ref;
gen_traversal.should_collect = ShouldCollect::ToVec;
Expand Down
2 changes: 1 addition & 1 deletion helix-db/src/helixc/generator/source_steps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ impl Display for SearchVector {
self.label,
pre_filter
.iter()
.map(|f| format!("|v: &HVector, txn: &RoTxn| {f}"))
.map(|f| format!("|val: &HVector, _txn: &RoTxn| {f}"))
.collect::<Vec<_>>()
.join(", ")
),
Expand Down
15 changes: 15 additions & 0 deletions helix-db/src/helixc/parser/expression_parse_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,21 @@ mod tests {
assert!(result.is_ok());
}

#[test]
fn test_parse_vector_search_with_prefilter() {
let source = r#"
V::Document { content: String, category: String, embedding: [F32] }

QUERY searchSimilar(queryVec: [F32]) =>
docs <- SearchV<Document>(queryVec, 10)::PREFILTER(_::{category}::EQ("tech"))
RETURN docs
"#;

let content = write_to_temp_file(vec![source]);
let result = HelixParser::parse_source(&content);
assert!(result.is_ok());
}

// ============================================================================
// Assignment Tests
// ============================================================================
Expand Down