diff --git a/rust/lance-graph/src/datafusion_planner.rs b/rust/lance-graph/src/datafusion_planner.rs index 15a54362..56c038d7 100644 --- a/rust/lance-graph/src/datafusion_planner.rs +++ b/rust/lance-graph/src/datafusion_planner.rs @@ -3,19 +3,34 @@ //! DataFusion-based physical planner for graph queries //! -//! This module implements the proper graph-to-relational mapping: -//! - Nodes as Tables: Each node label becomes a table +//! This module translates graph logical plans into DataFusion logical plans for execution. +//! It implements a two-phase planning approach: +//! +//! ## Phase 1: Analysis +//! - Extracts metadata from the graph logical plan (from `logical_plan.rs`) +//! - Assigns unique IDs to relationship instances to avoid column name conflicts +//! - Collects variable-to-label mappings and required datasets +//! +//! ## Phase 2: Plan Building +//! - Converts graph operations to relational operations +//! - Nodes as Tables: Each node label becomes a table scan //! - Relationships as Tables: Each relationship type becomes a linking table -//! - Cypher traversal becomes SQL joins +//! - Graph traversals become SQL joins with qualified column names //! -//! Uses DataFusion's LogicalPlan and optimizer for world-class query optimization. +//! ## Key Design Decisions +//! - **Unique relationship aliases**: Each relationship expansion gets a unique alias +//! (e.g., `knows_1`, `knows_2`) to support multi-hop queries without column conflicts +//! - **Relationship variables**: User-specified variables (e.g., `[r:KNOWS]`) take precedence +//! - **Column qualification**: All columns are qualified as `{variable}__{column}` to avoid ambiguity +use crate::ast::RelationshipDirection; use crate::error::Result; use crate::logical_plan::*; use crate::source_catalog::GraphSourceCatalog; use datafusion::logical_expr::{ col, lit, BinaryExpr, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, }; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; /// Planner abstraction for graph-to-physical planning @@ -25,7 +40,6 @@ pub trait GraphPhysicalPlanner { /// DataFusion-based physical planner pub struct DataFusionPlanner { - #[allow(dead_code)] config: crate::config::GraphConfig, catalog: Option>, } @@ -49,11 +63,96 @@ impl DataFusionPlanner { } } +// ============================================================================ +// Query Analysis Phase +// ============================================================================ + +/// Analysis result containing all metadata needed for planning +#[derive(Debug, Clone, Default)] +pub struct QueryAnalysis { + /// Variable → Label mappings (e.g., "n" → "Person") + pub var_to_label: HashMap, + + /// Relationship instances with unique IDs to avoid column conflicts + pub relationship_instances: Vec, + + /// All datasets required for this query + pub required_datasets: HashSet, +} + +/// Represents a single relationship expansion with a unique instance ID +#[derive(Debug, Clone)] +pub struct RelationshipInstance { + pub id: usize, // Unique instance number + pub rel_type: String, + pub source_var: String, + pub target_var: String, + pub direction: RelationshipDirection, + pub alias: String, // e.g., "friend_of_1", "friend_of_2" +} + +/// Parameters for joining source node to relationship +struct SourceJoinParams<'a> { + source_variable: &'a str, + rel_qualifier: &'a str, + node_id_field: &'a str, + rel_map: &'a crate::config::RelationshipMapping, + direction: &'a RelationshipDirection, +} + +/// Parameters for joining relationship to target node +struct TargetJoinParams<'a> { + source_variable: &'a str, + target_variable: &'a str, + rel_qualifier: &'a str, + node_map: &'a crate::config::NodeMapping, + rel_map: &'a crate::config::RelationshipMapping, + direction: &'a RelationshipDirection, +} + +/// Planning context that tracks state during plan building +pub struct PlanningContext<'a> { + pub analysis: &'a QueryAnalysis, + relationship_instance_idx: HashMap, +} + +impl<'a> PlanningContext<'a> { + pub fn new(analysis: &'a QueryAnalysis) -> Self { + Self { + analysis, + relationship_instance_idx: HashMap::new(), + } + } + + /// Get the next relationship instance for a given type (returns a clone) + pub fn next_relationship_instance(&mut self, rel_type: &str) -> Result { + let idx = self + .relationship_instance_idx + .entry(rel_type.to_string()) + .and_modify(|i| *i += 1) + .or_insert(0); + + self.analysis + .relationship_instances + .iter() + .filter(|r| r.rel_type == rel_type) + .nth(*idx) + .cloned() + .ok_or_else(|| crate::error::GraphError::PlanError { + message: format!("No relationship instance found for: {}", rel_type), + location: snafu::Location::new(file!(), line!(), column!()), + }) + } +} + impl GraphPhysicalPlanner for DataFusionPlanner { fn plan(&self, logical_plan: &LogicalOperator) -> Result { - use std::collections::HashMap; - let mut var_labels: HashMap = HashMap::new(); - self.plan_operator_with_ctx(logical_plan, &mut var_labels) + // Phase 1: Analyze query structure + let analysis = self.analyze(logical_plan)?; + + // Phase 2: Build execution plan with context + let mut ctx = PlanningContext::new(&analysis); + self.build_operator(&mut ctx, logical_plan) } } @@ -65,18 +164,17 @@ impl DataFusionPlanner { datasets: &std::collections::HashMap, ) -> Result { use crate::source_catalog::{InMemoryCatalog, SimpleTableSource}; - use std::collections::HashMap as StdHashMap; use std::sync::Arc; - // Collect variable -> label mappings from the logical plan - let mut variable_mappings: StdHashMap = StdHashMap::new(); - Self::collect_variable_mappings(logical_plan, &mut variable_mappings)?; + // Use the new analyze() method to extract metadata + let analysis = self.analyze(logical_plan)?; // Build an in-memory catalog from provided datasets (nodes and relationships) let mut catalog = InMemoryCatalog::new(); - let mut added_labels = std::collections::HashSet::new(); - for label in variable_mappings.values() { - if added_labels.insert(label.clone()) { + + // Register node sources from required datasets + for label in &analysis.required_datasets { + if self.config.node_mappings.contains_key(label) { if let Some(batch) = datasets.get(label) { let src = Arc::new(SimpleTableSource::new(batch.schema())); catalog = catalog.with_node_source(label, src); @@ -84,11 +182,13 @@ impl DataFusionPlanner { } } - // Register relationship sources if datasets include them - for rel_type in self.config.relationship_mappings.keys() { - if let Some(batch) = datasets.get(rel_type) { - let src = Arc::new(SimpleTableSource::new(batch.schema())); - catalog = catalog.with_relationship_source(rel_type, src); + // Register relationship sources from required datasets + for rel_type in &analysis.required_datasets { + if self.config.relationship_mappings.contains_key(rel_type) { + if let Some(batch) = datasets.get(rel_type) { + let src = Arc::new(SimpleTableSource::new(batch.schema())); + catalog = catalog.with_relationship_source(rel_type, src); + } } } @@ -98,50 +198,108 @@ impl DataFusionPlanner { planner_with_cat.plan(logical_plan) } - /// Collect variable to label mappings from logical plan - fn collect_variable_mappings( - op: &LogicalOperator, - mappings: &mut std::collections::HashMap, - ) -> Result<()> { - match op { - LogicalOperator::ScanByLabel { - variable, label, .. - } => { - mappings.insert(variable.clone(), label.clone()); - } - LogicalOperator::Expand { - input, - source_variable, - target_variable, - .. - } => { - Self::collect_variable_mappings(input, mappings)?; - // For expand, we need to infer the target variable's label - // For now, assume target has same label as source (simplified) - if let Some(source_label) = mappings.get(source_variable).cloned() { - mappings.insert(target_variable.clone(), source_label); - } - } - LogicalOperator::Filter { input, .. } - | LogicalOperator::Project { input, .. } - | LogicalOperator::Distinct { input } - | LogicalOperator::Limit { input, .. } - | LogicalOperator::Offset { input, .. } - | LogicalOperator::Sort { input, .. } => { - Self::collect_variable_mappings(input, mappings)?; + /// Phase 1: Analyze the logical plan to extract metadata + fn analyze(&self, logical_plan: &LogicalOperator) -> Result { + let mut analysis = QueryAnalysis::default(); + let mut rel_counter: HashMap = HashMap::new(); + + analyze_operator(logical_plan, &mut analysis, &mut rel_counter)?; + Ok(analysis) + } +} + +/// Recursively analyze operators to build QueryAnalysis +fn analyze_operator( + op: &LogicalOperator, + analysis: &mut QueryAnalysis, + rel_counter: &mut HashMap, +) -> Result<()> { + match op { + LogicalOperator::ScanByLabel { + variable, label, .. + } => { + analysis + .var_to_label + .insert(variable.clone(), label.clone()); + analysis.required_datasets.insert(label.clone()); + } + LogicalOperator::Expand { + input, + source_variable, + target_variable, + relationship_types, + direction, + relationship_variable, + .. + } + | LogicalOperator::VariableLengthExpand { + input, + source_variable, + target_variable, + relationship_types, + direction, + relationship_variable, + .. + } => { + // Recursively analyze input first + analyze_operator(input, analysis, rel_counter)?; + + // Infer target variable's label from source variable + // For (a:Person)-[:KNOWS]->(b), b also gets label Person + if let Some(source_label) = analysis.var_to_label.get(source_variable).cloned() { + analysis + .var_to_label + .insert(target_variable.clone(), source_label); } - LogicalOperator::VariableLengthExpand { input, .. } - | LogicalOperator::Join { left: input, .. } => { - Self::collect_variable_mappings(input, mappings)?; + + // Assign unique instance ID for this relationship + if let Some(rel_type) = relationship_types.first() { + let instance_id = rel_counter + .entry(rel_type.clone()) + .and_modify(|c| *c += 1) + .or_insert(1); + + // Use relationship variable if provided, otherwise use type_instanceId + let alias = if let Some(rel_var) = relationship_variable { + rel_var.clone() + } else { + format!("{}_{}", rel_type.to_lowercase(), instance_id) + }; + + analysis.relationship_instances.push(RelationshipInstance { + id: *instance_id, + rel_type: rel_type.clone(), + source_var: source_variable.clone(), + target_var: target_variable.clone(), + direction: direction.clone(), + alias, + }); + + analysis.required_datasets.insert(rel_type.clone()); } } - Ok(()) + LogicalOperator::Filter { input, .. } + | LogicalOperator::Project { input, .. } + | LogicalOperator::Sort { input, .. } + | LogicalOperator::Limit { input, .. } + | LogicalOperator::Offset { input, .. } + | LogicalOperator::Distinct { input } => { + analyze_operator(input, analysis, rel_counter)?; + } + LogicalOperator::Join { left, right, .. } => { + analyze_operator(left, analysis, rel_counter)?; + analyze_operator(right, analysis, rel_counter)?; + } } + Ok(()) +} - fn plan_operator_with_ctx( +impl DataFusionPlanner { + /// Phase 2: Build DataFusion LogicalPlan from logical operator with context + fn build_operator( &self, + ctx: &mut PlanningContext, op: &LogicalOperator, - var_labels: &mut std::collections::HashMap, ) -> Result { match op { LogicalOperator::ScanByLabel { @@ -149,65 +307,9 @@ impl DataFusionPlanner { label, properties, .. - } => { - // Track variable -> label mapping - var_labels.insert(variable.clone(), label.clone()); - - // Try to use catalog if available - if let Some(cat) = &self.catalog { - if let Some(source) = cat.node_source(label) { - // Get schema before moving source - let schema = source.schema(); - let mut builder = LogicalPlanBuilder::scan(label, source, None).unwrap(); - - // Apply property filters using unqualified names (before aliasing) - for (k, v) in properties.iter() { - let lit_expr = self - .to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); - let filter_expr = Expr::BinaryExpr(BinaryExpr { - left: Box::new(col(k)), - op: Operator::Eq, - right: Box::new(lit_expr), - }); - builder = builder.filter(filter_expr).unwrap(); - } - - // Create qualified column aliases: variable__property - let qualified_exprs: Vec = schema - .fields() - .iter() - .map(|field| { - let qualified_name = format!("{}__{}", variable, field.name()); - col(field.name()).alias(&qualified_name) - }) - .collect(); - - // Add projection with qualified aliases - builder = builder.project(qualified_exprs).unwrap(); - - return Ok(builder.build().unwrap()); - } - } - - // Fallback: create a simple table reference that DataFusion can resolve at execution time - // Use LogicalPlanBuilder to create a proper scan - let empty_source = Arc::new(crate::source_catalog::SimpleTableSource::empty()); - let builder = LogicalPlanBuilder::scan(label, empty_source, None).map_err(|e| { - crate::error::GraphError::PlanError { - message: format!("Failed to create table scan for {}: {}", label, e), - location: snafu::Location::new(file!(), line!(), column!()), - } - })?; - - Ok(builder - .build() - .map_err(|e| crate::error::GraphError::PlanError { - message: format!("Failed to build table scan for {}: {}", label, e), - location: snafu::Location::new(file!(), line!(), column!()), - })?) - } + } => self.build_scan(ctx, variable, label, properties), LogicalOperator::Filter { input, predicate } => { - let input_plan = self.plan_operator_with_ctx(input, var_labels)?; + let input_plan = self.build_operator(ctx, input)?; let expr = self.to_df_boolean_expr(predicate); Ok(LogicalPlanBuilder::from(input_plan) .filter(expr) @@ -216,7 +318,7 @@ impl DataFusionPlanner { .unwrap()) } LogicalOperator::Project { input, projections } => { - let input_plan = self.plan_operator_with_ctx(input, var_labels)?; + let input_plan = self.build_operator(ctx, input)?; let exprs: Vec = projections .iter() .map(|p| self.to_df_value_expr(&p.expression)) @@ -228,7 +330,7 @@ impl DataFusionPlanner { .unwrap()) } LogicalOperator::Distinct { input } => { - let input_plan = self.plan_operator_with_ctx(input, var_labels)?; + let input_plan = self.build_operator(ctx, input)?; Ok(LogicalPlanBuilder::from(input_plan) .distinct() .unwrap() @@ -237,10 +339,10 @@ impl DataFusionPlanner { } LogicalOperator::Sort { input, .. } => { // Schema-less placeholder: skip sort for now - self.plan_operator_with_ctx(input, var_labels) + self.build_operator(ctx, input) } LogicalOperator::Limit { input, count } => { - let input_plan = self.plan_operator_with_ctx(input, var_labels)?; + let input_plan = self.build_operator(ctx, input)?; Ok(LogicalPlanBuilder::from(input_plan) .limit(0, Some((*count) as usize)) .unwrap() @@ -248,7 +350,7 @@ impl DataFusionPlanner { .unwrap()) } LogicalOperator::Offset { input, offset } => { - let input_plan = self.plan_operator_with_ctx(input, var_labels)?; + let input_plan = self.build_operator(ctx, input)?; Ok(LogicalPlanBuilder::from(input_plan) .limit((*offset) as usize, None) .unwrap() @@ -261,7 +363,6 @@ impl DataFusionPlanner { target_variable, relationship_types, direction, - relationship_variable, .. } | LogicalOperator::VariableLengthExpand { @@ -270,182 +371,273 @@ impl DataFusionPlanner { target_variable, relationship_types, direction, - relationship_variable, .. - } => { - let left_plan = self.plan_operator_with_ctx(input, var_labels)?; - // TODO(two-hop+): Support chaining multiple hops in the physical plan. - // For single hop we scan the relationship table and filter with an ON expression. - // For two-hop (e.g., a-[:R1]->m-[:R2]->b), we should: - // 1) Join a with R1 (as done here) - // 2) Join the result with R2 - // 3) Join the result with the b node scan - // Ensure we maintain/propagate variable->label mapping (var_labels) and - // project/qualify columns to avoid ambiguity across joins. - // For VariableLengthExpand with bounds, consider unrolling small fixed bounds - // (e.g., *1..2) into a UNION ALL of 1-hop and 2-hop plans. - // Attempt first hop: source node -> relationship table - if let (Some(cat), Some(rel_type)) = (&self.catalog, relationship_types.first()) { - if let Some(rel_map) = self.config.relationship_mappings.get(rel_type) { - if let Some(src_label) = var_labels.get(source_variable) { - if let Some(node_map) = self.config.node_mappings.get(src_label) { - if let Some(rel_source) = - cat.relationship_source(&rel_map.relationship_type) - { - // Create relationship scan with qualified column aliases - let rel_schema = rel_source.schema(); - let rel_builder = LogicalPlanBuilder::scan( - &rel_map.relationship_type, - rel_source, - None, - ) - .unwrap(); - - // Create qualified column aliases for relationship - // Use relationship variable if available, otherwise use relationship type (lowercase) - let rel_type_lower = rel_map.relationship_type.to_lowercase(); - let rel_qualifier = - relationship_variable.as_deref().unwrap_or(&rel_type_lower); - let rel_qualified_exprs: Vec = rel_schema - .fields() - .iter() - .map(|field| { - let qualified_name = - format!("{}__{}", rel_qualifier, field.name()); - col(field.name()).alias(&qualified_name) - }) - .collect(); - - let rel_scan = rel_builder - .project(rel_qualified_exprs) - .unwrap() - .build() - .unwrap(); - - // Determine join keys based on direction - let (left_key, right_key) = match direction { - crate::ast::RelationshipDirection::Outgoing => { - (&node_map.id_field, &rel_map.source_id_field) - } - crate::ast::RelationshipDirection::Incoming => { - (&node_map.id_field, &rel_map.target_id_field) - } - crate::ast::RelationshipDirection::Undirected => { - (&node_map.id_field, &rel_map.source_id_field) - } - }; - // Use qualified column names for both sides of the join - let qualified_left_key = - format!("{}__{}", source_variable, left_key); - let qualified_right_key = - format!("{}__{}", rel_qualifier, right_key); - - // Use proper inner join instead of CrossJoin + Filter - let mut builder = LogicalPlanBuilder::from(left_plan) - .join( - rel_scan, - JoinType::Inner, - ( - vec![qualified_left_key.clone()], - vec![qualified_right_key.clone()], - ), - None, - ) - .unwrap(); - - // Add target node scan and join - // For now, assume target has same label as source (simplified) - if let Some(target_label) = - var_labels.get(source_variable).cloned() - { - if let Some(target_source) = cat.node_source(&target_label) - { - // Create target node scan with qualified column aliases - let target_schema = target_source.schema(); - let target_builder = LogicalPlanBuilder::scan( - &target_label, - target_source, - None, - ) - .unwrap(); - - // Create qualified column aliases for target: target_variable__property - let target_qualified_exprs: Vec = target_schema - .fields() - .iter() - .map(|field| { - let qualified_name = format!( - "{}__{}", - target_variable, - field.name() - ); - col(field.name()).alias(&qualified_name) - }) - .collect(); - - let target_scan = target_builder - .project(target_qualified_exprs) - .unwrap() - .build() - .unwrap(); - - // Determine target join keys - let target_key = match direction { - crate::ast::RelationshipDirection::Outgoing => { - &rel_map.target_id_field - } - crate::ast::RelationshipDirection::Incoming => { - &rel_map.source_id_field - } - crate::ast::RelationshipDirection::Undirected => { - &rel_map.target_id_field - } - }; - let qualified_rel_target_key = - format!("{}__{}", rel_qualifier, target_key); - let qualified_target_key = format!( - "{}__{}", - target_variable, &node_map.id_field - ); - - // Use proper inner join for relationship->target - builder = builder - .join( - target_scan, - JoinType::Inner, - ( - vec![qualified_rel_target_key], - vec![qualified_target_key], - ), - None, - ) - .unwrap(); - - // Track target variable label - var_labels - .insert(target_variable.clone(), target_label); - } - } - - return Ok(builder.build().unwrap()); - } - } - } - } - } - // Fallback: pass-through - var_labels - .entry(target_variable.clone()) - .or_insert_with(|| "Node".to_string()); - Ok(self.plan_operator_with_ctx(input, var_labels)?) - } + } => self.build_expand( + ctx, + input, + source_variable, + target_variable, + relationship_types, + direction, + ), LogicalOperator::Join { left, .. } => { // Not yet implemented: explicit join. For now, use left branch - self.plan_operator_with_ctx(left, var_labels) + self.build_operator(ctx, left) } } } + // ============================================================================ + // Component Builders + // ============================================================================ + + /// Build a qualified node scan with property filters and column aliasing + fn build_scan( + &self, + _ctx: &PlanningContext, + variable: &str, + label: &str, + properties: &HashMap, + ) -> Result { + // Try to use catalog if available + if let Some(cat) = &self.catalog { + if let Some(source) = cat.node_source(label) { + // Get schema before moving source + let schema = source.schema(); + let mut builder = LogicalPlanBuilder::scan(label, source, None).unwrap(); + + // Apply property filters using unqualified names (before aliasing) + for (k, v) in properties.iter() { + let lit_expr = + self.to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); + let filter_expr = Expr::BinaryExpr(BinaryExpr { + left: Box::new(col(k)), + op: Operator::Eq, + right: Box::new(lit_expr), + }); + builder = builder.filter(filter_expr).unwrap(); + } + + // Create qualified column aliases: variable__property + let qualified_exprs: Vec = schema + .fields() + .iter() + .map(|field| { + let qualified_name = format!("{}__{}", variable, field.name()); + col(field.name()).alias(&qualified_name) + }) + .collect(); + + // Add projection with qualified aliases + builder = builder.project(qualified_exprs).unwrap(); + + return Ok(builder.build().unwrap()); + } + } + + // Fallback: create a simple table reference that DataFusion can resolve at execution time + let empty_source = Arc::new(crate::source_catalog::SimpleTableSource::empty()); + let builder = LogicalPlanBuilder::scan(label, empty_source, None).map_err(|e| { + crate::error::GraphError::PlanError { + message: format!("Failed to create table scan for {}: {}", label, e), + location: snafu::Location::new(file!(), line!(), column!()), + } + })?; + + builder + .build() + .map_err(|e| crate::error::GraphError::PlanError { + message: format!("Failed to build table scan for {}: {}", label, e), + location: snafu::Location::new(file!(), line!(), column!()), + }) + } + + /// Build a relationship expansion (graph traversal) as a series of joins + fn build_expand( + &self, + ctx: &mut PlanningContext, + input: &LogicalOperator, + source_variable: &str, + target_variable: &str, + relationship_types: &[String], + direction: &RelationshipDirection, + ) -> Result { + let left_plan = self.build_operator(ctx, input)?; + + // Get the unique relationship instance for this expand operation + let Some(cat) = &self.catalog else { + // Fallback: pass-through if catalog not available + return Ok(left_plan); + }; + + let Some(rel_type) = relationship_types.first() else { + return Ok(left_plan); + }; + + let rel_instance = ctx.next_relationship_instance(rel_type)?; + let Some(rel_map) = self.config.relationship_mappings.get(rel_type) else { + return Ok(left_plan); + }; + + let Some(src_label) = ctx.analysis.var_to_label.get(source_variable) else { + return Ok(left_plan); + }; + + let Some(node_map) = self.config.node_mappings.get(src_label) else { + return Ok(left_plan); + }; + + let Some(rel_source) = cat.relationship_source(&rel_map.relationship_type) else { + return Ok(left_plan); + }; + + // Build relationship scan with qualified columns + let rel_scan = self.build_relationship_scan(&rel_instance, rel_source)?; + + // Join source node with relationship + let source_params = SourceJoinParams { + source_variable, + rel_qualifier: &rel_instance.alias, + node_id_field: &node_map.id_field, + rel_map, + direction, + }; + let builder = self.join_source_to_relationship(left_plan, rel_scan, &source_params)?; + + // Join relationship with target node + let target_params = TargetJoinParams { + source_variable, + target_variable, + rel_qualifier: &rel_instance.alias, + node_map, + rel_map, + direction, + }; + self.join_relationship_to_target(builder, cat, ctx, &target_params) + } + + /// Build a qualified relationship scan + fn build_relationship_scan( + &self, + rel_instance: &RelationshipInstance, + rel_source: Arc, + ) -> Result { + let rel_schema = rel_source.schema(); + let rel_builder = + LogicalPlanBuilder::scan(&rel_instance.rel_type, rel_source, None).unwrap(); + + // Use unique alias from rel_instance to avoid column conflicts + let rel_qualified_exprs: Vec = rel_schema + .fields() + .iter() + .map(|field| { + let qualified_name = format!("{}__{}", rel_instance.alias, field.name()); + col(field.name()).alias(&qualified_name) + }) + .collect(); + + Ok(rel_builder + .project(rel_qualified_exprs) + .unwrap() + .build() + .unwrap()) + } + + /// Join source node plan with relationship scan + fn join_source_to_relationship( + &self, + left_plan: LogicalPlan, + rel_scan: LogicalPlan, + params: &SourceJoinParams, + ) -> Result { + // Determine join keys based on direction + let right_key = match params.direction { + RelationshipDirection::Outgoing => ¶ms.rel_map.source_id_field, + RelationshipDirection::Incoming => ¶ms.rel_map.target_id_field, + RelationshipDirection::Undirected => ¶ms.rel_map.source_id_field, + }; + + let qualified_left_key = format!("{}__{}", params.source_variable, params.node_id_field); + let qualified_right_key = format!("{}__{}", params.rel_qualifier, right_key); + + Ok(LogicalPlanBuilder::from(left_plan) + .join( + rel_scan, + JoinType::Inner, + (vec![qualified_left_key], vec![qualified_right_key]), + None, + ) + .unwrap()) + } + + /// Join relationship with target node scan + fn join_relationship_to_target( + &self, + mut builder: LogicalPlanBuilder, + cat: &Arc, + ctx: &PlanningContext, + params: &TargetJoinParams, + ) -> Result { + // For now, assume target has same label as source (simplified) + let Some(target_label) = ctx + .analysis + .var_to_label + .get(params.source_variable) + .cloned() + else { + return Ok(builder.build().unwrap()); + }; + + let Some(target_source) = cat.node_source(&target_label) else { + return Ok(builder.build().unwrap()); + }; + + // Create target node scan with qualified column aliases + let target_schema = target_source.schema(); + let target_builder = LogicalPlanBuilder::scan(&target_label, target_source, None).unwrap(); + + let target_qualified_exprs: Vec = target_schema + .fields() + .iter() + .map(|field| { + let qualified_name = format!("{}__{}", params.target_variable, field.name()); + col(field.name()).alias(&qualified_name) + }) + .collect(); + + let target_scan = target_builder + .project(target_qualified_exprs) + .unwrap() + .build() + .unwrap(); + + // Determine target join keys + let target_key = match params.direction { + RelationshipDirection::Outgoing => ¶ms.rel_map.target_id_field, + RelationshipDirection::Incoming => ¶ms.rel_map.source_id_field, + RelationshipDirection::Undirected => ¶ms.rel_map.target_id_field, + }; + + let qualified_rel_target_key = format!("{}__{}", params.rel_qualifier, target_key); + let qualified_target_key = + format!("{}__{}", params.target_variable, ¶ms.node_map.id_field); + + builder = builder + .join( + target_scan, + JoinType::Inner, + (vec![qualified_rel_target_key], vec![qualified_target_key]), + None, + ) + .unwrap(); + + Ok(builder.build().unwrap()) + } + + // ============================================================================ + // Expression Translators + // ============================================================================ + fn to_df_boolean_expr(&self, expr: &crate::ast::BooleanExpression) -> Expr { use crate::ast::{BooleanExpression as BE, ComparisonOperator as CO}; match expr { @@ -799,7 +991,7 @@ mod tests { s ); assert!( - s.contains("KNOWS__src_person_id") || s.contains("knows__src_person_id"), + s.contains("knows_1__src_person_id"), "missing qualified rel key in join: {}", s ); @@ -1026,7 +1218,7 @@ mod tests { let df_plan = planner.plan(&project).unwrap(); let s = format!("{:?}", df_plan); assert!( - s.contains("KNOWS__dst_person_id") || s.contains("knows__dst_person_id"), + s.contains("knows_1__dst_person_id"), "incoming join should use dst key: {}", s ); @@ -1068,7 +1260,7 @@ mod tests { let df_plan = planner.plan(&project).unwrap(); let s = format!("{:?}", df_plan); assert!( - s.contains("KNOWS__src_person_id") || s.contains("knows__src_person_id"), + s.contains("knows_1__src_person_id"), "undirected uses src key side for predicate: {}", s ); @@ -1290,14 +1482,133 @@ mod tests { s ); } -} -/* -The main issues to fix: -1. Column import path: Use datafusion::common::Column instead of datafusion::logical_expr::Column -2. TableSource trait: Need to use LogicalTableSource or create proper table sources -3. ScalarValue::Null needs Option parameter -4. SortExpr type issues with DataFusion's Expr system + #[test] + fn test_query_analysis_single_hop() { + // Test that analysis correctly identifies relationship instances + let scan_a = LogicalOperator::ScanByLabel { + variable: "a".into(), + label: "Person".into(), + properties: Default::default(), + }; + let expand = LogicalOperator::Expand { + input: Box::new(scan_a), + source_variable: "a".into(), + target_variable: "b".into(), + relationship_types: vec!["KNOWS".into()], + direction: crate::ast::RelationshipDirection::Outgoing, + relationship_variable: None, + properties: Default::default(), + }; + + let cfg = crate::config::GraphConfig::builder() + .with_node_label("Person", "id") + .with_relationship("KNOWS", "src_id", "dst_id") + .build() + .unwrap(); + let planner = DataFusionPlanner::new(cfg); + let analysis = planner.analyze(&expand).unwrap(); + + // Should have two variable mappings: a and b both map to Person + assert_eq!(analysis.var_to_label.len(), 2); + assert_eq!(analysis.var_to_label.get("a"), Some(&"Person".to_string())); + assert_eq!(analysis.var_to_label.get("b"), Some(&"Person".to_string())); + + // Should have one relationship instance + assert_eq!(analysis.relationship_instances.len(), 1); + assert_eq!(analysis.relationship_instances[0].rel_type, "KNOWS"); + assert_eq!(analysis.relationship_instances[0].alias, "knows_1"); + assert_eq!(analysis.relationship_instances[0].id, 1); + } -Reference implementation should be here when these issues are resolved. -*/ + #[test] + fn test_query_analysis_two_hop() { + // Test that two-hop queries get unique relationship instances + let scan_a = LogicalOperator::ScanByLabel { + variable: "a".into(), + label: "Person".into(), + properties: Default::default(), + }; + let expand1 = LogicalOperator::Expand { + input: Box::new(scan_a), + source_variable: "a".into(), + target_variable: "b".into(), + relationship_types: vec!["KNOWS".into()], + direction: crate::ast::RelationshipDirection::Outgoing, + relationship_variable: None, + properties: Default::default(), + }; + let expand2 = LogicalOperator::Expand { + input: Box::new(expand1), + source_variable: "b".into(), + target_variable: "c".into(), + relationship_types: vec!["KNOWS".into()], + direction: crate::ast::RelationshipDirection::Outgoing, + relationship_variable: None, + properties: Default::default(), + }; + + let cfg = crate::config::GraphConfig::builder() + .with_node_label("Person", "id") + .with_relationship("KNOWS", "src_id", "dst_id") + .build() + .unwrap(); + let planner = DataFusionPlanner::new(cfg); + let analysis = planner.analyze(&expand2).unwrap(); + + // Should have two relationship instances with UNIQUE aliases + assert_eq!(analysis.relationship_instances.len(), 2); + assert_eq!(analysis.relationship_instances[0].alias, "knows_1"); + assert_eq!(analysis.relationship_instances[1].alias, "knows_2"); + + // Both should be KNOWS but with different IDs + assert_eq!(analysis.relationship_instances[0].rel_type, "KNOWS"); + assert_eq!(analysis.relationship_instances[1].rel_type, "KNOWS"); + assert_eq!(analysis.relationship_instances[0].id, 1); + assert_eq!(analysis.relationship_instances[1].id, 2); + } + + #[test] + fn test_planning_context_tracks_instances() { + // Test that PlanningContext correctly iterates through instances + let instances = vec![ + RelationshipInstance { + id: 1, + rel_type: "KNOWS".to_string(), + source_var: "a".to_string(), + target_var: "b".to_string(), + direction: crate::ast::RelationshipDirection::Outgoing, + alias: "knows_1".to_string(), + }, + RelationshipInstance { + id: 2, + rel_type: "KNOWS".to_string(), + source_var: "b".to_string(), + target_var: "c".to_string(), + direction: crate::ast::RelationshipDirection::Outgoing, + alias: "knows_2".to_string(), + }, + ]; + + let analysis = QueryAnalysis { + var_to_label: Default::default(), + relationship_instances: instances, + required_datasets: Default::default(), + }; + + let mut ctx = PlanningContext::new(&analysis); + + // First call should return knows_1 + let inst1 = ctx.next_relationship_instance("KNOWS").unwrap(); + assert_eq!(inst1.alias, "knows_1"); + assert_eq!(inst1.id, 1); + + // Second call should return knows_2 + let inst2 = ctx.next_relationship_instance("KNOWS").unwrap(); + assert_eq!(inst2.alias, "knows_2"); + assert_eq!(inst2.id, 2); + + // Third call should fail (no more instances) + assert!(ctx.next_relationship_instance("KNOWS").is_err()); + } +} diff --git a/rust/lance-graph/tests/integration_datafusion_pipeline.rs b/rust/lance-graph/tests/integration_datafusion_pipeline.rs index 4b3644a9..203969c9 100644 --- a/rust/lance-graph/tests/integration_datafusion_pipeline.rs +++ b/rust/lance-graph/tests/integration_datafusion_pipeline.rs @@ -5,6 +5,53 @@ use lance_graph::query::CypherQuery; use std::collections::HashMap; use std::sync::Arc; +// ============================================================================ +// Test Data Structure +// ============================================================================ +// +// Person Dataset (5 nodes): +// | ID | Name | Age | City | +// |----|---------|-----|---------------| +// | 1 | Alice | 25 | New York | +// | 2 | Bob | 35 | San Francisco | +// | 3 | Charlie | 30 | Chicago | +// | 4 | David | 40 | NULL | +// | 5 | Eve | 28 | Seattle | +// +// KNOWS Relationship Dataset (5 edges): +// | src_person_id | dst_person_id | since_year | +// |---------------|---------------|------------| +// | 1 | 2 | 2020 | +// | 2 | 3 | 2019 | +// | 3 | 4 | 2021 | +// | 4 | 5 | NULL | +// | 1 | 3 | 2018 | +// +// Visual Graph Structure: +// +// Alice(1) ──2020──> Bob(2) ──2019──> Charlie(3) ──2021──> David(4) ──NULL──> Eve(5) +// │ ▲ +// └──────────────2018──────────────────┘ +// +// Single-hop paths (5 edges): +// 1. Alice → Bob +// 2. Bob → Charlie +// 3. Charlie → David +// 4. David → Eve +// 5. Alice → Charlie (shortcut) +// +// Two-hop paths (4 paths): +// 1. Alice → Bob → Charlie +// 2. Bob → Charlie → David +// 3. Charlie → David → Eve +// 4. Alice → Charlie → David +// +// Key characteristics: +// - Eve (5): Has no outgoing edges (dead end) +// - Alice (1): Has 2 outgoing edges (most connections) +// - David (4): Has NULL since_year and NULL city values +// ============================================================================ + /// Helper function to create a Person dataset fn create_person_dataset() -> RecordBatch { let schema = Arc::new(Schema::new(vec![ @@ -608,3 +655,301 @@ async fn test_datafusion_one_hop_filtered_source_age_strict() { .collect(); assert_eq!(set, expected); } + +#[tokio::test] +async fn test_datafusion_two_hop_basic() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + + // Query: Find friends of friends + // Edges: 1->2, 2->3, 3->4, 4->5, 1->3 + // Two-hop paths: 1->2->3, 2->3->4, 3->4->5, 1->3->4 + let query = CypherQuery::new( + "MATCH (a:Person)-[:KNOWS]->(b:Person)-[:KNOWS]->(c:Person) RETURN c.name", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let out = query.execute_datafusion(datasets).await.unwrap(); + + // Should return: Charlie (from 1->2->3), David (from 2->3->4 and 1->3->4), Eve (from 3->4->5) + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 4); // 4 two-hop paths + + let names = out + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut counts = HashMap::::new(); + for i in 0..out.num_rows() { + *counts.entry(names.value(i).to_string()).or_insert(0) += 1; + } + + // Verify counts: Charlie:1, David:2, Eve:1 + assert_eq!(counts.get("Charlie"), Some(&1)); + assert_eq!(counts.get("David"), Some(&2)); + assert_eq!(counts.get("Eve"), Some(&1)); + assert!(!counts.contains_key("Alice")); + assert!(!counts.contains_key("Bob")); +} + +#[tokio::test] +async fn test_datafusion_two_hop_return_intermediate() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + + // Query: Return the intermediate node in two-hop paths + let query = CypherQuery::new( + "MATCH (a:Person)-[:KNOWS]->(b:Person)-[:KNOWS]->(c:Person) RETURN b.name", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let out = query.execute_datafusion(datasets).await.unwrap(); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 4); + + let names = out + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut counts = HashMap::::new(); + for i in 0..out.num_rows() { + *counts.entry(names.value(i).to_string()).or_insert(0) += 1; + } + + // Intermediate nodes: Bob (1->2->3), Charlie (2->3->4 and 1->3->4), David (3->4->5) + assert_eq!(counts.get("Bob"), Some(&1)); + assert_eq!(counts.get("Charlie"), Some(&2)); + assert_eq!(counts.get("David"), Some(&1)); +} + +#[tokio::test] +async fn test_datafusion_two_hop_return_all_three() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + + // Query: Return all three nodes in the path + let query = CypherQuery::new( + "MATCH (a:Person)-[:KNOWS]->(b:Person)-[:KNOWS]->(c:Person) RETURN a.name, b.name, c.name", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let out = query.execute_datafusion(datasets).await.unwrap(); + assert_eq!(out.num_columns(), 3); + assert_eq!(out.num_rows(), 4); + + let a_names = out + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let b_names = out + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let c_names = out + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + // Collect all paths + let mut paths = Vec::new(); + for i in 0..out.num_rows() { + paths.push(( + a_names.value(i).to_string(), + b_names.value(i).to_string(), + c_names.value(i).to_string(), + )); + } + + // Expected paths: Alice->Bob->Charlie, Bob->Charlie->David, Charlie->David->Eve, Alice->Charlie->David + assert!(paths.contains(&( + "Alice".to_string(), + "Bob".to_string(), + "Charlie".to_string() + ))); + assert!(paths.contains(&( + "Bob".to_string(), + "Charlie".to_string(), + "David".to_string() + ))); + assert!(paths.contains(&( + "Charlie".to_string(), + "David".to_string(), + "Eve".to_string() + ))); + assert!(paths.contains(&( + "Alice".to_string(), + "Charlie".to_string(), + "David".to_string() + ))); +} + +#[tokio::test] +async fn test_datafusion_two_hop_with_filter() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + + // Query: Two-hop with filter on intermediate node + let query = CypherQuery::new( + "MATCH (a:Person)-[:KNOWS]->(b:Person)-[:KNOWS]->(c:Person) WHERE b.age > 30 RETURN c.name", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let out = query.execute_datafusion(datasets).await.unwrap(); + + // Filter: b.age > 30 means b can be Bob(35), David(40) + // Paths with Bob as intermediate: 1->2->3 (Alice->Bob->Charlie) + // Paths with David as intermediate: 3->4->5 (Charlie->David->Eve) + // No paths with Charlie(30) as intermediate + assert_eq!(out.num_rows(), 2); + + let names = out + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let result_names: Vec = (0..out.num_rows()) + .map(|i| names.value(i).to_string()) + .collect(); + + assert!(result_names.contains(&"Charlie".to_string())); + assert!(result_names.contains(&"Eve".to_string())); +} + +#[tokio::test] +async fn test_datafusion_two_hop_with_relationship_variable() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + + // Query: Two-hop with relationship variables + let query = CypherQuery::new( + "MATCH (a:Person)-[r1:KNOWS]->(b:Person)-[r2:KNOWS]->(c:Person) RETURN a.name, c.name", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let out = query.execute_datafusion(datasets).await.unwrap(); + assert_eq!(out.num_columns(), 2); + assert_eq!(out.num_rows(), 4); + + let a_names = out + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let c_names = out + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify we get the correct source->target pairs + let mut pairs = Vec::new(); + for i in 0..out.num_rows() { + pairs.push((a_names.value(i).to_string(), c_names.value(i).to_string())); + } + + assert!(pairs.contains(&("Alice".to_string(), "Charlie".to_string()))); + assert!(pairs.contains(&("Bob".to_string(), "David".to_string()))); + assert!(pairs.contains(&("Charlie".to_string(), "Eve".to_string()))); + assert!(pairs.contains(&("Alice".to_string(), "David".to_string()))); +} + +#[tokio::test] +async fn test_datafusion_two_hop_distinct() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + + // Query: Get distinct final destinations in two-hop paths + let query = CypherQuery::new( + "MATCH (a:Person)-[:KNOWS]->(b:Person)-[:KNOWS]->(c:Person) RETURN DISTINCT c.name", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let out = query.execute_datafusion(datasets).await.unwrap(); + + // Distinct destinations: Charlie, David, Eve + assert_eq!(out.num_rows(), 3); + + let names = out + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let result_set: std::collections::HashSet = (0..out.num_rows()) + .map(|i| names.value(i).to_string()) + .collect(); + + let expected: std::collections::HashSet = ["Charlie", "David", "Eve"] + .into_iter() + .map(|s| s.to_string()) + .collect(); + + assert_eq!(result_set, expected); +} + +#[tokio::test] +async fn test_datafusion_two_hop_no_results() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + let knows_batch = create_knows_dataset(); + + // Query: Two-hop starting from Eve (who has no outgoing edges) + let query = CypherQuery::new( + "MATCH (a:Person)-[:KNOWS]->(b:Person)-[:KNOWS]->(c:Person) WHERE a.name = 'Eve' RETURN c.name" + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + datasets.insert("KNOWS".to_string(), knows_batch); + + let out = query.execute_datafusion(datasets).await.unwrap(); + + // Eve has no outgoing edges, so no two-hop paths + assert_eq!(out.num_rows(), 0); +}