Skip to content

Commit 73fa242

Browse files
lmeyerovclaude
andcommitted
feat(gfql): count(*) short-circuit — O(1) count_table op, no full-frame materialize
A lone `RETURN count(*)` over a single node/edge pattern used to materialize the whole matched frame and run a constant-key group_by just to count rows (the ~770x-vs-Ladybug loss in the fair Cypher benchmark: count = 2543-8206ms while the df-path count is ~0.01ms). The Cypher lowering now emits a new `count_table` row op that reads the scanned table's height directly (or sums the boolean alias-mask column when the pattern filters) in a single reduction — no frame copy, no group_by. Guarded to the provably-equivalent shapes only: exactly one non-DISTINCT count(*), no group keys / post-agg exprs / row-level WHERE / UNWIND / paging / multi-relationship binding, and either a pure node scan (relationship_count==0) or a single relationship counted on its edge alias (== the cases _reject_unsound_relationship_multiplicity_aggregates permits). Every other shape falls through to the general aggregate path unchanged. The op replaces the rows()+with_+group_by prefix and keeps the identical trailing projection, so downstream/result is byte-identical. count_table is a native frame op (like rows/limit) routed via _POLARS_NATIVE_ROW_PIPELINE_CALLS -> op.execute -> frame_ops.count_table, so one implementation covers pandas / cuDF / polars / polars-gpu. Validation: CPU parity (pandas+polars node/edge count == oracle, fast path confirmed taken); full test_lowering suite 1394 passed; coverage ledger green (count_table accounted in _rowop_exercised + _ROW_OP_CASES + CALL_KNOWN_UNCOVERED); dgx 4-engine conformance (count_all_nodes/edges + count_table rowop, chain+dag) 55 passed on pandas/cuDF/polars/polars-gpu; ruff + mypy clean. The 5M/20M Ladybug head-to-head benchmark is the next step (needs the ladybug bench harness rebased onto this). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent ca2b45e commit 73fa242

8 files changed

