|
13 | 13 | ASTObject, |
14 | 14 | ASTNode, |
15 | 15 | anti_semi_apply, |
| 16 | + count_table, |
16 | 17 | distinct, |
17 | 18 | drop_cols, |
18 | 19 | e_forward, |
@@ -2308,6 +2309,60 @@ def _match_relationship_count(clause: MatchClause) -> int: |
2308 | 2309 | return sum(1 for element in _match_pattern_elements(clause) if isinstance(element, RelationshipPattern)) |
2309 | 2310 |
|
2310 | 2311 |
|
| 2312 | +def _is_pure_count_star_shortcircuit( |
| 2313 | + *, |
| 2314 | + aggregate_specs: Sequence[_AggregateSpec], |
| 2315 | + pre_items: Sequence[Tuple[str, Any]], |
| 2316 | + row_steps: Sequence[ASTObject], |
| 2317 | + query: CypherQuery, |
| 2318 | + binding_row_aliases: AbstractSet[str], |
| 2319 | + relationship_count: int, |
| 2320 | + active_match_alias: Optional[str], |
| 2321 | + alias_targets: Mapping[str, ASTObject], |
| 2322 | +) -> bool: |
| 2323 | + """True when a RETURN is exactly ``count(*)`` over a single node/edge scan. |
| 2324 | +
|
| 2325 | + Guards the count_table fast path (skip the full-frame materialize + constant- |
| 2326 | + key group_by): the count then equals the height (or source-mask sum) of the |
| 2327 | + active table. Requires a lone non-DISTINCT ``count(*)`` with no group keys, |
| 2328 | + post-aggregate exprs, row-level WHERE, UNWIND, or multi-relationship binding, |
| 2329 | + and a plain ``rows(table=nodes|edges[, source])`` as the only prior step. |
| 2330 | + Sound only for a pure node scan (``relationship_count == 0``) or a single |
| 2331 | + relationship counted on its edge alias (``relationship_count == 1`` with an |
| 2332 | + ``ASTEdge`` active alias) — exactly the cases |
| 2333 | + ``_reject_unsound_relationship_multiplicity_aggregates`` permits; any other |
| 2334 | + shape (node-alias-over-relationship, multi-hop paths) falls through to the |
| 2335 | + general aggregate path (which counts bindings or rejects as unsound). |
| 2336 | + """ |
| 2337 | + if len(aggregate_specs) != 1: |
| 2338 | + return False |
| 2339 | + agg = aggregate_specs[0] |
| 2340 | + if agg.func != "count" or agg.expr_text is not None or agg.distinct: |
| 2341 | + return False |
| 2342 | + if pre_items or binding_row_aliases or query.unwinds: |
| 2343 | + return False |
| 2344 | + # Exactly the initial rows() step: any row-level WHERE or UNWIND would have |
| 2345 | + # appended further steps, so len == 1 proves the count is over the raw scan. |
| 2346 | + if len(row_steps) != 1: |
| 2347 | + return False |
| 2348 | + base = row_steps[0] |
| 2349 | + if not (isinstance(base, ASTCall) and base.function == "rows"): |
| 2350 | + return False |
| 2351 | + if base.params.get("table") not in ("nodes", "edges"): |
| 2352 | + return False |
| 2353 | + if base.params.get("binding_ops") is not None or base.params.get("alias_endpoints") is not None: |
| 2354 | + return False |
| 2355 | + if relationship_count == 0: |
| 2356 | + return True |
| 2357 | + if ( |
| 2358 | + relationship_count == 1 |
| 2359 | + and active_match_alias is not None |
| 2360 | + and isinstance(alias_targets.get(active_match_alias), ASTEdge) |
| 2361 | + ): |
| 2362 | + return True |
| 2363 | + return False |
| 2364 | + |
| 2365 | + |
2311 | 2366 | def _reject_unsound_relationship_multiplicity_aggregates_common( |
2312 | 2367 | *, |
2313 | 2368 | aggregate_specs: Sequence[_AggregateSpec], |
@@ -6573,6 +6628,35 @@ def _lower_general_row_projection( |
6573 | 6628 | if len(pre_items) > 0: |
6574 | 6629 | row_steps.append(with_(pre_items, extend=bindings_row_path)) |
6575 | 6630 | row_steps.append(group_by(key_names, aggregations, key_prefixes=alias_key_prefixes)) |
| 6631 | + elif _is_pure_count_star_shortcircuit( |
| 6632 | + aggregate_specs=aggregate_specs, |
| 6633 | + pre_items=pre_items, |
| 6634 | + row_steps=row_steps, |
| 6635 | + query=query, |
| 6636 | + binding_row_aliases=binding_row_aliases, |
| 6637 | + relationship_count=relationship_count, |
| 6638 | + active_match_alias=active_match_alias, |
| 6639 | + alias_targets=alias_targets, |
| 6640 | + ): |
| 6641 | + # Fast path: count_table reads the scanned table's height (or the |
| 6642 | + # source-alias mask sum) with one reduction — no full-frame |
| 6643 | + # materialize + constant-key group_by. It replaces the sole rows() |
| 6644 | + # step and produces the same ``count(*)`` column the group_by would, |
| 6645 | + # so the trailing identity projection (elif not key_names, below) is |
| 6646 | + # unchanged. empty_result_row is belt-and-suspenders here (count_table |
| 6647 | + # always emits a 1-row result), kept for parity with the group_by path. |
| 6648 | + base_rows = cast(ASTCall, row_steps[0]) |
| 6649 | + count_alias = aggregate_specs[0].output_name |
| 6650 | + row_steps = [ |
| 6651 | + count_table( |
| 6652 | + table=cast(str, base_rows.params.get("table", "nodes")), |
| 6653 | + source=cast(Optional[str], base_rows.params.get("source")), |
| 6654 | + alias=count_alias, |
| 6655 | + ) |
| 6656 | + ] |
| 6657 | + available_columns = {count_alias} |
| 6658 | + empty_aggregate_row = _empty_aggregate_row(aggregate_specs) |
| 6659 | + empty_result_row = empty_aggregate_row |
6576 | 6660 | else: |
6577 | 6661 | global_key = _fresh_temp_name(temp_names, "__cypher_group__") |
6578 | 6662 | row_steps.append(with_([(global_key, 1)] + pre_items)) |
|
0 commit comments