Skip to content

Commit 4c80fd5

Browse files
committed
feat(cypher): support OR/XOR-around-pattern WHERE lowering (#1236)
1 parent 6891d39 commit 4c80fd5

8 files changed

Lines changed: 336 additions & 38 deletions

File tree

graphistry/compute/ast.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,6 +1764,28 @@ def anti_semi_apply(
17641764
)
17651765

17661766

1767+
def semi_apply_mark(
1768+
*,
1769+
binding_ops: List[Dict[str, Any]],
1770+
join_aliases: Sequence[str],
1771+
out_col: str,
1772+
) -> ASTCall:
1773+
"""Annotate active rows with a correlated pattern-existence boolean.
1774+
1775+
``binding_ops`` encodes the pattern to evaluate as bindings rows.
1776+
``join_aliases`` names shared aliases used as join keys.
1777+
``out_col`` receives a bool marker where True means the pattern matched.
1778+
"""
1779+
return ASTCall(
1780+
"semi_apply_mark",
1781+
{
1782+
"binding_ops": binding_ops,
1783+
"join_aliases": list(join_aliases),
1784+
"out_col": out_col,
1785+
},
1786+
)
1787+
1788+
17671789
def order_by(keys: Iterable[Tuple[Any, str]]) -> ASTCall:
17681790
"""Create an ORDER BY operation for GFQL row pipelines."""
17691791
return ASTCall("order_by", {"keys": list(keys)})

graphistry/compute/gfql/call/validation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ def _group_by_requires_node_cols(params: Dict[str, object]) -> List[str]:
217217
return out
218218

219219

220+
def _semi_apply_mark_added_node_cols(params: Dict[str, object]) -> Set[str]:
221+
out_col = params.get("out_col")
222+
if isinstance(out_col, str) and out_col != "":
223+
return {out_col}
224+
return set()
225+
226+
220227
# Parser-backed helpers stay local because tests monkeypatch parser availability
221228
# and capability behavior through this module.
222229

@@ -271,6 +278,18 @@ def _group_by_requires_node_cols(params: Dict[str, object]) -> List[str]:
271278
schema_effects=NO_SCHEMA_EFFECTS,
272279
),
273280

