Skip to content

Commit bd48025

Browse files
authored
Merge pull request #1667 from weaviate/groupByreturn
Fix groupby returns without results
2 parents 7b588c5 + 1c5f393 commit bd48025

8 files changed

Lines changed: 60 additions & 12 deletions

File tree

integration/test_collection_aggregate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ def test_aggregation_groupby_with_limit(collection_factory: CollectionFactory) -
9999
assert res.groups[1].properties["text"].count == 1
100100

101101

102+
def test_aggregation_groupby_no_results(collection_factory: CollectionFactory) -> None:
103+
collection = collection_factory(properties=[Property(name="text", data_type=DataType.TEXT)])
104+
res = collection.aggregate.over_all(
105+
return_metrics=[Metrics("text").text(count=True)],
106+
group_by=GroupByAggregate(prop="text", limit=2),
107+
)
108+
assert len(res.groups) == 0
109+
110+
102111
@pytest.mark.parametrize(
103112
"filter_",
104113
[

weaviate/collections/aggregations/base_executor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,17 @@ def _to_aggregate_result(
9292
)
9393

9494
def _to_result(
95-
self, response: aggregate_pb2.AggregateReply
95+
self, is_groupby: bool, response: aggregate_pb2.AggregateReply
9696
) -> Union[AggregateReturn, AggregateGroupByReturn]:
97-
if response.HasField("single_result"):
97+
if not is_groupby:
9898
return AggregateReturn(
9999
properties={
100100
aggregation.property: self.__parse_property_grpc(aggregation)
101101
for aggregation in response.single_result.aggregations.aggregations
102102
},
103103
total_count=response.single_result.objects_count,
104104
)
105-
if response.HasField("grouped_results"):
105+
if is_groupby:
106106
return AggregateGroupByReturn(
107107
groups=[
108108
AggregateGroup(
@@ -116,9 +116,6 @@ def _to_result(
116116
for group in response.grouped_results.groups
117117
]
118118
)
119-
else:
120-
_Warnings.unknown_type_encountered(response.WhichOneof("result"))
121-
return AggregateReturn(properties={}, total_count=None)
122119

123120
def __parse_grouped_by_value(
124121
self, grouped_by: aggregate_pb2.AggregateReply.Group.GroupedBy

weaviate/collections/aggregations/hybrid/executor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from weaviate.connect.v4 import ConnectionType
1414
from weaviate.exceptions import WeaviateUnsupportedFeatureError
1515
from weaviate.types import NUMBER
16+
from weaviate.proto.v1 import aggregate_pb2
1617

1718

1819
class _HybridExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
@@ -161,8 +162,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
161162
limit=group_by.limit if group_by is not None else None,
162163
objects_count=total_count,
163164
)
165+
166+
def respGrpc(
167+
res: aggregate_pb2.AggregateReply,
168+
) -> Union[AggregateReturn, AggregateGroupByReturn]:
169+
return self._to_result(group_by is not None, res)
170+
164171
return executor.execute(
165-
response_callback=self._to_result,
172+
response_callback=respGrpc,
166173
method=self._connection.grpc_aggregate,
167174
request=request,
168175
)

weaviate/collections/aggregations/near_image/executor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from weaviate.connect.v4 import ConnectionType
1414
from weaviate.types import BLOB_INPUT, NUMBER
1515
from weaviate.util import parse_blob
16+
from weaviate.proto.v1 import aggregate_pb2
1617

1718

1819
class _NearImageExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
@@ -144,8 +145,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
144145
objects_count=total_count,
145146
object_limit=object_limit,
146147
)
148+
149+
def respGrpc(
150+
res: aggregate_pb2.AggregateReply,
151+
) -> Union[AggregateReturn, AggregateGroupByReturn]:
152+
return self._to_result(group_by is not None, res)
153+
147154
return executor.execute(
148-
response_callback=self._to_result,
155+
response_callback=respGrpc,
149156
method=self._connection.grpc_aggregate,
150157
request=request,
151158
)

weaviate/collections/aggregations/near_object/executor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from weaviate.connect import executor
1313
from weaviate.connect.v4 import ConnectionType
1414
from weaviate.types import NUMBER, UUID
15+
from weaviate.proto.v1 import aggregate_pb2
1516

1617

1718
class _NearObjectExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
@@ -142,8 +143,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
142143
objects_count=total_count,
143144
object_limit=object_limit,
144145
)
146+
147+
def respGrpc(
148+
res: aggregate_pb2.AggregateReply,
149+
) -> Union[AggregateReturn, AggregateGroupByReturn]:
150+
return self._to_result(group_by is not None, res)
151+
145152
return executor.execute(
146-
response_callback=self._to_result,
153+
response_callback=respGrpc,
147154
method=self._connection.grpc_aggregate,
148155
request=request,
149156
)

weaviate/collections/aggregations/near_text/executor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from weaviate.connect import executor
1414
from weaviate.connect.v4 import ConnectionType
1515
from weaviate.types import NUMBER
16+
from weaviate.proto.v1 import aggregate_pb2
1617

1718

1819
class _NearTextExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
@@ -162,8 +163,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
162163
objects_count=total_count,
163164
object_limit=object_limit,
164165
)
166+
167+
def respGrpc(
168+
res: aggregate_pb2.AggregateReply,
169+
) -> Union[AggregateReturn, AggregateGroupByReturn]:
170+
return self._to_result(group_by is not None, res)
171+
165172
return executor.execute(
166-
response_callback=self._to_result,
173+
response_callback=respGrpc,
167174
method=self._connection.grpc_aggregate,
168175
request=request,
169176
)

weaviate/collections/aggregations/near_vector/executor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from weaviate.connect.v4 import ConnectionType
1818
from weaviate.exceptions import WeaviateInvalidInputError
1919
from weaviate.types import NUMBER
20+
from weaviate.proto.v1 import aggregate_pb2
2021

2122

2223
class _NearVectorExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
@@ -163,8 +164,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
163164
objects_count=total_count,
164165
object_limit=object_limit,
165166
)
167+
168+
def respGrpc(
169+
res: aggregate_pb2.AggregateReply,
170+
) -> Union[AggregateReturn, AggregateGroupByReturn]:
171+
return self._to_result(group_by is not None, res)
172+
166173
return executor.execute(
167-
response_callback=self._to_result,
174+
response_callback=respGrpc,
168175
method=self._connection.grpc_aggregate,
169176
request=request,
170177
)

weaviate/collections/aggregations/over_all/executor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from weaviate.collections.filters import _FilterToGRPC
1212
from weaviate.connect import executor
1313
from weaviate.connect.v4 import ConnectionType
14+
from weaviate.proto.v1 import aggregate_pb2
1415

1516

1617
class _OverAllExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
@@ -105,8 +106,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
105106
limit=group_by.limit if group_by is not None else None,
106107
objects_count=total_count,
107108
)
109+
110+
def respGrpc(
111+
res: aggregate_pb2.AggregateReply,
112+
) -> Union[AggregateReturn, AggregateGroupByReturn]:
113+
return self._to_result(group_by is not None, res)
114+
108115
return executor.execute(
109-
response_callback=self._to_result,
116+
response_callback=respGrpc,
110117
method=self._connection.grpc_aggregate,
111118
request=request,
112119
)

0 commit comments

Comments
 (0)