@@ -354,3 +354,165 @@ def test_aggregate_collection_name_with_hyphen(self, conn):
354354 customer_type_idx = col_names .index ("customer_type" )
355355 for row in rows :
356356 assert row [customer_type_idx ] == "premium" , "All rows should have customer_type='premium'"
357+
358+
359+ class TestSqlGroupFunctions :
360+ """Test SQL aggregate functions (COUNT, AVG, MIN, MAX, SUM) translated to MongoDB pipelines."""
361+
362+ def test_count_star (self , conn ):
363+ """SELECT COUNT(*) AS total FROM users → should return document count"""
364+ cursor = conn .cursor ()
365+ cursor .execute ("SELECT COUNT(*) AS total FROM users" )
366+
367+ rows = cursor .fetchall ()
368+ assert len (rows ) == 1
369+
370+ col_names = [desc [0 ] for desc in cursor .description ]
371+ assert "total" in col_names
372+
373+ total_idx = col_names .index ("total" )
374+ assert rows [0 ][total_idx ] == 22 # 22 users in test data
375+
376+ def test_count_star_no_alias (self , conn ):
377+ """SELECT COUNT(*) FROM users → column name defaults to COUNT(*)"""
378+ cursor = conn .cursor ()
379+ cursor .execute ("SELECT COUNT(*) FROM users" )
380+
381+ rows = cursor .fetchall ()
382+ assert len (rows ) == 1
383+
384+ col_names = [desc [0 ] for desc in cursor .description ]
385+ assert "COUNT(*)" in col_names
386+ assert rows [0 ][col_names .index ("COUNT(*)" )] == 22
387+
388+ def test_count_star_with_where (self , conn ):
389+ """SELECT COUNT(*) AS total FROM users WHERE age > 30 → filtered count"""
390+ cursor = conn .cursor ()
391+ cursor .execute ("SELECT COUNT(*) AS total FROM users WHERE age > 30" )
392+
393+ rows = cursor .fetchall ()
394+ assert len (rows ) == 1
395+
396+ col_names = [desc [0 ] for desc in cursor .description ]
397+ total = rows [0 ][col_names .index ("total" )]
398+ assert isinstance (total , (int , float ))
399+ assert total > 0
400+ assert total < 22 # Must be less than total users
401+
402+ def test_avg (self , conn ):
403+ """SELECT AVG(age) AS avg_age FROM users"""
404+ cursor = conn .cursor ()
405+ cursor .execute ("SELECT AVG(age) AS avg_age FROM users" )
406+
407+ rows = cursor .fetchall ()
408+ assert len (rows ) == 1
409+
410+ col_names = [desc [0 ] for desc in cursor .description ]
411+ avg_age = rows [0 ][col_names .index ("avg_age" )]
412+ assert isinstance (avg_age , (int , float ))
413+ assert 24 <= avg_age <= 45 # Must be within the age range
414+
415+ def test_min (self , conn ):
416+ """SELECT MIN(age) AS youngest FROM users"""
417+ cursor = conn .cursor ()
418+ cursor .execute ("SELECT MIN(age) AS youngest FROM users" )
419+
420+ rows = cursor .fetchall ()
421+ assert len (rows ) == 1
422+
423+ col_names = [desc [0 ] for desc in cursor .description ]
424+ youngest = rows [0 ][col_names .index ("youngest" )]
425+ assert youngest == 24 # Min age in test data
426+
427+ def test_max (self , conn ):
428+ """SELECT MAX(age) AS oldest FROM users"""
429+ cursor = conn .cursor ()
430+ cursor .execute ("SELECT MAX(age) AS oldest FROM users" )
431+
432+ rows = cursor .fetchall ()
433+ assert len (rows ) == 1
434+
435+ col_names = [desc [0 ] for desc in cursor .description ]
436+ oldest = rows [0 ][col_names .index ("oldest" )]
437+ assert oldest == 45 # Max age in test data
438+
439+ def test_sum (self , conn ):
440+ """SELECT SUM(price) AS total_price FROM products"""
441+ cursor = conn .cursor ()
442+ cursor .execute ("SELECT SUM(price) AS total_price FROM products" )
443+
444+ rows = cursor .fetchall ()
445+ assert len (rows ) == 1
446+
447+ col_names = [desc [0 ] for desc in cursor .description ]
448+ total_price = rows [0 ][col_names .index ("total_price" )]
449+ assert isinstance (total_price , (int , float ))
450+ assert total_price > 0
451+
452+ def test_multiple_aggregates (self , conn ):
453+ """SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"""
454+ cursor = conn .cursor ()
455+ cursor .execute (
456+ "SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"
457+ )
458+
459+ rows = cursor .fetchall ()
460+ assert len (rows ) == 1
461+
462+ col_names = [desc [0 ] for desc in cursor .description ]
463+ row = rows [0 ]
464+
465+ cnt = row [col_names .index ("cnt" )]
466+ cheapest = row [col_names .index ("cheapest" )]
467+ priciest = row [col_names .index ("priciest" )]
468+ avg_price = row [col_names .index ("avg_price" )]
469+
470+ assert cnt == 50
471+ assert cheapest <= avg_price <= priciest
472+
473+ def test_min_max_on_products (self , conn ):
474+ """SELECT MIN(price) AS low, MAX(price) AS high FROM products"""
475+ cursor = conn .cursor ()
476+ cursor .execute ("SELECT MIN(price) AS low, MAX(price) AS high FROM products" )
477+
478+ rows = cursor .fetchall ()
479+ assert len (rows ) == 1
480+
481+ col_names = [desc [0 ] for desc in cursor .description ]
482+ low = rows [0 ][col_names .index ("low" )]
483+ high = rows [0 ][col_names .index ("high" )]
484+ assert low < high
485+
486+ def test_count_with_and_or_conditions (self , conn ):
487+ """SELECT COUNT(*) AS cnt FROM users WHERE (active = true AND age > 30) OR age < 25"""
488+ cursor = conn .cursor ()
489+
490+ # AND-only: active users over 30
491+ cursor .execute ("SELECT COUNT(*) AS cnt FROM users WHERE active = true AND age > 30" )
492+ rows = cursor .fetchall ()
493+ col_names = [desc [0 ] for desc in cursor .description ]
494+ and_count = rows [0 ][col_names .index ("cnt" )]
495+ assert isinstance (and_count , (int , float ))
496+ assert and_count > 0
497+ assert and_count < 22
498+
499+ # OR-only: very young or very old
500+ cursor .execute ("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40" )
501+ rows = cursor .fetchall ()
502+ col_names = [desc [0 ] for desc in cursor .description ]
503+ or_count = rows [0 ][col_names .index ("cnt" )]
504+ assert isinstance (or_count , (int , float ))
505+ assert or_count > 0
506+ assert or_count < 22
507+
508+ # Three AND conditions
509+ cursor .execute (
510+ "SELECT COUNT(*) AS cnt, AVG(age) AS avg_age FROM users " "WHERE active = true AND age >= 25 AND age <= 40"
511+ )
512+ rows = cursor .fetchall ()
513+ col_names = [desc [0 ] for desc in cursor .description ]
514+ cnt = rows [0 ][col_names .index ("cnt" )]
515+ avg_age = rows [0 ][col_names .index ("avg_age" )]
516+ assert cnt > 0
517+ assert cnt < 22
518+ assert 25 <= avg_age <= 40
0 commit comments