Skip to content

Commit c646b08

Browse files
committed
Updates from review comments
1 parent 68264f7 commit c646b08

2 files changed

Lines changed: 227 additions & 11 deletions

File tree

mindsdb/integrations/handlers/mongodb_handler/utils/mongodb_render.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ def _parse_select(self, from_table: Any) -> TypingTuple[str, Dict[str, Any], Opt
9090
# reject complex forms early
9191
# how deep we want to go with subqueries?
9292
if from_table.group_by is not None or from_table.having is not None:
93-
raise NotImplementedError(f"Not supported FROM as {from_table}")
93+
raise NotImplementedError(f"Not supported, subquery has `having` or `group by`: {from_table}")
9494

9595
if not isinstance(from_table.from_table, Identifier):
96-
raise NotImplementedError(f"Not supported FROM as {from_table}")
96+
raise NotImplementedError(f"Only simple subqueries are allowed in {from_table}")
9797

9898
collection = from_table.from_table.parts[-1]
9999

@@ -211,27 +211,28 @@ def select(self, node: Select) -> MongoQuery:
211211
func_name = col.op.lower()
212212
alias = col.alias.parts[-1] if col.alias is not None else func_name
213213

214-
if func_name == "count" and len(col.args) > 0 and isinstance(col.args[0], Star):
214+
if len(col.args) == 0:
215+
raise NotImplementedError(f"Function {func_name.upper()} requires arguments")
216+
217+
arg0 = col.args[0]
218+
219+
if func_name == "count" and isinstance(arg0, Star):
215220
agg_group[alias] = {"$sum": 1}
216-
elif len(col.args) > 0 and isinstance(col.args[0], Identifier):
217-
field_name = ".".join(col.args[0].parts)
221+
elif isinstance(arg0, Identifier):
222+
field_name = ".".join(arg0.parts)
218223

219-
# Map SQL functions to MongoDB operators
220224
if func_name == "avg":
221225
agg_group[alias] = {"$avg": f"${field_name}"}
222226
elif func_name == "sum":
223227
agg_group[alias] = {"$sum": f"${field_name}"}
224228
elif func_name == "count":
225-
if isinstance(col.args[0], Star):
226-
agg_group[alias] = {"$sum": 1}
227-
else:
228-
agg_group[alias] = {"$sum": {"$cond": [f"${field_name}", 1, 0]}}
229+
agg_group[alias] = {"$sum": {"$cond": [{"$ne": [f"${field_name}", None]}, 1, 0]}}
229230
elif func_name == "min":
230231
agg_group[alias] = {"$min": f"${field_name}"}
231232
elif func_name == "max":
232233
agg_group[alias] = {"$max": f"${field_name}"}
233234
else:
234-
raise NotImplementedError(f"Function {func_name} not supported")
235+
raise NotImplementedError(f"Aggregation function '{func_name.upper()}' is not supported")
235236
elif isinstance(col, Constant):
236237
alias = str(col.value) if col.alias is None else col.alias.parts[-1]
237238
project[alias] = col.value

tests/unit/handlers/test_mongodb.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,221 @@ def test_select_with_constant_no_alias(self):
566566
self.assertEqual(df["42"].tolist(), [42])
567567
self.assertEqual(df["hello"].tolist(), ["hello"])
568568

569+
def test_query_select_with_subquery_and_where(self):
570+
"""
571+
Test if the `query` method returns a response object with a data frame
572+
containing the query result for a select with subquery that has WHERE clause.
573+
"""
574+
self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [
575+
"movies"
576+
]
577+
578+
self.mock_connect.return_value[self.dummy_connection_data["database"]]["movies"].aggregate.return_value = [
579+
{
580+
"name": "The Dark Knight",
581+
"runtime": 152,
582+
},
583+
{
584+
"name": "Inception",
585+
"runtime": 148,
586+
},
587+
]
588+
589+
subquery = ast.Select(
590+
targets=[
591+
ast.Identifier(parts=["name"]),
592+
ast.Identifier(parts=["runtime"]),
593+
],
594+
from_table=ast.Identifier("movies"),
595+
where=ast.BinaryOperation(op=">", args=[ast.Identifier(parts=["runtime"]), ast.Constant(120)]),
596+
)
597+
598+
main_query = ast.Select(
599+
targets=[
600+
Star(),
601+
],
602+
from_table=subquery,
603+
)
604+
605+
response = self.handler.query(main_query)
606+
607+
assert isinstance(response, Response)
608+
self.assertEqual(response.type, RESPONSE_TYPE.TABLE)
609+
610+
df = response.data_frame
611+
self.assertEqual(len(df), 2)
612+
self.assertEqual(df.columns.tolist(), ["name", "runtime"])
613+
self.assertEqual(df["name"].tolist(), ["The Dark Knight", "Inception"])
614+
self.assertEqual(df["runtime"].tolist(), [152, 148])
615+
616+
def test_query_select_nested_field_projection(self):
617+
"""
618+
Test if the `query` method correctly handles nested field projection using dot notation.
619+
MongoDB stores nested documents (JSON data) that can be accessed with dot notation.
620+
"""
621+
self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [
622+
"clients"
623+
]
624+
625+
self.mock_connect.return_value[self.dummy_connection_data["database"]]["clients"].aggregate.return_value = [
626+
{
627+
"financials.profit_margin": 0.18,
628+
"financials.account_balance": 150000,
629+
},
630+
{
631+
"financials.profit_margin": 0.22,
632+
"financials.account_balance": 85000,
633+
},
634+
]
635+
636+
query = ast.Select(
637+
targets=[
638+
ast.Identifier(parts=["financials", "profit_margin"]),
639+
ast.Identifier(parts=["financials", "account_balance"]),
640+
],
641+
from_table=ast.Identifier("clients"),
642+
)
643+
644+
response = self.handler.query(query)
645+
646+
assert isinstance(response, Response)
647+
self.assertEqual(response.type, RESPONSE_TYPE.TABLE)
648+
649+
df = response.data_frame
650+
self.assertEqual(len(df), 2)
651+
self.assertEqual(
652+
df.columns.tolist(),
653+
["financials.profit_margin", "financials.account_balance"],
654+
)
655+
self.assertEqual(df["financials.profit_margin"].tolist(), [0.18, 0.22])
656+
self.assertEqual(df["financials.account_balance"].tolist(), [150000, 85000])
657+
658+
def test_query_select_nested_field_with_where(self):
659+
"""
660+
Test nested field projection with WHERE clause on nested field.
661+
Tests that nested fields work correctly in both SELECT and WHERE clauses.
662+
"""
663+
self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [
664+
"clients"
665+
]
666+
667+
self.mock_connect.return_value[self.dummy_connection_data["database"]]["clients"].aggregate.return_value = [
668+
{
669+
"financials.profit_margin": 0.18,
670+
},
671+
{
672+
"financials.profit_margin": 0.22,
673+
},
674+
]
675+
676+
query = ast.Select(
677+
targets=[
678+
ast.Identifier(parts=["financials", "profit_margin"]),
679+
],
680+
from_table=ast.Identifier("clients"),
681+
where=ast.BinaryOperation(
682+
op=">",
683+
args=[
684+
ast.Identifier(parts=["financials", "profit_margin"]),
685+
ast.Constant(0.15),
686+
],
687+
),
688+
)
689+
690+
response = self.handler.query(query)
691+
692+
assert isinstance(response, Response)
693+
self.assertEqual(response.type, RESPONSE_TYPE.TABLE)
694+
695+
df = response.data_frame
696+
self.assertEqual(len(df), 2)
697+
self.assertEqual(df.columns.tolist(), ["financials.profit_margin"])
698+
self.assertEqual(df["financials.profit_margin"].tolist(), [0.18, 0.22])
699+
700+
def test_query_aggregation_on_nested_field(self):
701+
"""
702+
Test aggregation function (AVG) on nested field.
703+
Tests that nested fields work correctly with aggregation functions.
704+
"""
705+
self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [
706+
"clients"
707+
]
708+
709+
self.mock_connect.return_value[self.dummy_connection_data["database"]]["clients"].aggregate.return_value = [
710+
{
711+
"avg_margin": 0.191,
712+
}
713+
]
714+
715+
query = ast.Select(
716+
targets=[
717+
ast.Function(
718+
op="AVG",
719+
args=[ast.Identifier(parts=["financials", "profit_margin"])],
720+
alias=ast.Identifier(parts=["avg_margin"]),
721+
)
722+
],
723+
from_table=ast.Identifier("clients"),
724+
)
725+
726+
response = self.handler.query(query)
727+
728+
assert isinstance(response, Response)
729+
self.assertEqual(response.type, RESPONSE_TYPE.TABLE)
730+
731+
df = response.data_frame
732+
self.assertEqual(len(df), 1)
733+
self.assertEqual(df.columns.tolist(), ["avg_margin"])
734+
self.assertAlmostEqual(df["avg_margin"].tolist()[0], 0.191, places=3)
735+
736+
def test_query_group_by_with_nested_aggregation(self):
737+
"""
738+
Test GROUP BY with aggregation on nested field.
739+
Tests that nested fields work correctly with GROUP BY and aggregation.
740+
"""
741+
self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [
742+
"clients"
743+
]
744+
745+
self.mock_connect.return_value[self.dummy_connection_data["database"]]["clients"].aggregate.return_value = [
746+
{
747+
"industry": "technology",
748+
"avg_margin": 0.18,
749+
},
750+
{
751+
"industry": "finance",
752+
"avg_margin": 0.22,
753+
},
754+
{
755+
"industry": "healthcare",
756+
"avg_margin": 0.15,
757+
},
758+
]
759+
760+
query = ast.Select(
761+
targets=[
762+
ast.Identifier(parts=["industry"]),
763+
ast.Function(
764+
op="AVG",
765+
args=[ast.Identifier(parts=["financials", "profit_margin"])],
766+
alias=ast.Identifier(parts=["avg_margin"]),
767+
),
768+
],
769+
from_table=ast.Identifier("clients"),
770+
group_by=[ast.Identifier(parts=["industry"])],
771+
)
772+
773+
response = self.handler.query(query)
774+
775+
assert isinstance(response, Response)
776+
self.assertEqual(response.type, RESPONSE_TYPE.TABLE)
777+
778+
df = response.data_frame
779+
self.assertEqual(len(df), 3)
780+
self.assertEqual(df.columns.tolist(), ["industry", "avg_margin"])
781+
self.assertEqual(df["industry"].tolist(), ["technology", "finance", "healthcare"])
782+
self.assertEqual(df["avg_margin"].tolist(), [0.18, 0.22, 0.15])
783+
569784

570785
if __name__ == "__main__":
571786
unittest.main()

0 commit comments

Comments
 (0)