@@ -566,6 +566,221 @@ def test_select_with_constant_no_alias(self):
566566 self .assertEqual (df ["42" ].tolist (), [42 ])
567567 self .assertEqual (df ["hello" ].tolist (), ["hello" ])
568568
569+ def test_query_select_with_subquery_and_where (self ):
570+ """
571+ Test if the `query` method returns a response object with a data frame
572+ containing the query result for a select with subquery that has WHERE clause.
573+ """
574+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]].list_collection_names .return_value = [
575+ "movies"
576+ ]
577+
578+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]]["movies" ].aggregate .return_value = [
579+ {
580+ "name" : "The Dark Knight" ,
581+ "runtime" : 152 ,
582+ },
583+ {
584+ "name" : "Inception" ,
585+ "runtime" : 148 ,
586+ },
587+ ]
588+
589+ subquery = ast .Select (
590+ targets = [
591+ ast .Identifier (parts = ["name" ]),
592+ ast .Identifier (parts = ["runtime" ]),
593+ ],
594+ from_table = ast .Identifier ("movies" ),
595+ where = ast .BinaryOperation (op = ">" , args = [ast .Identifier (parts = ["runtime" ]), ast .Constant (120 )]),
596+ )
597+
598+ main_query = ast .Select (
599+ targets = [
600+ Star (),
601+ ],
602+ from_table = subquery ,
603+ )
604+
605+ response = self .handler .query (main_query )
606+
607+ assert isinstance (response , Response )
608+ self .assertEqual (response .type , RESPONSE_TYPE .TABLE )
609+
610+ df = response .data_frame
611+ self .assertEqual (len (df ), 2 )
612+ self .assertEqual (df .columns .tolist (), ["name" , "runtime" ])
613+ self .assertEqual (df ["name" ].tolist (), ["The Dark Knight" , "Inception" ])
614+ self .assertEqual (df ["runtime" ].tolist (), [152 , 148 ])
615+
616+ def test_query_select_nested_field_projection (self ):
617+ """
618+ Test if the `query` method correctly handles nested field projection using dot notation.
619+ MongoDB stores nested documents (JSON data) that can be accessed with dot notation.
620+ """
621+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]].list_collection_names .return_value = [
622+ "clients"
623+ ]
624+
625+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]]["clients" ].aggregate .return_value = [
626+ {
627+ "financials.profit_margin" : 0.18 ,
628+ "financials.account_balance" : 150000 ,
629+ },
630+ {
631+ "financials.profit_margin" : 0.22 ,
632+ "financials.account_balance" : 85000 ,
633+ },
634+ ]
635+
636+ query = ast .Select (
637+ targets = [
638+ ast .Identifier (parts = ["financials" , "profit_margin" ]),
639+ ast .Identifier (parts = ["financials" , "account_balance" ]),
640+ ],
641+ from_table = ast .Identifier ("clients" ),
642+ )
643+
644+ response = self .handler .query (query )
645+
646+ assert isinstance (response , Response )
647+ self .assertEqual (response .type , RESPONSE_TYPE .TABLE )
648+
649+ df = response .data_frame
650+ self .assertEqual (len (df ), 2 )
651+ self .assertEqual (
652+ df .columns .tolist (),
653+ ["financials.profit_margin" , "financials.account_balance" ],
654+ )
655+ self .assertEqual (df ["financials.profit_margin" ].tolist (), [0.18 , 0.22 ])
656+ self .assertEqual (df ["financials.account_balance" ].tolist (), [150000 , 85000 ])
657+
658+ def test_query_select_nested_field_with_where (self ):
659+ """
660+ Test nested field projection with WHERE clause on nested field.
661+ Tests that nested fields work correctly in both SELECT and WHERE clauses.
662+ """
663+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]].list_collection_names .return_value = [
664+ "clients"
665+ ]
666+
667+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]]["clients" ].aggregate .return_value = [
668+ {
669+ "financials.profit_margin" : 0.18 ,
670+ },
671+ {
672+ "financials.profit_margin" : 0.22 ,
673+ },
674+ ]
675+
676+ query = ast .Select (
677+ targets = [
678+ ast .Identifier (parts = ["financials" , "profit_margin" ]),
679+ ],
680+ from_table = ast .Identifier ("clients" ),
681+ where = ast .BinaryOperation (
682+ op = ">" ,
683+ args = [
684+ ast .Identifier (parts = ["financials" , "profit_margin" ]),
685+ ast .Constant (0.15 ),
686+ ],
687+ ),
688+ )
689+
690+ response = self .handler .query (query )
691+
692+ assert isinstance (response , Response )
693+ self .assertEqual (response .type , RESPONSE_TYPE .TABLE )
694+
695+ df = response .data_frame
696+ self .assertEqual (len (df ), 2 )
697+ self .assertEqual (df .columns .tolist (), ["financials.profit_margin" ])
698+ self .assertEqual (df ["financials.profit_margin" ].tolist (), [0.18 , 0.22 ])
699+
700+ def test_query_aggregation_on_nested_field (self ):
701+ """
702+ Test aggregation function (AVG) on nested field.
703+ Tests that nested fields work correctly with aggregation functions.
704+ """
705+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]].list_collection_names .return_value = [
706+ "clients"
707+ ]
708+
709+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]]["clients" ].aggregate .return_value = [
710+ {
711+ "avg_margin" : 0.191 ,
712+ }
713+ ]
714+
715+ query = ast .Select (
716+ targets = [
717+ ast .Function (
718+ op = "AVG" ,
719+ args = [ast .Identifier (parts = ["financials" , "profit_margin" ])],
720+ alias = ast .Identifier (parts = ["avg_margin" ]),
721+ )
722+ ],
723+ from_table = ast .Identifier ("clients" ),
724+ )
725+
726+ response = self .handler .query (query )
727+
728+ assert isinstance (response , Response )
729+ self .assertEqual (response .type , RESPONSE_TYPE .TABLE )
730+
731+ df = response .data_frame
732+ self .assertEqual (len (df ), 1 )
733+ self .assertEqual (df .columns .tolist (), ["avg_margin" ])
734+ self .assertAlmostEqual (df ["avg_margin" ].tolist ()[0 ], 0.191 , places = 3 )
735+
736+ def test_query_group_by_with_nested_aggregation (self ):
737+ """
738+ Test GROUP BY with aggregation on nested field.
739+ Tests that nested fields work correctly with GROUP BY and aggregation.
740+ """
741+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]].list_collection_names .return_value = [
742+ "clients"
743+ ]
744+
745+ self .mock_connect .return_value [self .dummy_connection_data ["database" ]]["clients" ].aggregate .return_value = [
746+ {
747+ "industry" : "technology" ,
748+ "avg_margin" : 0.18 ,
749+ },
750+ {
751+ "industry" : "finance" ,
752+ "avg_margin" : 0.22 ,
753+ },
754+ {
755+ "industry" : "healthcare" ,
756+ "avg_margin" : 0.15 ,
757+ },
758+ ]
759+
760+ query = ast .Select (
761+ targets = [
762+ ast .Identifier (parts = ["industry" ]),
763+ ast .Function (
764+ op = "AVG" ,
765+ args = [ast .Identifier (parts = ["financials" , "profit_margin" ])],
766+ alias = ast .Identifier (parts = ["avg_margin" ]),
767+ ),
768+ ],
769+ from_table = ast .Identifier ("clients" ),
770+ group_by = [ast .Identifier (parts = ["industry" ])],
771+ )
772+
773+ response = self .handler .query (query )
774+
775+ assert isinstance (response , Response )
776+ self .assertEqual (response .type , RESPONSE_TYPE .TABLE )
777+
778+ df = response .data_frame
779+ self .assertEqual (len (df ), 3 )
780+ self .assertEqual (df .columns .tolist (), ["industry" , "avg_margin" ])
781+ self .assertEqual (df ["industry" ].tolist (), ["technology" , "finance" , "healthcare" ])
782+ self .assertEqual (df ["avg_margin" ].tolist (), [0.18 , 0.22 , 0.15 ])
783+
569784
570785if __name__ == "__main__" :
571786 unittest .main ()
0 commit comments