Skip to content

Commit 218ee16

Browse files
committed
Implement feedback
1 parent e88d17e commit 218ee16

3 files changed

Lines changed: 48 additions & 95 deletions

File tree

Lines changed: 44 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from integration.conftest import CollectionFactory
4-
from weaviate.classes.query import DiversitySelection
4+
from weaviate.classes.query import Diversity
55
from weaviate.collections.classes.config import Configure, DataType, Property
66
from weaviate.collections.classes.data import DataObject
77

@@ -27,83 +27,34 @@ def _create_clustered_collection(collection_factory: CollectionFactory):
2727
return collection
2828

2929

30-
def test_near_vector_diversity_pure_relevance(
31-
collection_factory: CollectionFactory,
32-
) -> None:
33-
"""balance=1.0 -> MMR degenerates to pure relevance (same as no diversity)."""
34-
collection = _create_clustered_collection(collection_factory)
35-
36-
baseline = collection.query.near_vector(near_vector=[1.0, 0.0, 0.0], limit=3).objects
37-
diverse = collection.query.near_vector(
38-
near_vector=[1.0, 0.0, 0.0],
39-
diversity_selection=DiversitySelection.mmr(limit=3, balance=1.0),
40-
).objects
30+
def test_near_vector_diversity_selection(collection_factory: CollectionFactory) -> None:
31+
"""Verify that the client passes diversity_selection to the server correctly.
4132
42-
assert [o.properties["text"] for o in baseline] == [o.properties["text"] for o in diverse]
43-
44-
45-
def test_near_vector_diversity_pure_diversity(
46-
collection_factory: CollectionFactory,
47-
) -> None:
48-
"""balance=0.0 -> MMR picks maximally diverse results (one per cluster)."""
33+
Two orthogonal assertions — server-side logic (MMR itself) is out of scope:
34+
- ``balance`` reaches the server: balance=0.0 produces a different UUID ordering than balance=1.0
35+
- ``limit`` reaches the server: len(result) == mmr_limit
36+
"""
4937
collection = _create_clustered_collection(collection_factory)
38+
mmr_limit = 3
5039

51-
result = collection.query.near_vector(
40+
balance_0 = collection.query.near_vector(
5241
near_vector=[1.0, 0.0, 0.0],
53-
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
54-
)
55-
texts = {o.properties["text"] for o in result.objects}
56-
assert len(texts) == 3
57-
# Pure diversity should pick one from each cluster (a*, b*, c*)
58-
clusters = {t[0] for t in texts}
59-
assert clusters == {"a", "b", "c"}
60-
61-
62-
def test_near_vector_diversity_with_mmr_class(
63-
collection_factory: CollectionFactory,
64-
) -> None:
65-
"""Direct MMR class construction (DiversitySelection.MMR) also works, not just the factory."""
66-
collection = _create_clustered_collection(collection_factory)
67-
result = collection.query.near_vector(
42+
diversity_selection=Diversity.mmr(limit=mmr_limit, balance=0.0),
43+
).objects
44+
balance_1 = collection.query.near_vector(
6845
near_vector=[1.0, 0.0, 0.0],
69-
diversity_selection=DiversitySelection.MMR(limit=3, balance=0.0),
70-
)
71-
clusters = {o.properties["text"][0] for o in result.objects}
72-
assert clusters == {"a", "b", "c"}
73-
74-
75-
def test_near_object_diversity(collection_factory: CollectionFactory) -> None:
76-
"""near_object supports diversity selection."""
77-
collection = _create_clustered_collection(collection_factory)
78-
anchor = next(iter(collection.query.fetch_objects().objects)).uuid
79-
80-
result = collection.query.near_object(
81-
near_object=anchor,
82-
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
83-
)
84-
assert len(result.objects) == 3
85-
clusters = {o.properties["text"][0] for o in result.objects}
86-
assert len(clusters) == 3
87-
88-
89-
def test_diversity_cannot_be_instantiated() -> None:
90-
"""Test that direct instantiation of the DiversitySelection factory fails."""
91-
with pytest.raises(TypeError):
92-
DiversitySelection()
93-
46+
diversity_selection=Diversity.mmr(limit=mmr_limit, balance=1.0),
47+
).objects
9448

95-
def test_diversity_mmr_only_limit(collection_factory: CollectionFactory) -> None:
96-
"""MMR accepts just a limit (balance defaults to server-side value)."""
97-
collection = _create_clustered_collection(collection_factory)
98-
result = collection.query.near_vector(
99-
near_vector=[1.0, 0.0, 0.0],
100-
diversity_selection=DiversitySelection.mmr(limit=2),
101-
)
102-
assert len(result.objects) == 2
49+
# mmr_limit reaches the server → result count equals it
50+
assert len(balance_0) == mmr_limit
51+
assert len(balance_1) == mmr_limit
52+
# balance reaches the server → different ordering
53+
assert [o.uuid for o in balance_0] != [o.uuid for o in balance_1]
10354

