Skip to content

Commit 44953ee

Browse files
Peng RenCopilot
andcommitted
Fix the issue on count in a sql
Co-authored-by: Copilot <copilot@github.com>
1 parent d4bdfed commit 44953ee

4 files changed

Lines changed: 407 additions & 0 deletions

File tree

pymongosql/sql/builder.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
import json
23
import logging
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
@@ -112,6 +113,11 @@ def build_from_parse_result(
112113
@staticmethod
113114
def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
114115
"""Build a query execution plan from SELECT parsing."""
116+
117+
# Auto-generate aggregate pipeline for SQL aggregate functions (COUNT, SUM, etc.)
118+
if getattr(parse_result, "aggregate_functions", None):
119+
return ExecutionPlanBuilder._build_sql_aggregate_plan(parse_result)
120+
115121
builder = BuilderFactory.create_query_builder().collection(parse_result.collection)
116122

117123
builder.filter(parse_result.filter_conditions).project(parse_result.projection).column_aliases(
@@ -128,6 +134,60 @@ def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
128134
plan = builder.build()
129135
return plan
130136

137+
@staticmethod
138+
def _build_sql_aggregate_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
139+
"""Build an aggregate execution plan from SQL aggregate functions like COUNT(*), SUM(), etc."""
140+
_FUNCTION_TO_ACCUMULATOR = {
141+
"COUNT": "$sum",
142+
"SUM": "$sum",
143+
"AVG": "$avg",
144+
"MIN": "$min",
145+
"MAX": "$max",
146+
}
147+
148+
builder = BuilderFactory.create_query_builder().collection(parse_result.collection)
149+
150+
pipeline = []
151+
152+
# Add $match stage if there are filter conditions (from WHERE clause)
153+
if parse_result.filter_conditions:
154+
pipeline.append({"$match": parse_result.filter_conditions})
155+
156+
# Build $group stage from aggregate functions
157+
group_stage = {"_id": None}
158+
for func_info in parse_result.aggregate_functions:
159+
alias = func_info["alias"]
160+
func_name = func_info["function"]
161+
arg = func_info["argument"]
162+
accumulator = _FUNCTION_TO_ACCUMULATOR[func_name]
163+
164+
if func_name == "COUNT":
165+
group_stage[alias] = {accumulator: 1}
166+
else:
167+
group_stage[alias] = {accumulator: f"${arg}"}
168+
169+
pipeline.append({"$group": group_stage})
170+
171+
# Add $project to exclude _id
172+
project_stage = {"_id": 0}
173+
for func_info in parse_result.aggregate_functions:
174+
project_stage[func_info["alias"]] = 1
175+
pipeline.append({"$project": project_stage})
176+
177+
# Configure the execution plan as an aggregate query
178+
builder._execution_plan.is_aggregate_query = True
179+
builder._execution_plan.aggregate_pipeline = json.dumps(pipeline)
180+
builder._execution_plan.aggregate_options = json.dumps({})
181+
182+
# Set projection for ResultSet description
183+
agg_projection = {}
184+
for func_info in parse_result.aggregate_functions:
185+
agg_projection[func_info["alias"]] = 1
186+
builder._execution_plan.projection_stage = agg_projection
187+
188+
plan = builder.build()
189+
return plan
190+
131191
@staticmethod
132192
def _build_insert_plan(parse_result: "InsertParseResult") -> "InsertExecutionPlan":
133193
"""Build an INSERT execution plan from INSERT parsing."""

pymongosql/sql/query_handler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class QueryParseResult:
3232
aggregate_pipeline: Optional[str] = None # JSON string representation of pipeline
3333
aggregate_options: Optional[str] = None # JSON string representation of options
3434

35+
# SQL aggregate functions detected in SELECT (COUNT, SUM, AVG, MIN, MAX)
36+
aggregate_functions: List[Dict[str, Any]] = field(default_factory=list)
37+
3538
# Subquery info (for wrapped subqueries, e.g., Superset outering)
3639
subquery_plan: Optional[Any] = None
3740
subquery_alias: Optional[str] = None
@@ -111,6 +114,12 @@ def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]:
111114
class SelectHandler(BaseHandler, ContextUtilsMixin):
112115
"""Handles SELECT statement parsing"""
113116

117+
# Pattern to detect SQL aggregate functions: COUNT(*), SUM(field), AVG(field), etc.
118+
_AGGREGATE_PATTERN = re.compile(
119+
r"^(COUNT|SUM|AVG|MIN|MAX)\s*\(\s*(\*|\w+(?:\.\w+)*)\s*\)$",
120+
re.IGNORECASE,
121+
)
122+
114123
def can_handle(self, ctx: Any) -> bool:
115124
"""Check if this is a select context"""
116125
return hasattr(ctx, "projectionItems")
@@ -122,6 +131,21 @@ def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "Q
122131
if hasattr(ctx, "projectionItems") and ctx.projectionItems():
123132
for item in ctx.projectionItems().projectionItem():
124133
field_name, alias = self._extract_field_and_alias(item)
134+
135+
# Check if this is an aggregate function (COUNT, SUM, etc.)
136+
agg_match = self._AGGREGATE_PATTERN.match(field_name)
137+
if agg_match:
138+
func_name = agg_match.group(1).upper()
139+
func_arg = agg_match.group(2)
140+
parse_result.aggregate_functions.append(
141+
{
142+
"function": func_name,
143+
"argument": func_arg,
144+
"alias": alias or field_name,
145+
}
146+
)
147+
continue
148+
125149
# Use MongoDB standard projection format: {field: 1} to include field
126150
projection[field_name] = 1
127151
# Store alias if present

tests/test_cursor_aggregate.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,165 @@ def test_aggregate_collection_name_with_hyphen(self, conn):
354354
customer_type_idx = col_names.index("customer_type")
355355
for row in rows:
356356
assert row[customer_type_idx] == "premium", "All rows should have customer_type='premium'"
357+
358+
359+
class TestSqlGroupFunctions:
360+
"""Test SQL aggregate functions (COUNT, AVG, MIN, MAX, SUM) translated to MongoDB pipelines."""
361+
362+
def test_count_star(self, conn):
363+
"""SELECT COUNT(*) AS total FROM users → should return document count"""
364+
cursor = conn.cursor()
365+
cursor.execute("SELECT COUNT(*) AS total FROM users")
366+
367+
rows = cursor.fetchall()
368+
assert len(rows) == 1
369+
370+
col_names = [desc[0] for desc in cursor.description]
371+
assert "total" in col_names
372+
373+
total_idx = col_names.index("total")
374+
assert rows[0][total_idx] == 22 # 22 users in test data
375+
376+
def test_count_star_no_alias(self, conn):
377+
"""SELECT COUNT(*) FROM users → column name defaults to COUNT(*)"""
378+
cursor = conn.cursor()
379+
cursor.execute("SELECT COUNT(*) FROM users")
380+
381+
rows = cursor.fetchall()
382+
assert len(rows) == 1
383+
384+
col_names = [desc[0] for desc in cursor.description]
385+
assert "COUNT(*)" in col_names
386+
assert rows[0][col_names.index("COUNT(*)")] == 22
387+
388+
def test_count_star_with_where(self, conn):
389+
"""SELECT COUNT(*) AS total FROM users WHERE age > 30 → filtered count"""
390+
cursor = conn.cursor()
391+
cursor.execute("SELECT COUNT(*) AS total FROM users WHERE age > 30")
392+
393+
rows = cursor.fetchall()
394+
assert len(rows) == 1
395+
396+
col_names = [desc[0] for desc in cursor.description]
397+
total = rows[0][col_names.index("total")]
398+
assert isinstance(total, (int, float))
399+
assert total > 0
400+
assert total < 22 # Must be less than total users
401+
402+
def test_avg(self, conn):
403+
"""SELECT AVG(age) AS avg_age FROM users"""
404+
cursor = conn.cursor()
405+
cursor.execute("SELECT AVG(age) AS avg_age FROM users")
406+
407+
rows = cursor.fetchall()
408+
assert len(rows) == 1
409+
410+
col_names = [desc[0] for desc in cursor.description]
411+
avg_age = rows[0][col_names.index("avg_age")]
412+
assert isinstance(avg_age, (int, float))
413+
assert 24 <= avg_age <= 45 # Must be within the age range
414+
415+
def test_min(self, conn):
416+
"""SELECT MIN(age) AS youngest FROM users"""
417+
cursor = conn.cursor()
418+
cursor.execute("SELECT MIN(age) AS youngest FROM users")
419+
420+
rows = cursor.fetchall()
421+
assert len(rows) == 1
422+
423+
col_names = [desc[0] for desc in cursor.description]
424+
youngest = rows[0][col_names.index("youngest")]
425+
assert youngest == 24 # Min age in test data
426+
427+
def test_max(self, conn):
428+
"""SELECT MAX(age) AS oldest FROM users"""
429+
cursor = conn.cursor()
430+
cursor.execute("SELECT MAX(age) AS oldest FROM users")
431+
432+
rows = cursor.fetchall()
433+
assert len(rows) == 1
434+
435+
col_names = [desc[0] for desc in cursor.description]
436+
oldest = rows[0][col_names.index("oldest")]
437+
assert oldest == 45 # Max age in test data
438+
439+
def test_sum(self, conn):
440+
"""SELECT SUM(price) AS total_price FROM products"""
441+
cursor = conn.cursor()
442+
cursor.execute("SELECT SUM(price) AS total_price FROM products")
443+
444+
rows = cursor.fetchall()
445+
assert len(rows) == 1
446+
447+
col_names = [desc[0] for desc in cursor.description]
448+
total_price = rows[0][col_names.index("total_price")]
449+
assert isinstance(total_price, (int, float))
450+
assert total_price > 0
451+
452+
def test_multiple_aggregates(self, conn):
453+
"""SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"""
454+
cursor = conn.cursor()
455+
cursor.execute(
456+
"SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"
457+
)
458+
459+
rows = cursor.fetchall()
460+
assert len(rows) == 1
461+
462+
col_names = [desc[0] for desc in cursor.description]
463+
row = rows[0]
464+
465+
cnt = row[col_names.index("cnt")]
466+
cheapest = row[col_names.index("cheapest")]
467+
priciest = row[col_names.index("priciest")]
468+
avg_price = row[col_names.index("avg_price")]
469+
470+
assert cnt == 50
471+
assert cheapest <= avg_price <= priciest
472+
473+
def test_min_max_on_products(self, conn):
474+
"""SELECT MIN(price) AS low, MAX(price) AS high FROM products"""
475+
cursor = conn.cursor()
476+
cursor.execute("SELECT MIN(price) AS low, MAX(price) AS high FROM products")
477+
478+
rows = cursor.fetchall()
479+
assert len(rows) == 1
480+
481+
col_names = [desc[0] for desc in cursor.description]
482+
low = rows[0][col_names.index("low")]
483+
high = rows[0][col_names.index("high")]
484+
assert low < high
485+
486+
def test_count_with_and_or_conditions(self, conn):
487+
"""SELECT COUNT(*) AS cnt FROM users WHERE (active = true AND age > 30) OR age < 25"""
488+
cursor = conn.cursor()
489+
490+
# AND-only: active users over 30
491+
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE active = true AND age > 30")
492+
rows = cursor.fetchall()
493+
col_names = [desc[0] for desc in cursor.description]
494+
and_count = rows[0][col_names.index("cnt")]
495+
assert isinstance(and_count, (int, float))
496+
assert and_count > 0
497+
assert and_count < 22
498+
499+
# OR-only: very young or very old
500+
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40")
501+
rows = cursor.fetchall()
502+
col_names = [desc[0] for desc in cursor.description]
503+
or_count = rows[0][col_names.index("cnt")]
504+
assert isinstance(or_count, (int, float))
505+
assert or_count > 0
506+
assert or_count < 22
507+
508+
# Three AND conditions
509+
cursor.execute(
510+
"SELECT COUNT(*) AS cnt, AVG(age) AS avg_age FROM users " "WHERE active = true AND age >= 25 AND age <= 40"
511+
)
512+
rows = cursor.fetchall()
513+
col_names = [desc[0] for desc in cursor.description]
514+
cnt = rows[0][col_names.index("cnt")]
515+
avg_age = rows[0][col_names.index("avg_age")]
516+
assert cnt > 0
517+
assert cnt < 22
518+
assert 25 <= avg_age <= 40

0 commit comments

Comments
 (0)