Skip to content

Commit 7625f3d

Browse files
authored
feat: implement relational variable join in the datafusion planner (#41)
1 parent 5731adf commit 7625f3d

3 files changed

Lines changed: 383 additions & 22 deletions

File tree

rust/lance-graph/src/datafusion_planner/builder.rs

Lines changed: 153 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -686,16 +686,27 @@ impl DataFusionPlanner {
686686
/// generates join keys based on the id fields of those shared variables.
687687
///
688688
/// Supports both node variables and relationship variables:
689-
/// - Node variables: Join on node ID field (e.g., `b__id`)
690-
/// - Relationship variables: Currently unsupported - returns empty keys
689+
/// - **Node variables**: Join on node ID field (e.g., `b__id`)
690+
/// - **Relationship variables**: Join on composite keys (src_id AND dst_id)
691691
///
692-
/// # Example
692+
/// # Examples
693+
///
694+
/// **Node variable join:**
693695
/// ```text
694-
/// Left pattern: (a:Person)-[:KNOWS]->(b:Person) -> variables: [a, b]
695-
/// Right pattern: (b:Person)-[:WORKS_AT]->(c:Company) -> variables: [b, c]
696+
/// Left: (a:Person)-[:KNOWS]->(b:Person) -> variables: [a, b]
697+
/// Right: (b:Person)-[:WORKS_AT]->(c:Company) -> variables: [b, c]
696698
/// Shared: [b]
697699
/// Result: (left_keys=["b__id"], right_keys=["b__id"])
698700
/// ```
701+
///
702+
/// **Relationship variable join:**
703+
/// ```text
704+
/// Left: (a:Person)-[r:KNOWS]->(b:Person) -> variables: [a, b, r]
705+
/// Right: (c:Person)-[r:KNOWS]->(d:Person) -> variables: [c, d, r]
706+
/// Shared: [r]
707+
/// Result: (left_keys=["r__src_id", "r__dst_id"],
708+
/// right_keys=["r__src_id", "r__dst_id"])
709+
/// ```
699710
fn infer_join_keys(
700711
&self,
701712
ctx: &PlanningContext,
@@ -735,24 +746,39 @@ impl DataFusionPlanner {
735746
left_keys.push(left_key);
736747
right_keys.push(right_key);
737748
}
749+
} else {
750+
// Not a node variable - check if it's a relationship variable
751+
// Look up the relationship instance by its alias (the variable name)
752+
if let Some(rel_instance) = ctx
753+
.analysis
754+
.relationship_instances
755+
.iter()
756+
.find(|r| r.alias == *var)
757+
{
758+
// Get the relationship mapping to find src/dst field names
759+
if let Some(rel_map) = self
760+
.config
761+
.relationship_mappings
762+
.get(&rel_instance.rel_type)
763+
{
764+
// Generate composite join keys for both src_id and dst_id
765+
// This ensures we're matching the exact same relationship instance
766+
// The columns are qualified as: {alias}__{original_field_name}
767+
// Example: var="r", source_id_field="src_person_id"
768+
// -> "r__src_person_id"
769+
let left_src = format!("{}__{}", var, &rel_map.source_id_field);
770+
let right_src = format!("{}__{}", var, &rel_map.source_id_field);
771+
let left_dst = format!("{}__{}", var, &rel_map.target_id_field);
772+
let right_dst = format!("{}__{}", var, &rel_map.target_id_field);
773+
774+
left_keys.push(left_src);
775+
right_keys.push(right_src);
776+
left_keys.push(left_dst);
777+
right_keys.push(right_dst);
778+
}
779+
}
780+
// If not found in either node or relationship variables, skip it
738781
}
739-
// If not a node variable, it might be a relationship variable
740-
// TODO: Implement relationship variable join key generation
741-
//
742-
// For now, we skip relationship variables (they won't generate keys).
743-
// This means patterns with only shared relationship variables will fall back
744-
// to cross join (or error for outer joins).
745-
//
746-
// To implement this:
747-
// 1. Look up the relationship instance in ctx.analysis.relationship_instances
748-
// using the variable name as the key
749-
// 2. Get the relationship mapping from self.config.relationship_mappings
750-
// using the relationship type
751-
// 3. Generate join keys based on a unique relationship ID column
752-
// (may need to add an ID field to RelationshipMapping if not present)
753-
// 4. Consider how to handle the fact that relationships are represented as
754-
// joins in the physical plan - you may need to join on both src_id and dst_id
755-
// to ensure the same relationship instance is matched
756782
}
757783

758784
(left_keys, right_keys)
@@ -2246,4 +2272,109 @@ mod tests {
22462272
"Shared variables should include 'r'"
22472273
);
22482274
}
2275+
2276+
#[test]
2277+
fn test_relationship_variable_join_key_inference() {
2278+
// Test that the join key inference logic correctly handles relationship variables
2279+
//
2280+
// Note: This tests the key generation logic, not the full plan execution.
2281+
// In practice, joining on shared relationship variables across disconnected patterns
2282+
// doesn't make semantic sense in Cypher (a relationship can't have two sources).
2283+
//
2284+
// The implementation correctly:
2285+
// 1. Detects relationship variables in both patterns
2286+
// 2. Generates composite keys (src_id + dst_id) for relationship variables
2287+
// 3. Generates single keys for node variables
2288+
use crate::datafusion_planner::analysis;
2289+
use crate::logical_plan::LogicalOperator;
2290+
2291+
let cfg = crate::config::GraphConfig::builder()
2292+
.with_node_label("Person", "id")
2293+
.with_relationship("KNOWS", "src_person_id", "dst_person_id")
2294+
.build()
2295+
.unwrap();
2296+
let planner = DataFusionPlanner::with_catalog(cfg, make_catalog());
2297+
2298+
// Left: (a:Person)-[r1:KNOWS]->(b:Person)
2299+
let scan_a = LogicalOperator::ScanByLabel {
2300+
variable: "a".to_string(),
2301+
label: "Person".to_string(),
2302+
properties: Default::default(),
2303+
};
2304+
let expand_left = LogicalOperator::Expand {
2305+
input: Box::new(scan_a),
2306+
source_variable: "a".to_string(),
2307+
target_variable: "b".to_string(),
2308+
target_label: "Person".to_string(),
2309+
relationship_types: vec!["KNOWS".to_string()],
2310+
direction: crate::ast::RelationshipDirection::Outgoing,
2311+
relationship_variable: Some("r1".to_string()),
2312+
properties: Default::default(),
2313+
target_properties: Default::default(),
2314+
};
2315+
2316+
// Right: (b:Person)-[r2:KNOWS]->(c:Person) - shares node 'b'
2317+
let scan_b = LogicalOperator::ScanByLabel {
2318+
variable: "b".to_string(),
2319+
label: "Person".to_string(),
2320+
properties: Default::default(),
2321+
};
2322+
let expand_right = LogicalOperator::Expand {
2323+
input: Box::new(scan_b),
2324+
source_variable: "b".to_string(),
2325+
target_variable: "c".to_string(),
2326+
target_label: "Person".to_string(),
2327+
relationship_types: vec!["KNOWS".to_string()],
2328+
direction: crate::ast::RelationshipDirection::Outgoing,
2329+
relationship_variable: Some("r2".to_string()),
2330+
properties: Default::default(),
2331+
target_properties: Default::default(),
2332+
};
2333+
2334+
// Analyze both patterns to build the context
2335+
let left_analysis = analysis::analyze(&expand_left).unwrap();
2336+
let left_ctx = analysis::PlanningContext::new(&left_analysis);
2337+
2338+
// Test the key inference logic directly
2339+
let (left_keys, right_keys) =
2340+
planner.infer_join_keys(&left_ctx, &expand_left, &expand_right);
2341+
2342+
// Should generate join keys for shared node variable 'b'
2343+
assert!(
2344+
!left_keys.is_empty(),
2345+
"Should generate join keys for shared node 'b'"
2346+
);
2347+
assert_eq!(
2348+
left_keys.len(),
2349+
right_keys.len(),
2350+
"Left and right keys should match"
2351+
);
2352+
2353+
// Should contain b__id (the shared node)
2354+
assert!(
2355+
left_keys.iter().any(|k| k.contains("b__id")),
2356+
"Should have join key for shared node 'b': {:?}",
2357+
left_keys
2358+
);
2359+
2360+
// Verify that relationship variables r1 and r2 are collected
2361+
let left_vars = planner.extract_variables(&expand_left);
2362+
let right_vars = planner.extract_variables(&expand_right);
2363+
2364+
assert!(left_vars.contains(&"r1".to_string()), "Left should have r1");
2365+
assert!(
2366+
right_vars.contains(&"r2".to_string()),
2367+
"Right should have r2"
2368+
);
2369+
2370+
// r1 and r2 are different, so they shouldn't be in shared variables
2371+
let shared: Vec<String> = left_vars
2372+
.iter()
2373+
.filter(|v| right_vars.contains(v))
2374+
.cloned()
2375+
.collect();
2376+
assert!(!shared.contains(&"r1".to_string()), "r1 is not shared");
2377+
assert!(!shared.contains(&"r2".to_string()), "r2 is not shared");
2378+
assert!(shared.contains(&"b".to_string()), "b is shared");
2379+
}
22492380
}