281+
'semi_apply_mark': _method_entry(
282+
allowed_params={'binding_ops', 'join_aliases', 'out_col'},
283+
required_params={'binding_ops', 'join_aliases', 'out_col'},
284+
param_validators={
285+
'binding_ops': is_list_of_dicts,
286+
'join_aliases': is_non_empty_list_of_strings,
287+
'out_col': is_non_empty_string,
288+
},
289+
description='Annotate active rows with correlated pattern-existence booleans',
290+
schema_effects=_schema_effects(adds_node_cols=_semi_apply_mark_added_node_cols),
291+
),
292+
274293
'order_by': _method_entry(
275294
allowed_params={'keys'},
276295
required_params={'keys'},

graphistry/compute/gfql/cypher/_boolean_expr_text.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,9 @@ def boolean_expr_to_text(expr: BooleanExpr) -> str:
5757
if expr.op == "atom":
5858
return expr.atom_text or ""
5959
if expr.op == "pattern":
60-
# Unreachable today: top-level AND leaves are lifted out by
61-
# ``_split_top_level_and_pattern_leaves`` before the binder walks
62-
# the tree, and patterns nested under NOT/OR/XOR are rejected
63-
# earlier with E108 errors. Contract for the defensive branch:
64-
# emit raw pattern source for round-trippability.
60+
# Pattern leaves may remain in expr_tree for nested OR/XOR/NOT
61+
# compositions. Emit raw source text for round-trippability; lowering
62+
# can rewrite these leaves to marker columns before row evaluation.
6563
return expr.atom_text or ""
6664
if expr.op == "not":
6765
operand = boolean_expr_to_text(expr.left) if expr.left is not None else ""

graphistry/compute/gfql/cypher/lowering.py

Lines changed: 168 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
order_by,
2424
return_,
2525
rows,
26+
semi_apply_mark,
2627
serialize_binding_ops,
2728
select,
2829
skip,
@@ -5993,9 +5994,12 @@ def _is_variable_length_relationship_pattern(relationship: RelationshipPattern)
59935994
def _reject_unsupported_variable_length_where_pattern_predicates(query: CypherQuery) -> None:
59945995
if query.where is None:
59955996
return
5996-
for predicate in query.where.predicates:
5997-
if not isinstance(predicate, WherePatternPredicate):
5998-
continue
5997+
predicates: List[WherePatternPredicate] = [
5998+
predicate for predicate in query.where.predicates if isinstance(predicate, WherePatternPredicate)
5999+
]
6000+
if query.where.expr_tree is not None:
6001+
predicates.extend(_where_expr_tree_pattern_predicates(query.where.expr_tree))
6002+
for predicate in predicates:
59996003
relationships = [
60006004
element
60016005
for element in predicate.pattern
@@ -6356,6 +6360,155 @@ def _predicate_pattern_aliases(predicate: WherePatternPredicate) -> List[str]:
63566360
return aliases
63576361

63586362

6363+
def _where_expr_tree_pattern_predicates(expr: BooleanExpr) -> List[WherePatternPredicate]:
6364+
out: List[WherePatternPredicate] = []
6365+
stack: List[BooleanExpr] = [expr]
6366+
while stack:
6367+
cur = stack.pop()
6368+
if cur.op == "pattern":
6369+
if cur.pattern is None:
6370+
raise _unsupported(
6371+
"Cypher WHERE pattern predicates must include a relationship",
6372+
field="where",
6373+
value=cur.atom_text,
6374+
line=cur.span.line,
6375+
column=cur.span.column,
6376+
)
6377+
out.append(WherePatternPredicate(pattern=cur.pattern, span=cur.span, negated=False))
6378+
continue
6379+
if cur.left is not None:
6380+
stack.append(cur.left)
6381+
if cur.right is not None:
6382+
stack.append(cur.right)
6383+
return out
6384+
6385+
6386+
def _lower_pattern_predicate_to_row_marker(
6387+
predicate: WherePatternPredicate,
6388+
*,
6389+
alias_targets: Mapping[str, ASTObject],
6390+
params: Optional[Mapping[str, Any]],
6391+
out_col: str,
6392+
) -> ASTCall:
6393+
if len(predicate.pattern) < 3:
6394+
raise _unsupported(
6395+
"Cypher WHERE pattern predicates must include a relationship",
6396+
field="where",
6397+
value=None,
6398+
line=predicate.span.line,
6399+
column=predicate.span.column,
6400+
)
6401+
6402+
predicate_aliases = _predicate_pattern_aliases(predicate)
6403+
if not predicate_aliases:
6404+
raise _unsupported(
6405+
"Cypher WHERE pattern predicates currently require at least one shared bound alias",
6406+
field="where",
6407+
value=None,
6408+
line=predicate.span.line,
6409+
column=predicate.span.column,
6410+
)
6411+
6412+
introduced_aliases = sorted(alias for alias in predicate_aliases if alias not in alias_targets)
6413+
if introduced_aliases:
6414+
raise _unsupported(
6415+
"Cypher WHERE pattern predicates cannot introduce new aliases in this phase",
6416+
field="where",
6417+
value=introduced_aliases,
6418+
line=predicate.span.line,
6419+
column=predicate.span.column,
6420+
)
6421+
6422+
shared_aliases = [alias for alias in predicate_aliases if alias in alias_targets]
6423+
if not shared_aliases:
6424+
raise _unsupported(
6425+
"Cypher WHERE pattern predicates currently require at least one shared bound alias",
6426+
field="where",
6427+
value=predicate_aliases,
6428+
line=predicate.span.line,
6429+
column=predicate.span.column,
6430+
)
6431+
6432+
pattern_clause = MatchClause(
6433+
patterns=(predicate.pattern,),
6434+
span=predicate.span,
6435+
optional=False,
6436+
pattern_aliases=(None,),
6437+
pattern_alias_kinds=("pattern",),
6438+
)
6439+
pattern_ops = lower_match_clause(pattern_clause, params=params)
6440+
return semi_apply_mark(
6441+
binding_ops=serialize_binding_ops(pattern_ops),
6442+
join_aliases=shared_aliases,
6443+
out_col=out_col,
6444+
)
6445+
6446+
6447+
def _rewrite_where_expr_patterns_to_markers(
6448+
*,
6449+
where: WhereClause,
6450+
alias_targets: Mapping[str, ASTObject],
6451+
params: Optional[Mapping[str, Any]],
6452+
) -> Tuple[Optional[ExpressionText], List[ASTCall]]:
6453+
if where.expr_tree is None:
6454+
return None, []
6455+
6456+
pattern_preds = _where_expr_tree_pattern_predicates(where.expr_tree)
6457+
if not pattern_preds:
6458+
return _where_clause_expr_text(where), []
6459+
6460+
marker_ops: List[ASTCall] = []
6461+
marker_counter = 0
6462+
6463+
def _fresh_marker_col(span: SourceSpan) -> str:
6464+
nonlocal marker_counter
6465+
marker_counter += 1
6466+
return (
6467+
"__gfql_where_pattern_"
6468+
f"{span.line}_{span.column}_{span.end_line}_{span.end_column}_{marker_counter}__"
6469+
)
6470+
6471+
def _rewrite(expr: BooleanExpr) -> BooleanExpr:
6472+
if expr.op == "pattern":
6473+
if expr.pattern is None:
6474+
raise _unsupported(
6475+
"Cypher WHERE pattern predicates must include a relationship",
6476+
field="where",
6477+
value=expr.atom_text,
6478+
line=expr.span.line,
6479+
column=expr.span.column,
6480+
)
6481+
marker_col = _fresh_marker_col(expr.span)
6482+
marker_ops.append(
6483+
_lower_pattern_predicate_to_row_marker(
6484+
WherePatternPredicate(pattern=expr.pattern, span=expr.span, negated=False),
6485+
alias_targets=alias_targets,
6486+
params=params,
6487+
out_col=marker_col,
6488+
)
6489+
)
6490+
return BooleanExpr(
6491+
op="atom",
6492+
span=expr.span,
6493+
atom_text=marker_col,
6494+
atom_span=expr.atom_span or expr.span,
6495+
)
6496+
if expr.op in {"atom"}:
6497+
return expr
6498+
if expr.op == "not":
6499+
return replace(expr, left=_rewrite(cast(BooleanExpr, expr.left)))
6500+
if expr.op in {"and", "or", "xor"}:
6501+
return replace(
6502+
expr,
6503+
left=_rewrite(cast(BooleanExpr, expr.left)),
6504+
right=_rewrite(cast(BooleanExpr, expr.right)),
6505+
)
6506+
return expr
6507+
6508+
rewritten = _rewrite(where.expr_tree)
6509+
return ExpressionText(text=boolean_expr_to_text(rewritten), span=where.span), marker_ops
6510+
6511+
63596512
def _lower_negated_pattern_predicate_to_row_filter(
63606513
predicate: WherePatternPredicate,
63616514
*,
@@ -6449,7 +6602,12 @@ def lower_match_query(
64496602
row_where: Optional[ExpressionText] = None
64506603
row_where_predicates: List[str] = list(dynamic_row_where_predicates)
64516604
if query.where is not None:
6452-
where_expr = _where_clause_expr_text(query.where)
6605+
where_expr, where_pattern_row_filters = _rewrite_where_expr_patterns_to_markers(
6606+
where=query.where,
6607+
alias_targets=alias_targets,
6608+
params=params,
6609+
)
6610+
row_pre_filters.extend(where_pattern_row_filters)
64536611
if where_expr is not None:
64546612
type_where = _extract_relationship_type_where(
64556613
where_expr,
@@ -8270,7 +8428,12 @@ def _apply_where_to_ops(
82708428
row_pre_filters: List[ASTCall] = []
82718429
if where is None:
82728430
return where_out, row_expr_filters, row_pre_filters
8273-
where_expr = _where_clause_expr_text(where)
8431+
where_expr, where_pattern_row_filters = _rewrite_where_expr_patterns_to_markers(
8432+
where=where,
8433+
alias_targets=alias_targets,
8434+
params=params,
8435+
)
8436+
row_pre_filters.extend(where_pattern_row_filters)
82748437
if where_expr is not None:
82758438
type_where = _extract_relationship_type_where(
82768439
where_expr,

graphistry/compute/gfql/cypher/parser.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ def _split_top_level_and_pattern_leaves(
361361
- ``others``: non-pattern conjuncts that should remain in ``expr_tree``.
362362
- ``has_nested_pattern``: True when a pattern atom appears in a deeper
363363
non-AND/non-direct-NOT context (e.g. ``OR`` with a pattern leaf, or
364-
``NOT (and-tree-of-patterns)``). Triggers the legacy E108 reject so
365-
slice 4 / De-Morgan-NOT compositions stay deferred.
364+
``NOT (and-tree-of-patterns)``). Lowering consumes this by keeping
365+
such leaves in ``expr_tree`` instead of lifting them to predicates.
366366
"""
367367
if expr.op == "and":
368368
if expr.left is None or expr.right is None:
@@ -420,17 +420,6 @@ def _build_where_with_pattern_lift(
420420
expr_text: str,
421421
span: SourceSpan,
422422
) -> WhereClause:
423-
if nested_pattern:
424-
raise GFQLValidationError(
425-
ErrorCode.E108,
426-
"Cypher WHERE pattern predicates cannot yet be mixed with generic row expressions",
427-
field="where",
428-
value=expr_text,
429-
suggestion="Use positive top-level pattern predicates joined by AND.",
430-
line=span.line,
431-
column=span.column,
432-
language="cypher",
433-
)
434423
# Slice 3 (#1031): N positive patterns each become a WherePatternPredicate
435424
# (negated=False). Slice 2 (#1031): N NOT-patterns each become a
436425
# WherePatternPredicate (negated=True) for downstream anti-semi-join
@@ -446,6 +435,9 @@ def _build_where_with_pattern_lift(
446435
new_expr_tree = _rebuild_and_tree(other_conjuncts)
447436
if new_expr_tree is None:
448437
return WhereClause(predicates=tuple(pattern_preds), expr_tree=None, span=span)
438+
# Nested pattern leaves (OR/XOR/complex NOT contexts) stay in expr_tree;
439+
# lowering rewrites them to correlated semi-apply marker columns.
440+
_ = nested_pattern
449441
return WhereClause(
450442
predicates=tuple(pattern_preds),
451443
expr_tree=new_expr_tree,

0 commit comments

Comments
 (0)