Lines changed: 188 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
1818
### Added
1919
- **GFQL native Polars engine — more cypher row coverage (`toFloat`, `collect`/`collect(DISTINCT)`, `WHERE … IN`)**: three surfaces that previously raised `NotImplementedError` on `engine='polars'` now run natively, parity-validated vs the pandas oracle across all four engines (and honest-NIE where pandas can't be matched). **`toFloat(x)`** lowers int/uint/bool/float → `Float64` (NaN preserved — float64 has no separate null sentinel, unlike `toInteger`); a non-numeric String declines (NIE) because pandas `astype(float)` *raises* rather than null-on-failure. **`collect(x)` / `collect(DISTINCT x)`** aggregations complete the native `group_by` surface (every other agg was already native): drop nulls, preserve within-group first-occurrence order (`collect` keeps dups; `DISTINCT` dedups keep-first), all-null group → `[]`. **`where_rows`/`WHERE … IN [list]`** membership lowers to `is_in` (a null cell is excluded per openCypher 3VL). No change to any already-native path.
2020
- **GFQL Polars-CPU streaming collect (opt-in, large traversals)**: `GFQL_POLARS_CPU_STREAMING=1` runs the polars-CPU lazy collects (`hop`/`chain`) on the polars **streaming** executor instead of the default in-memory collect. Benchmarked ~1.04–1.11× faster on big multi-hop traversals (10M nodes / 80M edges: 20.0→18.0 s) and parity-identical, but ~0.86× (slower) on small/interactive sizes (streaming overhead) — so it is **opt-in, default off** (no change to default behavior). Use for large batch traversals where CPU is the target.
21+
- **GFQL Cypher `count(*)` short-circuit — O(1) instead of O(N) materialize**: a lone `RETURN count(*)` over a single node or edge pattern (`MATCH (n) RETURN count(*)`, `MATCH ()-[r]->() RETURN count(*)`) previously materialized the entire matched frame and ran a constant-key `group_by` just to count its rows. The lowering now emits a new `count_table` row op that reads the scanned table's height directly (or, when the pattern applies a filter, sums the boolean alias-mask column) in a single reduction — no full-frame copy, no group_by. Applies only to the provably-equivalent shapes: exactly one non-DISTINCT `count(*)`, no group keys / post-aggregate exprs / row-level `WHERE` / `UNWIND` / paging / multi-relationship binding, and either a pure node scan (`relationship_count == 0`) or a single relationship counted on its edge alias — every other shape (node-alias-over-relationship, multi-hop paths, `count(DISTINCT …)`, grouped counts) falls through to the general aggregate path unchanged. Engine-polymorphic across pandas/cuDF/polars/polars-gpu; differential parity verified on all four engines (`count_all_nodes`/`count_all_edges` cypher conformance cases + a `count_table` row-op subject case, chain and DAG surfaces). No change to any result value — only the execution path.
2122

2223
### Fixed
2324
- **GFQL `ne()` / `<>` on NULL now follows openCypher/SQL 3-valued logic (pandas)**: `n({"col": ne(x)})` and cypher `WHERE n.col <> x` over a NULL/NA cell used to KEEP the null row on the pandas engine (`NaN != x` → True), diverging from cuDF and the polars engine (both drop it) — and even from pandas' own `WHERE NOT n.col = x` path. Per openCypher/SQL three-valued logic, `null <> x` is `null` (an unknown value cannot be proven unequal to `x`), so a null cell is **not** a match and the row is excluded — consistent with `eq`/`gt`/`lt`/`IN` (which already dropped nulls). Fixed the `NE` predicate to mask out nulls; this corrects both the `filter_dict` predicate path and the single-entity cypher `<>` WHERE path on pandas. cuDF/polars/polars-gpu were already conformant. Verified across all four engines (`ne`, `<>`, `NOT =`, `NOT IN` all drop the null). Note: this is a behavior change for `ne()` on nullable columns under the default pandas engine. (Broader openCypher null-semantics alignment + docs tracked in #1664.)

graphistry/compute/ast.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,28 @@ def drop_cols(cols: Iterable[str]) -> ASTCall:
17511751
return ASTCall("drop_cols", {"cols": list(cols)})
17521752

17531753

1754+
def count_table(
1755+
table: str = "nodes",
1756+
source: Optional[str] = None,
1757+
alias: str = "count(*)",
1758+
) -> ASTCall:
1759+
"""Count matched rows without materializing them (fast path for a lone ``count(*)``).
1760+
1761+
Emitted by the Cypher lowering when a RETURN is exactly ``count(*)`` over a
1762+
single node/edge pattern (no DISTINCT, GROUP BY, row-level WHERE, UNWIND,
1763+
paging, or multi-relationship binding). Produces a one-row table
1764+
``{alias: n}`` where ``n`` is the height of the active ``table`` (or, when a
1765+
``source`` alias-mask column is present, the count of its truthy rows). This
1766+
avoids the full-frame materialize + constant-key ``group_by`` the general
1767+
aggregate path performs — the win that turns count(*) from O(N) into a single
1768+
reduction. See plans/gfql-engine-followups (BEAT LADYBUG).
1769+
"""
1770+
params: Dict[str, Any] = {"table": table, "alias": alias}
1771+
if source is not None:
1772+
params["source"] = source
1773+
return ASTCall("count_table", params)
1774+
1775+
17541776
def group_by(
17551777
keys: Iterable[str],
17561778
aggregations: Iterable[Sequence[Any]],

graphistry/compute/gfql/call/validation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,19 @@ def _semi_apply_mark_added_node_cols(params: Dict[str, object]) -> Set[str]:
491491
description='Drop duplicate rows from active row table',
492492
),
493493

494+
'count_table': _safelist_entry(
495+
{'table', 'source', 'alias'},
496+
param_validators={
497+
'table': lambda v: v in ['nodes', 'edges'],
498+
'source': is_string_or_none,
499+
'alias': is_non_empty_string,
500+
},
501+
description='Count matched rows (node/edge table height or source-alias mask) into a one-row table, without materializing the frame — fast path for a lone count(*)',
502+
schema_effects=_schema_effects(
503+
adds_node_cols=lambda p: [p.get('alias', 'count(*)')],
504+
),
505+
),
506+
494507
'get_degrees': _safelist_entry(
495508
{'col', 'degree_in', 'degree_out', 'engine'},
496509
param_validators={

graphistry/compute/gfql/cypher/lowering.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ASTObject,
1414
ASTNode,
1515
anti_semi_apply,
16+
count_table,
1617
distinct,
1718
drop_cols,
1819
e_forward,
@@ -2308,6 +2309,60 @@ def _match_relationship_count(clause: MatchClause) -> int:
23082309
return sum(1 for element in _match_pattern_elements(clause) if isinstance(element, RelationshipPattern))
23092310

23102311

2312+
def _is_pure_count_star_shortcircuit(
2313+
*,
2314+
aggregate_specs: Sequence[_AggregateSpec],
2315+
pre_items: Sequence[Tuple[str, Any]],
2316+
row_steps: Sequence[ASTObject],
2317+
query: CypherQuery,
2318+
binding_row_aliases: AbstractSet[str],
2319+
relationship_count: int,
2320+
active_match_alias: Optional[str],
2321+
alias_targets: Mapping[str, ASTObject],
2322+
) -> bool:
2323+
"""True when a RETURN is exactly ``count(*)`` over a single node/edge scan.
2324+
2325+
Guards the count_table fast path (skip the full-frame materialize + constant-
2326+
key group_by): the count then equals the height (or source-mask sum) of the
2327+
active table. Requires a lone non-DISTINCT ``count(*)`` with no group keys,
2328+
post-aggregate exprs, row-level WHERE, UNWIND, or multi-relationship binding,
2329+
and a plain ``rows(table=nodes|edges[, source])`` as the only prior step.
2330+
Sound only for a pure node scan (``relationship_count == 0``) or a single
2331+
relationship counted on its edge alias (``relationship_count == 1`` with an
2332+
``ASTEdge`` active alias) — exactly the cases
2333+
``_reject_unsound_relationship_multiplicity_aggregates`` permits; any other
2334+
shape (node-alias-over-relationship, multi-hop paths) falls through to the
2335+
general aggregate path (which counts bindings or rejects as unsound).
2336+
"""
2337+
if len(aggregate_specs) != 1:
2338+
return False
2339+
agg = aggregate_specs[0]
2340+
if agg.func != "count" or agg.expr_text is not None or agg.distinct:
2341+
return False
2342+
if pre_items or binding_row_aliases or query.unwinds:
2343+
return False
2344+
# Exactly the initial rows() step: any row-level WHERE or UNWIND would have
2345+
# appended further steps, so len == 1 proves the count is over the raw scan.
2346+
if len(row_steps) != 1:
2347+
return False
2348+
base = row_steps[0]
2349+
if not (isinstance(base, ASTCall) and base.function == "rows"):
2350+
return False
2351+
if base.params.get("table") not in ("nodes", "edges"):
2352+
return False
2353+
if base.params.get("binding_ops") is not None or base.params.get("alias_endpoints") is not None:
2354+
return False
2355+
if relationship_count == 0:
2356+
return True
2357+
if (
2358+
relationship_count == 1
2359+
and active_match_alias is not None
2360+
and isinstance(alias_targets.get(active_match_alias), ASTEdge)
2361+
):
2362+
return True
2363+
return False
2364+
2365+
23112366
def _reject_unsound_relationship_multiplicity_aggregates_common(
23122367
*,
23132368
aggregate_specs: Sequence[_AggregateSpec],
@@ -6573,6 +6628,35 @@ def _lower_general_row_projection(
65736628
if len(pre_items) > 0:
65746629
row_steps.append(with_(pre_items, extend=bindings_row_path))
65756630
row_steps.append(group_by(key_names, aggregations, key_prefixes=alias_key_prefixes))
6631+
elif _is_pure_count_star_shortcircuit(
6632+
aggregate_specs=aggregate_specs,
6633+
pre_items=pre_items,
6634+
row_steps=row_steps,
6635+
query=query,
6636+
binding_row_aliases=binding_row_aliases,
6637+
relationship_count=relationship_count,
6638+
active_match_alias=active_match_alias,
6639+
alias_targets=alias_targets,
6640+
):
6641+
# Fast path: count_table reads the scanned table's height (or the
6642+
# source-alias mask sum) with one reduction — no full-frame
6643+
# materialize + constant-key group_by. It replaces the sole rows()
6644+
# step and produces the same ``count(*)`` column the group_by would,
6645+
# so the trailing identity projection (elif not key_names, below) is
6646+
# unchanged. empty_result_row is belt-and-suspenders here (count_table
6647+
# always emits a 1-row result), kept for parity with the group_by path.
6648+
base_rows = cast(ASTCall, row_steps[0])
6649+
count_alias = aggregate_specs[0].output_name
6650+
row_steps = [
6651+
count_table(
6652+
table=cast(str, base_rows.params.get("table", "nodes")),
6653+
source=cast(Optional[str], base_rows.params.get("source")),
6654+
alias=count_alias,
6655+
)
6656+
]
6657+
available_columns = {count_alias}
6658+
empty_aggregate_row = _empty_aggregate_row(aggregate_specs)
6659+
empty_result_row = empty_aggregate_row
65766660
else:
65776661
global_key = _fresh_temp_name(temp_names, "__cypher_group__")
65786662
row_steps.append(with_([(global_key, 1)] + pre_items))

graphistry/compute/gfql/row/frame_ops.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,64 @@ def rows(
167167
return row_table(ctx, table_df)
168168

169169

170+
def count_table(
171+
ctx: Any,
172+
table: str = "nodes",
173+
source: Optional[str] = None,
174+
alias: str = "count(*)",
175+
) -> "Plottable":
176+
"""Count matched rows and set a one-row ``{alias: n}`` result table.
177+
178+
Fast path for a lone ``count(*)``: reads the height of the active node/edge
179+
table (or the truthy count of the ``source`` alias-mask column) with a single
180+
reduction, never materializing/copying the whole frame the way ``rows`` +
181+
``group_by`` would. Engine-polymorphic across pandas/cuDF/polars (eager or
182+
lazy). See ``graphistry.compute.ast.count_table`` and the Cypher lowering
183+
short-circuit.
184+
"""
185+
if table not in {"nodes", "edges"}:
186+
raise ValueError(
187+
f"count_table(table=...) must be one of 'nodes' or 'edges', got {table!r}"
188+
)
189+
table_df = ctx._nodes if table == "nodes" else ctx._edges
190+
191+
if table_df is None:
192+
return row_table(ctx, pd.DataFrame({alias: [0]}))
193+
194+
if _is_polars(table_df):
195+
import polars as pl
196+
if source is not None:
197+
# LazyFrame lacks .columns without a resolve; collect_schema is lazy-safe.
198+
cols = table_df.collect_schema().names()
199+
if source not in cols:
200+
raise ValueError(
201+
f"count_table(source=...) alias column not found: {source!r}"
202+
)
203+
count_expr = pl.col(source).fill_null(False).cast(pl.Boolean).sum()
204+
else:
205+
count_expr = pl.len()
206+
res = table_df.select(count_expr.alias(alias))
207+
# eager DataFrame.select -> DataFrame (no collect); LazyFrame.select -> LazyFrame.
208+
if hasattr(res, "collect"):
209+
res = res.collect()
210+
n = int(res.item())
211+
return row_table(ctx, pl.DataFrame({alias: [n]}))
212+
213+
# pandas / cuDF (API-compatible)
214+
if source is not None:
215+
if source not in table_df.columns:
216+
raise ValueError(
217+
f"count_table(source=...) alias column not found: {source!r}"
218+
)
219+
mask = table_df[source]
220+
if hasattr(mask, "fillna"):
221+
mask = mask.fillna(False)
222+
n = int(mask.astype(bool).sum())
223+
else:
224+
n = int(len(table_df))
225+
return row_table(ctx, template_df_cons(table_df, {alias: [n]}))
226+
227+
170228
def drop_cols(ctx: Any, cols: Sequence[str]) -> "Plottable":
171229
"""Drop named columns from the active row table, ignoring any that don't exist."""
172230
table_df = get_active_table(ctx)

graphistry/compute/gfql/row/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _gfql_cudf_list_sort_series_requires_host_bridge(series: Any) -> bool:
180180
"unwind",
181181
"group_by",
182182
"drop_cols",
183+
"count_table",
183184
}
184185
)
185186

@@ -4403,6 +4404,7 @@ def order_by(self, keys: List[Any]) -> "Plottable":
44034404
skip = row_frame_ops.skip
44044405
limit = row_frame_ops.limit
44054406
distinct = row_frame_ops.distinct
4407+
count_table = row_frame_ops.count_table
44064408

44074409
def unwind(self, expr: Any, as_: str = "value") -> "Plottable":
44084410
"""Vectorized UNWIND for column or literal list expressions."""
@@ -4653,14 +4655,15 @@ def bind(self) -> "Plottable":
46534655
"unwind": RowPipelineMixin.unwind,
46544656
"group_by": RowPipelineMixin.group_by,
46554657
"drop_cols": RowPipelineMixin.drop_cols,
4658+
"count_table": RowPipelineMixin.count_table,
46564659
}
46574660

46584661

46594662
# Row-pipeline ops with native polars implementations (frame-level only — no
46604663
# cypher expression engine). Everything else falls back through the guard below
46614664
# until lowered natively. See plans/gfql-polars-engine (Phase 2).
46624665
_POLARS_NATIVE_ROW_PIPELINE_CALLS = frozenset(
4663-
{"rows", "skip", "limit", "distinct", "drop_cols"}
4666+
{"rows", "skip", "limit", "distinct", "drop_cols", "count_table"}
46644667
)
46654668

46664669

graphistry/tests/compute/gfql/test_conformance_ledger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def test_known_uncovered_function_reasons_are_nonempty():
271271
"order_by": "row sort; native on polars chain, NIE via call()/DAG executor; not asserted via call(). TODO.",
272272
"unwind": "list explode; native on polars chain, NIE via call()/DAG executor; not asserted via call(). TODO.",
273273
"group_by": "grouped aggregation; native on polars chain, NIE via call()/DAG executor; not asserted via call(). TODO.",
274+
"count_table": "count(*) short-circuit fast path (table height / source-mask sum); native frame op emitted by the cypher lowering, exercised as a labeled subject via _ROW_OP_CASES + the count_all_nodes/edges cypher cases, not via a direct call() consistency label. TODO.",
274275
"semi_apply_mark": "correlated EXISTS-mark; row-pipeline op honest-NIE under polars; not asserted. TODO.",
275276
"anti_semi_apply": "anti-semi correlated filter; row-pipeline op honest-NIE under polars; not asserted. TODO.",
276277
"join_apply": "correlated row join; row-pipeline op honest-NIE under polars; not asserted. TODO.",

graphistry/tests/compute/gfql/test_engine_polars_conformance_matrix.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ def _cypher_expression_queries():
262262
("count_distinct_grouped", "MATCH (n) RETURN n.flag AS k, count(DISTINCT n.num) AS cd"),
263263
("count_distinct_all", "MATCH (n) RETURN count(DISTINCT n.flag) AS cd"),
264264
("count_grouped", "MATCH (n) RETURN n.flag AS k, count(n.num) AS c"),
265+
# count(*) short-circuit (count_table fast path): whole-graph node / edge counts.
266+
("count_all_nodes", "MATCH (n) RETURN count(*) AS c"),
267+
("count_all_edges", "MATCH ()-[r]->() RETURN count(*) AS c"),
265268
("size_str", "MATCH (n) RETURN n.id AS id, size(n.name) AS sz"),
266269
("substring3", "MATCH (n) RETURN n.id AS id, substring(n.name, 0, 4) AS sub"),
267270
("substring2", "MATCH (n) RETURN n.id AS id, substring(n.name, 2) AS sub"),
@@ -506,6 +509,7 @@ def _rowop_exercised():
506509
"with_", "unwind",
507510
"rows", "skip", "limit", "distinct", "drop_cols",
508511
"order_by", "select", "return_", "where_rows", "group_by",
512+
"count_table",
509513
}
510514

511515

@@ -1086,6 +1090,7 @@ def test_conformance_unwind_chain_vs_cypher_consistent():
10861090
("where_rows", [n(), rows(), call("where_rows", {"expr": "num > 50"})]),
10871091
("group_by", [n(), rows(), call("group_by", {"keys": ["flag"],
10881092
"aggregations": [("c", "count"), ("s", "sum", "num")]})]),
1093+
("count_table", [n(), rows(), call("count_table", {"table": "nodes", "alias": "cnt"})]),
10891094
]
10901095

10911096

0 commit comments

Comments
 (0)