rust/lance-graph/tests/test_datafusion_pipeline.rs

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2912,3 +2912,160 @@ async fn test_datafusion_disconnected_with_distinct() {
29122912
.collect();
29132913
assert_eq!(name_set, expected);
29142914
}
2915+
2916+
#[tokio::test]
2917+
async fn test_datafusion_shared_node_variable_join() {
2918+
// This should join on shared variable 'b' using b.id
2919+
let config = create_graph_config();
2920+
let person_batch = create_person_dataset();
2921+
let knows_batch = create_knows_dataset();
2922+
2923+
let query = CypherQuery::new(
2924+
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person) \
2925+
RETURN a.name, b.name, c.name ORDER BY a.name, c.name",
2926+
)
2927+
.unwrap()
2928+
.with_config(config);
2929+
2930+
let mut datasets = HashMap::new();
2931+
datasets.insert("Person".to_string(), person_batch);
2932+
datasets.insert("KNOWS".to_string(), knows_batch);
2933+
2934+
let result = query.execute_datafusion(datasets).await.unwrap();
2935+
2936+
// This is a two-hop path query that should use join key inference on 'b'
2937+
// Alice(1) -> Bob(2) -> Charlie(3)
2938+
// Alice(1) -> Bob(2) -> David(4)
2939+
assert!(
2940+
result.num_rows() >= 2,
2941+
"Should have at least 2 two-hop paths"
2942+
);
2943+
2944+
let a_names = result
2945+
.column(0)
2946+
.as_any()
2947+
.downcast_ref::<StringArray>()
2948+
.unwrap();
2949+
let b_names = result
2950+
.column(1)
2951+
.as_any()
2952+
.downcast_ref::<StringArray>()
2953+
.unwrap();
2954+
let c_names = result
2955+
.column(2)
2956+
.as_any()
2957+
.downcast_ref::<StringArray>()
2958+
.unwrap();
2959+
2960+
// Verify at least one path: Alice -> Bob -> Charlie
2961+
let mut found_path = false;
2962+
for i in 0..result.num_rows() {
2963+
if a_names.value(i) == "Alice" && b_names.value(i) == "Bob" && c_names.value(i) == "Charlie"
2964+
{
2965+
found_path = true;
2966+
break;
2967+
}
2968+
}
2969+
assert!(found_path, "Should find path: Alice -> Bob -> Charlie");
2970+
}
2971+
2972+
#[tokio::test]
2973+
async fn test_datafusion_shared_variable_with_filter() {
2974+
let config = create_graph_config();
2975+
let person_batch = create_person_dataset();
2976+
let knows_batch = create_knows_dataset();
2977+
2978+
let query = CypherQuery::new(
2979+
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person) \
2980+
WHERE a.age > 20 AND c.age < 40 \
2981+
RETURN a.name, b.name, c.name",
2982+
)
2983+
.unwrap()
2984+
.with_config(config);
2985+
2986+
let mut datasets = HashMap::new();
2987+
datasets.insert("Person".to_string(), person_batch);
2988+
datasets.insert("KNOWS".to_string(), knows_batch);
2989+
2990+
let result = query.execute_datafusion(datasets).await.unwrap();
2991+
2992+
// Should successfully execute with join key inference + filters
2993+
assert!(result.num_rows() > 0, "Should have results with filters");
2994+
2995+
// Verify all results satisfy the age constraints
2996+
// All 'a' nodes should have age > 20 (excludes no one in our dataset)
2997+
// All 'c' nodes should have age < 40 (excludes David who is 40)
2998+
for i in 0..result.num_rows() {
2999+
let c_name = result
3000+
.column(2)
3001+
.as_any()
3002+
.downcast_ref::<StringArray>()
3003+
.unwrap()
3004+
.value(i);
3005+
assert_ne!(c_name, "David", "David (age 40) should be filtered out");
3006+
}
3007+
}
3008+
3009+
#[tokio::test]
3010+
async fn test_datafusion_multiple_shared_variables() {
3011+
let config = create_graph_config();
3012+
let person_batch = create_person_dataset();
3013+
let knows_batch = create_knows_dataset();
3014+
3015+
let query = CypherQuery::new(
3016+
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person), (c)-[:KNOWS]->(d:Person) \
3017+
RETURN a.name, b.name, c.name, d.name",
3018+
)
3019+
.unwrap()
3020+
.with_config(config);
3021+
3022+
let mut datasets = HashMap::new();
3023+
datasets.insert("Person".to_string(), person_batch);
3024+
datasets.insert("KNOWS".to_string(), knows_batch);
3025+
3026+
let result = query.execute_datafusion(datasets).await.unwrap();
3027+
3028+
// This is a three-hop path query using join key inference on 'b' and 'c'
3029+
// Should successfully execute (may have 0 or more results depending on data)
3030+
assert_eq!(result.num_columns(), 4);
3031+
}
3032+
3033+
#[tokio::test]
3034+
async fn test_datafusion_shared_variable_distinct() {
3035+
let config = create_graph_config();
3036+
let person_batch = create_person_dataset();
3037+
let knows_batch = create_knows_dataset();
3038+
3039+
let query = CypherQuery::new(
3040+
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person) \
3041+
RETURN DISTINCT b.name ORDER BY b.name",
3042+
)
3043+
.unwrap()
3044+
.with_config(config);
3045+
3046+
let mut datasets = HashMap::new();
3047+
datasets.insert("Person".to_string(), person_batch);
3048+
datasets.insert("KNOWS".to_string(), knows_batch);
3049+
3050+
let result = query.execute_datafusion(datasets).await.unwrap();
3051+
3052+
// Should return distinct intermediate nodes that have both incoming and outgoing KNOWS edges
3053+
assert!(result.num_rows() > 0, "Should have intermediate nodes");
3054+
assert_eq!(result.num_columns(), 1);
3055+
3056+
let names = result
3057+
.column(0)
3058+
.as_any()
3059+
.downcast_ref::<StringArray>()
3060+
.unwrap();
3061+
3062+
// Verify no duplicates
3063+
let name_set: std::collections::HashSet<String> = (0..result.num_rows())
3064+
.map(|i| names.value(i).to_string())
3065+
.collect();
3066+
assert_eq!(
3067+
name_set.len(),
3068+
result.num_rows(),
3069+
"DISTINCT should eliminate duplicates"
3070+
);
3071+
}

0 commit comments

Comments
 (0)