Skip to content

Commit 7f4c031

Browse files
committed
Rename to DiversitySelection
1 parent edbcce8 commit 7f4c031

3 files changed

Lines changed: 17 additions & 15 deletions

File tree

integration/test_collection_diversity.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from integration.conftest import CollectionFactory
4-
from weaviate.classes.query import Diversity
4+
from weaviate.classes.query import DiversitySelection
55
from weaviate.collections.classes.config import Configure, DataType, Property
66
from weaviate.collections.classes.data import DataObject
77

@@ -36,7 +36,7 @@ def test_near_vector_diversity_pure_relevance(
3636
baseline = collection.query.near_vector(near_vector=[1.0, 0.0, 0.0], limit=3).objects
3737
diverse = collection.query.near_vector(
3838
near_vector=[1.0, 0.0, 0.0],
39-
diversity_selection=Diversity.mmr(limit=3, balance=1.0),
39+
diversity_selection=DiversitySelection.mmr(limit=3, balance=1.0),
4040
).objects
4141

4242
assert [o.properties["text"] for o in baseline] == [o.properties["text"] for o in diverse]
@@ -50,7 +50,7 @@ def test_near_vector_diversity_pure_diversity(
5050

5151
result = collection.query.near_vector(
5252
near_vector=[1.0, 0.0, 0.0],
53-
diversity_selection=Diversity.mmr(limit=3, balance=0.0),
53+
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
5454
)
5555
texts = {o.properties["text"] for o in result.objects}
5656
assert len(texts) == 3
@@ -62,11 +62,11 @@ def test_near_vector_diversity_pure_diversity(
6262
def test_near_vector_diversity_with_mmr_class(
6363
collection_factory: CollectionFactory,
6464
) -> None:
65-
"""Direct MMR class construction (Diversity.MMR) also works, not just the factory."""
65+
"""Direct MMR class construction (DiversitySelection.MMR) also works, not just the factory."""
6666
collection = _create_clustered_collection(collection_factory)
6767
result = collection.query.near_vector(
6868
near_vector=[1.0, 0.0, 0.0],
69-
diversity_selection=Diversity.MMR(limit=3, balance=0.0),
69+
diversity_selection=DiversitySelection.MMR(limit=3, balance=0.0),
7070
)
7171
clusters = {o.properties["text"][0] for o in result.objects}
7272
assert clusters == {"a", "b", "c"}
@@ -79,25 +79,25 @@ def test_near_object_diversity(collection_factory: CollectionFactory) -> None:
7979

8080
result = collection.query.near_object(
8181
near_object=anchor,
82-
diversity_selection=Diversity.mmr(limit=3, balance=0.0),
82+
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
8383
)
8484
assert len(result.objects) == 3
8585
clusters = {o.properties["text"][0] for o in result.objects}
8686
assert len(clusters) == 3
8787

8888

8989
def test_diversity_cannot_be_instantiated() -> None:
90-
"""Diversity is a factory — direct instantiation should fail."""
90+
"""DiversitySelection is a factory — direct instantiation should fail."""
9191
with pytest.raises(TypeError):
92-
Diversity()
92+
DiversitySelection()
9393

9494

9595
def test_diversity_mmr_only_limit(collection_factory: CollectionFactory) -> None:
9696
"""MMR accepts just a limit (balance defaults to server-side value)."""
9797
collection = _create_clustered_collection(collection_factory)
9898
result = collection.query.near_vector(
9999
near_vector=[1.0, 0.0, 0.0],
100-
diversity_selection=Diversity.mmr(limit=2),
100+
diversity_selection=DiversitySelection.mmr(limit=2),
101101
)
102102
assert len(result.objects) == 2
103103

@@ -117,7 +117,7 @@ def test_near_text_diversity(collection_factory: CollectionFactory) -> None:
117117

118118
result = collection.query.near_text(
119119
query="fruit",
120-
diversity_selection=Diversity.mmr(limit=3, balance=0.0),
120+
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
121121
)
122122
assert len(result.objects) == 3
123123

@@ -139,6 +139,6 @@ def test_near_text_generate_diversity(collection_factory: CollectionFactory) ->
139139
result = collection.generate.near_text(
140140
query="fruit",
141141
single_prompt="Describe {name}",
142-
diversity_selection=Diversity.mmr(limit=3, balance=0.0),
142+
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
143143
)
144144
assert len(result.objects) == 3

weaviate/classes/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from weaviate.collections.classes.grpc import (
88
MMR,
9-
Diversity,
9+
DiversitySelection,
1010
GroupBy,
1111
HybridFusion,
1212
HybridVector,
@@ -23,7 +23,7 @@
2323
from weaviate.collections.classes.types import GeoCoordinate
2424

2525
__all__ = [
26-
"Diversity",
26+
"DiversitySelection",
2727
"Filter",
2828
"FilterReturn",
2929
"GeoCoordinate",

weaviate/collections/classes/grpc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,13 @@ class MMR:
281281
balance: Optional[float] = None
282282

283283

284-
class Diversity:
284+
class DiversitySelection:
285285
"""Use this factory class to apply diversity selection to search results via MMR."""
286286

287287
def __init__(self) -> None:
288-
raise TypeError("Diversity cannot be instantiated directly. Use Diversity.mmr(...).")
288+
raise TypeError(
289+
"DiversitySelection cannot be instantiated directly. Use DiversitySelection.mmr(...)."
290+
)
289291

290292
MMR = MMR
291293

0 commit comments

Comments
 (0)