|
23 | 23 | order_by, |
24 | 24 | return_, |
25 | 25 | rows, |
| 26 | + semi_apply_mark, |
26 | 27 | serialize_binding_ops, |
27 | 28 | select, |
28 | 29 | skip, |
@@ -5993,9 +5994,12 @@ def _is_variable_length_relationship_pattern(relationship: RelationshipPattern) |
5993 | 5994 | def _reject_unsupported_variable_length_where_pattern_predicates(query: CypherQuery) -> None: |
5994 | 5995 | if query.where is None: |
5995 | 5996 | return |
5996 | | - for predicate in query.where.predicates: |
5997 | | - if not isinstance(predicate, WherePatternPredicate): |
5998 | | - continue |
| 5997 | + predicates: List[WherePatternPredicate] = [ |
| 5998 | + predicate for predicate in query.where.predicates if isinstance(predicate, WherePatternPredicate) |
| 5999 | + ] |
| 6000 | + if query.where.expr_tree is not None: |
| 6001 | + predicates.extend(_where_expr_tree_pattern_predicates(query.where.expr_tree)) |
| 6002 | + for predicate in predicates: |
5999 | 6003 | relationships = [ |
6000 | 6004 | element |
6001 | 6005 | for element in predicate.pattern |
@@ -6356,6 +6360,155 @@ def _predicate_pattern_aliases(predicate: WherePatternPredicate) -> List[str]: |
6356 | 6360 | return aliases |
6357 | 6361 |
|
6358 | 6362 |
|
| 6363 | +def _where_expr_tree_pattern_predicates(expr: BooleanExpr) -> List[WherePatternPredicate]: |
| 6364 | + out: List[WherePatternPredicate] = [] |
| 6365 | + stack: List[BooleanExpr] = [expr] |
| 6366 | + while stack: |
| 6367 | + cur = stack.pop() |
| 6368 | + if cur.op == "pattern": |
| 6369 | + if cur.pattern is None: |
| 6370 | + raise _unsupported( |
| 6371 | + "Cypher WHERE pattern predicates must include a relationship", |
| 6372 | + field="where", |
| 6373 | + value=cur.atom_text, |
| 6374 | + line=cur.span.line, |
| 6375 | + column=cur.span.column, |
| 6376 | + ) |
| 6377 | + out.append(WherePatternPredicate(pattern=cur.pattern, span=cur.span, negated=False)) |
| 6378 | + continue |
| 6379 | + if cur.left is not None: |
| 6380 | + stack.append(cur.left) |
| 6381 | + if cur.right is not None: |
| 6382 | + stack.append(cur.right) |
| 6383 | + return out |
| 6384 | + |
| 6385 | + |
| 6386 | +def _lower_pattern_predicate_to_row_marker( |
| 6387 | + predicate: WherePatternPredicate, |
| 6388 | + *, |
| 6389 | + alias_targets: Mapping[str, ASTObject], |
| 6390 | + params: Optional[Mapping[str, Any]], |
| 6391 | + out_col: str, |
| 6392 | +) -> ASTCall: |
| 6393 | + if len(predicate.pattern) < 3: |
| 6394 | + raise _unsupported( |
| 6395 | + "Cypher WHERE pattern predicates must include a relationship", |
| 6396 | + field="where", |
| 6397 | + value=None, |
| 6398 | + line=predicate.span.line, |
| 6399 | + column=predicate.span.column, |
| 6400 | + ) |
| 6401 | + |
| 6402 | + predicate_aliases = _predicate_pattern_aliases(predicate) |
| 6403 | + if not predicate_aliases: |
| 6404 | + raise _unsupported( |
| 6405 | + "Cypher WHERE pattern predicates currently require at least one shared bound alias", |
| 6406 | + field="where", |
| 6407 | + value=None, |
| 6408 | + line=predicate.span.line, |
| 6409 | + column=predicate.span.column, |
| 6410 | + ) |
| 6411 | + |
| 6412 | + introduced_aliases = sorted(alias for alias in predicate_aliases if alias not in alias_targets) |
| 6413 | + if introduced_aliases: |
| 6414 | + raise _unsupported( |
| 6415 | + "Cypher WHERE pattern predicates cannot introduce new aliases in this phase", |
| 6416 | + field="where", |
| 6417 | + value=introduced_aliases, |
| 6418 | + line=predicate.span.line, |
| 6419 | + column=predicate.span.column, |
| 6420 | + ) |
| 6421 | + |
| 6422 | + shared_aliases = [alias for alias in predicate_aliases if alias in alias_targets] |
| 6423 | + if not shared_aliases: |
| 6424 | + raise _unsupported( |
| 6425 | + "Cypher WHERE pattern predicates currently require at least one shared bound alias", |
| 6426 | + field="where", |
| 6427 | + value=predicate_aliases, |
| 6428 | + line=predicate.span.line, |
| 6429 | + column=predicate.span.column, |
| 6430 | + ) |
| 6431 | + |
| 6432 | + pattern_clause = MatchClause( |
| 6433 | + patterns=(predicate.pattern,), |
| 6434 | + span=predicate.span, |
| 6435 | + optional=False, |
| 6436 | + pattern_aliases=(None,), |
| 6437 | + pattern_alias_kinds=("pattern",), |
| 6438 | + ) |
| 6439 | + pattern_ops = lower_match_clause(pattern_clause, params=params) |
| 6440 | + return semi_apply_mark( |
| 6441 | + binding_ops=serialize_binding_ops(pattern_ops), |
| 6442 | + join_aliases=shared_aliases, |
| 6443 | + out_col=out_col, |
| 6444 | + ) |
| 6445 | + |
| 6446 | + |
| 6447 | +def _rewrite_where_expr_patterns_to_markers( |
| 6448 | + *, |
| 6449 | + where: WhereClause, |
| 6450 | + alias_targets: Mapping[str, ASTObject], |
| 6451 | + params: Optional[Mapping[str, Any]], |
| 6452 | +) -> Tuple[Optional[ExpressionText], List[ASTCall]]: |
| 6453 | + if where.expr_tree is None: |
| 6454 | + return None, [] |
| 6455 | + |
| 6456 | + pattern_preds = _where_expr_tree_pattern_predicates(where.expr_tree) |
| 6457 | + if not pattern_preds: |
| 6458 | + return _where_clause_expr_text(where), [] |
| 6459 | + |
| 6460 | + marker_ops: List[ASTCall] = [] |
| 6461 | + marker_counter = 0 |
| 6462 | + |
| 6463 | + def _fresh_marker_col(span: SourceSpan) -> str: |
| 6464 | + nonlocal marker_counter |
| 6465 | + marker_counter += 1 |
| 6466 | + return ( |
| 6467 | + "__gfql_where_pattern_" |
| 6468 | + f"{span.line}_{span.column}_{span.end_line}_{span.end_column}_{marker_counter}__" |
| 6469 | + ) |
| 6470 | + |
| 6471 | + def _rewrite(expr: BooleanExpr) -> BooleanExpr: |
| 6472 | + if expr.op == "pattern": |
| 6473 | + if expr.pattern is None: |
| 6474 | + raise _unsupported( |
| 6475 | + "Cypher WHERE pattern predicates must include a relationship", |
| 6476 | + field="where", |
| 6477 | + value=expr.atom_text, |
| 6478 | + line=expr.span.line, |
| 6479 | + column=expr.span.column, |
| 6480 | + ) |
| 6481 | + marker_col = _fresh_marker_col(expr.span) |
| 6482 | + marker_ops.append( |
| 6483 | + _lower_pattern_predicate_to_row_marker( |
| 6484 | + WherePatternPredicate(pattern=expr.pattern, span=expr.span, negated=False), |
| 6485 | + alias_targets=alias_targets, |
| 6486 | + params=params, |
| 6487 | + out_col=marker_col, |
| 6488 | + ) |
| 6489 | + ) |
| 6490 | + return BooleanExpr( |
| 6491 | + op="atom", |
| 6492 | + span=expr.span, |
| 6493 | + atom_text=marker_col, |
| 6494 | + atom_span=expr.atom_span or expr.span, |
| 6495 | + ) |
| 6496 | + if expr.op in {"atom"}: |
| 6497 | + return expr |
| 6498 | + if expr.op == "not": |
| 6499 | + return replace(expr, left=_rewrite(cast(BooleanExpr, expr.left))) |
| 6500 | + if expr.op in {"and", "or", "xor"}: |
| 6501 | + return replace( |
| 6502 | + expr, |
| 6503 | + left=_rewrite(cast(BooleanExpr, expr.left)), |
| 6504 | + right=_rewrite(cast(BooleanExpr, expr.right)), |
| 6505 | + ) |
| 6506 | + return expr |
| 6507 | + |
| 6508 | + rewritten = _rewrite(where.expr_tree) |
| 6509 | + return ExpressionText(text=boolean_expr_to_text(rewritten), span=where.span), marker_ops |
| 6510 | + |
| 6511 | + |
6359 | 6512 | def _lower_negated_pattern_predicate_to_row_filter( |
6360 | 6513 | predicate: WherePatternPredicate, |
6361 | 6514 | *, |
@@ -6449,7 +6602,12 @@ def lower_match_query( |
6449 | 6602 | row_where: Optional[ExpressionText] = None |
6450 | 6603 | row_where_predicates: List[str] = list(dynamic_row_where_predicates) |
6451 | 6604 | if query.where is not None: |
6452 | | - where_expr = _where_clause_expr_text(query.where) |
| 6605 | + where_expr, where_pattern_row_filters = _rewrite_where_expr_patterns_to_markers( |
| 6606 | + where=query.where, |
| 6607 | + alias_targets=alias_targets, |
| 6608 | + params=params, |
| 6609 | + ) |
| 6610 | + row_pre_filters.extend(where_pattern_row_filters) |
6453 | 6611 | if where_expr is not None: |
6454 | 6612 | type_where = _extract_relationship_type_where( |
6455 | 6613 | where_expr, |
@@ -8270,7 +8428,12 @@ def _apply_where_to_ops( |
8270 | 8428 | row_pre_filters: List[ASTCall] = [] |
8271 | 8429 | if where is None: |
8272 | 8430 | return where_out, row_expr_filters, row_pre_filters |
8273 | | - where_expr = _where_clause_expr_text(where) |
| 8431 | + where_expr, where_pattern_row_filters = _rewrite_where_expr_patterns_to_markers( |
| 8432 | + where=where, |
| 8433 | + alias_targets=alias_targets, |
| 8434 | + params=params, |
| 8435 | + ) |
| 8436 | + row_pre_filters.extend(where_pattern_row_filters) |
8274 | 8437 | if where_expr is not None: |
8275 | 8438 | type_where = _extract_relationship_type_where( |
8276 | 8439 | where_expr, |
|
0 commit comments