Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions integration/test_collection_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ def test_aggregation_groupby_with_limit(collection_factory: CollectionFactory) -
assert res.groups[1].properties["text"].count == 1


def test_aggregation_groupby_no_results(collection_factory: CollectionFactory) -> None:
collection = collection_factory(properties=[Property(name="text", data_type=DataType.TEXT)])
res = collection.aggregate.over_all(
return_metrics=[Metrics("text").text(count=True)],
group_by=GroupByAggregate(prop="text", limit=2),
)
assert len(res.groups) == 0


@pytest.mark.parametrize(
"filter_",
[
Expand Down
9 changes: 3 additions & 6 deletions weaviate/collections/aggregations/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,17 @@ def _to_aggregate_result(
)

def _to_result(
self, response: aggregate_pb2.AggregateReply
self, is_groupby: bool, response: aggregate_pb2.AggregateReply
) -> Union[AggregateReturn, AggregateGroupByReturn]:
if response.HasField("single_result"):
if not is_groupby:
return AggregateReturn(
properties={
aggregation.property: self.__parse_property_grpc(aggregation)
for aggregation in response.single_result.aggregations.aggregations
},
total_count=response.single_result.objects_count,
)
if response.HasField("grouped_results"):
if is_groupby:
return AggregateGroupByReturn(
groups=[
AggregateGroup(
Expand All @@ -116,9 +116,6 @@ def _to_result(
for group in response.grouped_results.groups
]
)
else:
_Warnings.unknown_type_encountered(response.WhichOneof("result"))
return AggregateReturn(properties={}, total_count=None)

def __parse_grouped_by_value(
self, grouped_by: aggregate_pb2.AggregateReply.Group.GroupedBy
Expand Down
9 changes: 8 additions & 1 deletion weaviate/collections/aggregations/hybrid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from weaviate.connect.v4 import ConnectionType
from weaviate.exceptions import WeaviateUnsupportedFeatureError
from weaviate.types import NUMBER
from weaviate.proto.v1 import aggregate_pb2


class _HybridExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
Expand Down Expand Up @@ -161,8 +162,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
limit=group_by.limit if group_by is not None else None,
objects_count=total_count,
)

def respGrpc(
res: aggregate_pb2.AggregateReply,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
return self._to_result(group_by is not None, res)

return executor.execute(
response_callback=self._to_result,
response_callback=respGrpc,
method=self._connection.grpc_aggregate,
request=request,
)
9 changes: 8 additions & 1 deletion weaviate/collections/aggregations/near_image/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from weaviate.connect.v4 import ConnectionType
from weaviate.types import BLOB_INPUT, NUMBER
from weaviate.util import parse_blob
from weaviate.proto.v1 import aggregate_pb2


class _NearImageExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
Expand Down Expand Up @@ -144,8 +145,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
objects_count=total_count,
object_limit=object_limit,
)

def respGrpc(
res: aggregate_pb2.AggregateReply,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
return self._to_result(group_by is not None, res)

return executor.execute(
response_callback=self._to_result,
response_callback=respGrpc,
method=self._connection.grpc_aggregate,
request=request,
)
9 changes: 8 additions & 1 deletion weaviate/collections/aggregations/near_object/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from weaviate.connect import executor
from weaviate.connect.v4 import ConnectionType
from weaviate.types import NUMBER, UUID
from weaviate.proto.v1 import aggregate_pb2


class _NearObjectExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
Expand Down Expand Up @@ -142,8 +143,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
objects_count=total_count,
object_limit=object_limit,
)

def respGrpc(
res: aggregate_pb2.AggregateReply,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
return self._to_result(group_by is not None, res)

return executor.execute(
response_callback=self._to_result,
response_callback=respGrpc,
method=self._connection.grpc_aggregate,
request=request,
)
9 changes: 8 additions & 1 deletion weaviate/collections/aggregations/near_text/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from weaviate.connect import executor
from weaviate.connect.v4 import ConnectionType
from weaviate.types import NUMBER
from weaviate.proto.v1 import aggregate_pb2


class _NearTextExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
Expand Down Expand Up @@ -162,8 +163,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
objects_count=total_count,
object_limit=object_limit,
)

def respGrpc(
res: aggregate_pb2.AggregateReply,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
return self._to_result(group_by is not None, res)

return executor.execute(
response_callback=self._to_result,
response_callback=respGrpc,
method=self._connection.grpc_aggregate,
request=request,
)
9 changes: 8 additions & 1 deletion weaviate/collections/aggregations/near_vector/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from weaviate.connect.v4 import ConnectionType
from weaviate.exceptions import WeaviateInvalidInputError
from weaviate.types import NUMBER
from weaviate.proto.v1 import aggregate_pb2


class _NearVectorExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
Expand Down Expand Up @@ -163,8 +164,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
objects_count=total_count,
object_limit=object_limit,
)

def respGrpc(
res: aggregate_pb2.AggregateReply,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
return self._to_result(group_by is not None, res)

return executor.execute(
response_callback=self._to_result,
response_callback=respGrpc,
method=self._connection.grpc_aggregate,
request=request,
)
9 changes: 8 additions & 1 deletion weaviate/collections/aggregations/over_all/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from weaviate.collections.filters import _FilterToGRPC
from weaviate.connect import executor
from weaviate.connect.v4 import ConnectionType
from weaviate.proto.v1 import aggregate_pb2


class _OverAllExecutor(Generic[ConnectionType], _BaseExecutor[ConnectionType]):
Expand Down Expand Up @@ -105,8 +106,14 @@ def resp(res: dict) -> Union[AggregateReturn, AggregateGroupByReturn]:
limit=group_by.limit if group_by is not None else None,
objects_count=total_count,
)

def respGrpc(
res: aggregate_pb2.AggregateReply,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
return self._to_result(group_by is not None, res)

return executor.execute(
response_callback=self._to_result,
response_callback=respGrpc,
method=self._connection.grpc_aggregate,
request=request,
)