Skip to content

Commit d920320

Browse files
authored
Fix/count not working (#34)
* Fix the issue on count in a sql * Update README
1 parent d4bdfed commit d920320

5 files changed

Lines changed: 468 additions & 14 deletions

File tree

README.md

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo
2828
- **PartiQL-based SQL Syntax**: Built on [PartiQL](https://partiql.org/tutorial.html) (SQL for semi-structured data), enabling seamless SQL querying of nested and hierarchical MongoDB documents
2929
- **Nested Structure Support**: Query and filter deeply nested fields and arrays within MongoDB documents using standard SQL syntax
3030
- **MongoDB Aggregate Pipeline Support**: Execute native MongoDB aggregation pipelines using SQL-like syntax with `aggregate()` function
31+
- **SQL Aggregate Functions**: `COUNT(*)`, `SUM`, `AVG`, `MIN`, `MAX` translated to MongoDB aggregation pipelines
3132
- **SQLAlchemy Integration**: Complete ORM and Core support with dedicated MongoDB dialect
3233
- **SQL Query Support**: SELECT statements with WHERE conditions, field selection, and aliases
3334
- **DML Support**: Full support for INSERT, UPDATE, and DELETE operations using PartiQL syntax
@@ -87,6 +88,7 @@ pip install -e .
8788
- [WHERE Clauses](#where-clauses)
8889
- [Nested Field Support](#nested-field-support)
8990
- [Sorting and Limiting](#sorting-and-limiting)
91+
- [SQL Aggregate Functions](#sql-aggregate-functions)
9092
- [MongoDB Aggregate Function](#mongodb-aggregate-function)
9193
- [INSERT Statements](#insert-statements)
9294
- [UPDATE Statements](#update-statements)
@@ -293,6 +295,42 @@ Both functions:
293295
- **LIMIT**: `LIMIT 10`
294296
- **Combined**: `ORDER BY created_at DESC LIMIT 5`
295297

298+
### SQL Aggregate Functions
299+
300+
PyMongoSQL supports standard SQL aggregate functions that are automatically translated into MongoDB aggregation pipelines.
301+
302+
**Supported Functions**: `COUNT(*)`, `SUM(field)`, `AVG(field)`, `MIN(field)`, `MAX(field)`
303+
304+
**Basic Count**
305+
306+
```python
307+
cursor.execute("SELECT COUNT(*) AS total FROM users")
308+
row = cursor.fetchone()
309+
print(f"Total users: {row[0]}")
310+
```
311+
312+
**Multiple Aggregates**
313+
314+
```python
315+
cursor.execute(
316+
"SELECT COUNT(*) AS cnt, AVG(price) AS avg_price, MIN(price) AS cheapest, MAX(price) AS priciest FROM products"
317+
)
318+
```
319+
320+
**Aggregate with WHERE**
321+
322+
```python
323+
cursor.execute("SELECT COUNT(*) AS total FROM users WHERE active = true AND age > 30")
324+
```
325+
326+
**Aggregate with OR Conditions**
327+
328+
```python
329+
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40")
330+
```
331+
332+
**Note:** Aggregate functions are translated into a MongoDB `aggregate()` pipeline with `$match` (from WHERE), `$group` (with accumulators), and `$project` stages. `COUNT(*)` maps to `{$sum: 1}`, while `SUM`, `AVG`, `MIN`, and `MAX` map to their corresponding MongoDB accumulators (`$sum`, `$avg`, `$min`, `$max`).
333+
296334
### MongoDB Aggregate Function
297335

298336
PyMongoSQL supports executing native MongoDB aggregation pipelines using SQL-like syntax with the `aggregate()` function. This allows you to leverage MongoDB's powerful aggregation framework while maintaining SQL-style query patterns.
@@ -608,20 +646,24 @@ The table below shows how PyMongoSQL translates SQL operations into MongoDB comm
608646

609647
### SQL Operations to MongoDB Commands
610648

611-
| SQL Operation | MongoDB Command | Equivalent PyMongo Method |
612-
|---|---|---|
613-
| `SELECT ... FROM col` | `{find: col, projection: {...}}` | `db.command("find", ...)` |
614-
| `SELECT ... FROM col WHERE ...` | `{find: col, filter: {...}}` | `db.command("find", ...)` |
615-
| `SELECT ... ORDER BY col ASC/DESC` | `{find: ..., sort: {col: 1/-1}}` | `db.command("find", ...)` |
616-
| `SELECT ... LIMIT n` | `{find: ..., limit: n}` | `db.command("find", ...)` |
617-
| `SELECT ... OFFSET n` | `{find: ..., skip: n}` | `db.command("find", ...)` |
618-
| `SELECT * FROM col.aggregate(...)` | `collection.aggregate(pipeline)` | `collection.aggregate()` |
619-
| `INSERT INTO col ...` | `{insert: col, documents: [...]}` | `db.command("insert", ...)` |
620-
| `UPDATE col SET ... WHERE ...` | `{update: col, updates: [{q: filter, u: {$set: {...}}, multi: true}]}` | `db.command("update", ...)` |
621-
| `DELETE FROM col WHERE ...` | `{delete: col, deletes: [{q: filter, limit: 0}]}` | `db.command("delete", ...)` |
622-
| `CREATE VIEW v ON col AS '[...]'` | `{create: v, viewOn: col, pipeline: [...]}` | `db.command("create", ...)` |
623-
| `DROP VIEW v` | `{drop: v}` | `db.command("drop", ...)` |
624-
| `EXPLAIN <select>` | `{explain: <find\|aggregate cmd>, verbosity: "queryPlanner"}` | `db.command("explain", ...)` |
649+
| SQL Operation | MongoDB Command |
650+
|---|---|
651+
| `SELECT ... FROM col` | `{find: col, projection: {...}}` |
652+
| `SELECT ... FROM col WHERE ...` | `{find: col, filter: {...}}` |
653+
| `SELECT ... ORDER BY col ASC/DESC` | `{find: ..., sort: {col: 1/-1}}` |
654+
| `SELECT ... LIMIT n` | `{find: ..., limit: n}` |
655+
| `SELECT ... OFFSET n` | `{find: ..., skip: n}` |
656+
| `SELECT COUNT(*) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
657+
| `SELECT AVG(field) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
658+
| `SELECT MIN/MAX(field) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
659+
| `SELECT SUM(field) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
660+
| `SELECT * FROM col.aggregate(...)` | `collection.aggregate(pipeline)` |
661+
| `INSERT INTO col ...` | `{insert: col, documents: [...]}` |
662+
| `UPDATE col SET ... WHERE ...` | `{update: col, updates: [{q: filter, u: {$set: {...}}, multi: true}]}` |
663+
| `DELETE FROM col WHERE ...` | `{delete: col, deletes: [{q: filter, limit: 0}]}` |
664+
| `CREATE VIEW v ON col AS '[...]'` | `{create: v, viewOn: col, pipeline: [...]}` |
665+
| `DROP VIEW v` | `{drop: v}` |
666+
| `EXPLAIN <select>` | `{explain: <find\|aggregate cmd>, verbosity: "queryPlanner"}` |
625667

626668
### SQL Clauses to MongoDB Query Components
627669

@@ -635,6 +677,11 @@ The table below shows how PyMongoSQL translates SQL operations into MongoDB comm
635677
| `ORDER BY col DESC` | `sort: {col: -1}` | Descending sort |
636678
| `LIMIT n` | `limit: n` | Restrict result count |
637679
| `OFFSET n` | `skip: n` | Skip first n results |
680+
| `COUNT(*)` | `{$group: {_id: null, count: {$sum: 1}}}` | Document count |
681+
| `SUM(field)` | `{$group: {_id: null, sum: {$sum: "$field"}}}` | Field sum |
682+
| `AVG(field)` | `{$group: {_id: null, avg: {$avg: "$field"}}}` | Field average |
683+
| `MIN(field)` | `{$group: {_id: null, min: {$min: "$field"}}}` | Field minimum |
684+
| `MAX(field)` | `{$group: {_id: null, max: {$max: "$field"}}}` | Field maximum |
638685

639686
### WHERE Operators to MongoDB Filter Operators
640687

pymongosql/sql/builder.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
import json
23
import logging
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
@@ -112,6 +113,11 @@ def build_from_parse_result(
112113
@staticmethod
113114
def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
114115
"""Build a query execution plan from SELECT parsing."""
116+
117+
# Auto-generate aggregate pipeline for SQL aggregate functions (COUNT, SUM, etc.)
118+
if getattr(parse_result, "aggregate_functions", None):
119+
return ExecutionPlanBuilder._build_sql_aggregate_plan(parse_result)
120+
115121
builder = BuilderFactory.create_query_builder().collection(parse_result.collection)
116122

117123
builder.filter(parse_result.filter_conditions).project(parse_result.projection).column_aliases(
@@ -128,6 +134,60 @@ def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
128134
plan = builder.build()
129135
return plan
130136

137+
@staticmethod
138+
def _build_sql_aggregate_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
139+
"""Build an aggregate execution plan from SQL aggregate functions like COUNT(*), SUM(), etc."""
140+
_FUNCTION_TO_ACCUMULATOR = {
141+
"COUNT": "$sum",
142+
"SUM": "$sum",
143+
"AVG": "$avg",
144+
"MIN": "$min",
145+
"MAX": "$max",
146+
}
147+
148+
builder = BuilderFactory.create_query_builder().collection(parse_result.collection)
149+
150+
pipeline = []
151+
152+
# Add $match stage if there are filter conditions (from WHERE clause)
153+
if parse_result.filter_conditions:
154+
pipeline.append({"$match": parse_result.filter_conditions})
155+
156+
# Build $group stage from aggregate functions
157+
group_stage = {"_id": None}
158+
for func_info in parse_result.aggregate_functions:
159+
alias = func_info["alias"]
160+
func_name = func_info["function"]
161+
arg = func_info["argument"]
162+
accumulator = _FUNCTION_TO_ACCUMULATOR[func_name]
163+
164+
if func_name == "COUNT":
165+
group_stage[alias] = {accumulator: 1}
166+
else:
167+
group_stage[alias] = {accumulator: f"${arg}"}
168+
169+
pipeline.append({"$group": group_stage})
170+
171+
# Add $project to exclude _id
172+
project_stage = {"_id": 0}
173+
for func_info in parse_result.aggregate_functions:
174+
project_stage[func_info["alias"]] = 1
175+
pipeline.append({"$project": project_stage})
176+
177+
# Configure the execution plan as an aggregate query
178+
builder._execution_plan.is_aggregate_query = True
179+
builder._execution_plan.aggregate_pipeline = json.dumps(pipeline)
180+
builder._execution_plan.aggregate_options = json.dumps({})
181+
182+
# Set projection for ResultSet description
183+
agg_projection = {}
184+
for func_info in parse_result.aggregate_functions:
185+
agg_projection[func_info["alias"]] = 1
186+
builder._execution_plan.projection_stage = agg_projection
187+
188+
plan = builder.build()
189+
return plan
190+
131191
@staticmethod
132192
def _build_insert_plan(parse_result: "InsertParseResult") -> "InsertExecutionPlan":
133193
"""Build an INSERT execution plan from INSERT parsing."""

pymongosql/sql/query_handler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class QueryParseResult:
3232
aggregate_pipeline: Optional[str] = None # JSON string representation of pipeline
3333
aggregate_options: Optional[str] = None # JSON string representation of options
3434

35+
# SQL aggregate functions detected in SELECT (COUNT, SUM, AVG, MIN, MAX)
36+
aggregate_functions: List[Dict[str, Any]] = field(default_factory=list)
37+
3538
# Subquery info (for wrapped subqueries, e.g., Superset outering)
3639
subquery_plan: Optional[Any] = None
3740
subquery_alias: Optional[str] = None
@@ -111,6 +114,12 @@ def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]:
111114
class SelectHandler(BaseHandler, ContextUtilsMixin):
112115
"""Handles SELECT statement parsing"""
113116

117+
# Pattern to detect SQL aggregate functions: COUNT(*), SUM(field), AVG(field), etc.
118+
_AGGREGATE_PATTERN = re.compile(
119+
r"^(COUNT|SUM|AVG|MIN|MAX)\s*\(\s*(\*|\w+(?:\.\w+)*)\s*\)$",
120+
re.IGNORECASE,
121+
)
122+
114123
def can_handle(self, ctx: Any) -> bool:
115124
"""Check if this is a select context"""
116125
return hasattr(ctx, "projectionItems")
@@ -122,6 +131,21 @@ def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "Q
122131
if hasattr(ctx, "projectionItems") and ctx.projectionItems():
123132
for item in ctx.projectionItems().projectionItem():
124133
field_name, alias = self._extract_field_and_alias(item)
134+
135+
# Check if this is an aggregate function (COUNT, SUM, etc.)
136+
agg_match = self._AGGREGATE_PATTERN.match(field_name)
137+
if agg_match:
138+
func_name = agg_match.group(1).upper()
139+
func_arg = agg_match.group(2)
140+
parse_result.aggregate_functions.append(
141+
{
142+
"function": func_name,
143+
"argument": func_arg,
144+
"alias": alias or field_name,
145+
}
146+
)
147+
continue
148+
125149
# Use MongoDB standard projection format: {field: 1} to include field
126150
projection[field_name] = 1
127151
# Store alias if present

tests/test_cursor_aggregate.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,165 @@ def test_aggregate_collection_name_with_hyphen(self, conn):
354354
customer_type_idx = col_names.index("customer_type")
355355
for row in rows:
356356
assert row[customer_type_idx] == "premium", "All rows should have customer_type='premium'"
357+
358+
359+
class TestSqlGroupFunctions:
360+
"""Test SQL aggregate functions (COUNT, AVG, MIN, MAX, SUM) translated to MongoDB pipelines."""
361+
362+
def test_count_star(self, conn):
363+
"""SELECT COUNT(*) AS total FROM users → should return document count"""
364+
cursor = conn.cursor()
365+
cursor.execute("SELECT COUNT(*) AS total FROM users")
366+
367+
rows = cursor.fetchall()
368+
assert len(rows) == 1
369+
370+
col_names = [desc[0] for desc in cursor.description]
371+
assert "total" in col_names
372+
373+
total_idx = col_names.index("total")
374+
assert rows[0][total_idx] == 22 # 22 users in test data
375+
376+
def test_count_star_no_alias(self, conn):
377+
"""SELECT COUNT(*) FROM users → column name defaults to COUNT(*)"""
378+
cursor = conn.cursor()
379+
cursor.execute("SELECT COUNT(*) FROM users")
380+
381+
rows = cursor.fetchall()
382+
assert len(rows) == 1
383+
384+
col_names = [desc[0] for desc in cursor.description]
385+
assert "COUNT(*)" in col_names
386+
assert rows[0][col_names.index("COUNT(*)")] == 22
387+
388+
def test_count_star_with_where(self, conn):
389+
"""SELECT COUNT(*) AS total FROM users WHERE age > 30 → filtered count"""
390+
cursor = conn.cursor()
391+
cursor.execute("SELECT COUNT(*) AS total FROM users WHERE age > 30")
392+
393+
rows = cursor.fetchall()
394+
assert len(rows) == 1
395+
396+
col_names = [desc[0] for desc in cursor.description]
397+
total = rows[0][col_names.index("total")]
398+
assert isinstance(total, (int, float))
399+
assert total > 0
400+
assert total < 22 # Must be less than total users
401+
402+
def test_avg(self, conn):
403+
"""SELECT AVG(age) AS avg_age FROM users"""
404+
cursor = conn.cursor()
405+
cursor.execute("SELECT AVG(age) AS avg_age FROM users")
406+
407+
rows = cursor.fetchall()
408+
assert len(rows) == 1
409+
410+
col_names = [desc[0] for desc in cursor.description]
411+
avg_age = rows[0][col_names.index("avg_age")]
412+
assert isinstance(avg_age, (int, float))
413+
assert 24 <= avg_age <= 45 # Must be within the age range
414+
415+
def test_min(self, conn):
416+
"""SELECT MIN(age) AS youngest FROM users"""
417+
cursor = conn.cursor()
418+
cursor.execute("SELECT MIN(age) AS youngest FROM users")
419+
420+
rows = cursor.fetchall()
421+
assert len(rows) == 1
422+
423+
col_names = [desc[0] for desc in cursor.description]
424+
youngest = rows[0][col_names.index("youngest")]
425+
assert youngest == 24 # Min age in test data
426+
427+
def test_max(self, conn):
428+
"""SELECT MAX(age) AS oldest FROM users"""
429+
cursor = conn.cursor()
430+
cursor.execute("SELECT MAX(age) AS oldest FROM users")
431+
432+
rows = cursor.fetchall()
433+
assert len(rows) == 1
434+
435+
col_names = [desc[0] for desc in cursor.description]
436+
oldest = rows[0][col_names.index("oldest")]
437+
assert oldest == 45 # Max age in test data
438+
439+
def test_sum(self, conn):
440+
"""SELECT SUM(price) AS total_price FROM products"""
441+
cursor = conn.cursor()
442+
cursor.execute("SELECT SUM(price) AS total_price FROM products")
443+
444+
rows = cursor.fetchall()
445+
assert len(rows) == 1
446+
447+
col_names = [desc[0] for desc in cursor.description]
448+
total_price = rows[0][col_names.index("total_price")]
449+
assert isinstance(total_price, (int, float))
450+
assert total_price > 0
451+
452+
def test_multiple_aggregates(self, conn):
453+
"""SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"""
454+
cursor = conn.cursor()
455+
cursor.execute(
456+
"SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"
457+
)
458+
459+
rows = cursor.fetchall()
460+
assert len(rows) == 1
461+
462+
col_names = [desc[0] for desc in cursor.description]
463+
row = rows[0]
464+
465+
cnt = row[col_names.index("cnt")]
466+
cheapest = row[col_names.index("cheapest")]
467+
priciest = row[col_names.index("priciest")]
468+
avg_price = row[col_names.index("avg_price")]
469+
470+
assert cnt == 50
471+
assert cheapest <= avg_price <= priciest
472+
473+
def test_min_max_on_products(self, conn):
474+
"""SELECT MIN(price) AS low, MAX(price) AS high FROM products"""
475+
cursor = conn.cursor()
476+
cursor.execute("SELECT MIN(price) AS low, MAX(price) AS high FROM products")
477+
478+
rows = cursor.fetchall()
479+
assert len(rows) == 1
480+
481+
col_names = [desc[0] for desc in cursor.description]
482+
low = rows[0][col_names.index("low")]
483+
high = rows[0][col_names.index("high")]
484+
assert low < high
485+
486+
def test_count_with_and_or_conditions(self, conn):
487+
"""SELECT COUNT(*) AS cnt FROM users WHERE (active = true AND age > 30) OR age < 25"""
488+
cursor = conn.cursor()
489+
490+
# AND-only: active users over 30
491+
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE active = true AND age > 30")
492+
rows = cursor.fetchall()
493+
col_names = [desc[0] for desc in cursor.description]
494+
and_count = rows[0][col_names.index("cnt")]
495+
assert isinstance(and_count, (int, float))
496+
assert and_count > 0
497+
assert and_count < 22
498+
499+
# OR-only: very young or very old
500+
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40")
501+
rows = cursor.fetchall()
502+
col_names = [desc[0] for desc in cursor.description]
503+
or_count = rows[0][col_names.index("cnt")]
504+
assert isinstance(or_count, (int, float))
505+
assert or_count > 0
506+
assert or_count < 22
507+
508+
# Three AND conditions
509+
cursor.execute(
510+
"SELECT COUNT(*) AS cnt, AVG(age) AS avg_age FROM users " "WHERE active = true AND age >= 25 AND age <= 40"
511+
)
512+
rows = cursor.fetchall()
513+
col_names = [desc[0] for desc in cursor.description]
514+
cnt = rows[0][col_names.index("cnt")]
515+
avg_age = rows[0][col_names.index("avg_age")]
516+
assert cnt > 0
517+
assert cnt < 22
518+
assert 25 <= avg_age <= 40

0 commit comments

Comments
 (0)