10455

105-
def test_near_text_diversity(collection_factory: CollectionFactory) -> None:
106-
"""near_text supports diversity selection via text2vec-contextionary."""
56+
def test_near_text_diversity_selection(collection_factory: CollectionFactory) -> None:
57+
"""Smoke test: diversity_selection kwarg is wired through the near_text entry point."""
10758
collection = collection_factory(
10859
properties=[Property(name="name", data_type=DataType.TEXT)],
10960
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
@@ -117,29 +68,25 @@ def test_near_text_diversity(collection_factory: CollectionFactory) -> None:
11768

11869
result = collection.query.near_text(
11970
query="fruit",
120-
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
71+
diversity_selection=Diversity.mmr(limit=3, balance=0.5),
12172
)
12273
assert len(result.objects) == 3
12374

12475

125-
def test_near_vector_balance_0_differs_from_balance_1(
126-
collection_factory: CollectionFactory,
127-
) -> None:
128-
"""Test that MMR balance=0 (pure diversity) produces a different ordering than balance=1."""
76+
def test_near_object_diversity_selection(collection_factory: CollectionFactory) -> None:
77+
"""Smoke test: diversity_selection kwarg is wired through the near_object entry point."""
12978
collection = _create_clustered_collection(collection_factory)
130-
balance_0 = collection.query.near_vector(
131-
near_vector=[1.0, 0.0, 0.0],
132-
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
133-
).objects
134-
balance_1 = collection.query.near_vector(
135-
near_vector=[1.0, 0.0, 0.0],
136-
diversity_selection=DiversitySelection.mmr(limit=3, balance=1.0),
137-
).objects
138-
assert [o.uuid for o in balance_0] != [o.uuid for o in balance_1]
79+
anchor = next(iter(collection.query.fetch_objects().objects)).uuid
80+
81+
result = collection.query.near_object(
82+
near_object=anchor,
83+
diversity_selection=Diversity.mmr(limit=3, balance=0.5),
84+
)
85+
assert len(result.objects) == 3
13986

14087

141-
def test_near_text_generate_diversity(collection_factory: CollectionFactory) -> None:
142-
"""Generate namespace (collection.generate.near_text) also supports diversity selection."""
88+
def test_generate_diversity_selection(collection_factory: CollectionFactory) -> None:
89+
"""Smoke test: diversity_selection kwarg is wired through the .generate namespace."""
14390
collection = collection_factory(
14491
properties=[Property(name="name", data_type=DataType.TEXT)],
14592
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
@@ -155,6 +102,16 @@ def test_near_text_generate_diversity(collection_factory: CollectionFactory) ->
155102
result = collection.generate.near_text(
156103
query="fruit",
157104
single_prompt="Describe {name}",
158-
diversity_selection=DiversitySelection.mmr(limit=3, balance=0.0),
105+
diversity_selection=Diversity.mmr(limit=3, balance=0.5),
159106
)
160107
assert len(result.objects) == 3
108+
109+
110+
def test_diversity_selection_api_surface() -> None:
111+
"""Test the public API surface of Diversity: factory guard + mmr factory method."""
112+
# Direct instantiation of the factory class fails
113+
with pytest.raises(TypeError):
114+
Diversity()
115+
116+
# Diversity.mmr() produces an MMR-configured selection object
117+
assert Diversity.mmr(limit=3, balance=0.5).limit == 3

weaviate/classes/query.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
BM25OperatorFactory as BM25Operator,
66
)
77
from weaviate.collections.classes.grpc import (
8-
MMR,
9-
DiversitySelection,
8+
Diversity,
109
GroupBy,
1110
HybridFusion,
1211
HybridVector,
@@ -23,7 +22,7 @@
2322
from weaviate.collections.classes.types import GeoCoordinate
2423

2524
__all__ = [
26-
"DiversitySelection",
25+
"Diversity",
2726
"Filter",
2827
"FilterReturn",
2928
"GeoCoordinate",
@@ -32,7 +31,6 @@
3231
"HybridFusion",
3332
"HybridVector",
3433
"BM25Operator",
35-
"MMR",
3634
"MetadataQuery",
3735
"Metrics",
3836
"Move",

weaviate/collections/classes/grpc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,16 +281,14 @@ class MMR:
281281
balance: Optional[float] = None
282282

283283

284-
class DiversitySelection:
284+
class Diversity:
285285
"""Use this factory class to apply diversity selection to search results via MMR."""
286286

287287
def __init__(self) -> None:
288288
raise TypeError(
289-
"DiversitySelection cannot be instantiated directly. Use DiversitySelection.mmr(...)."
289+
"Diversity cannot be instantiated directly. Use Diversity.mmr(...)."
290290
)
291291

292-
MMR = MMR
293-
294292
@staticmethod
295293
def mmr(limit: Optional[int] = None, balance: Optional[float] = None) -> MMR:
296294
"""Maximal Marginal Relevance diversity selection.

0 commit comments

Comments
 (0)