11import pytest
22
33from integration .conftest import CollectionFactory
4- from weaviate .classes .query import Diversity
4+ from weaviate .classes .query import DiversitySelection
55from weaviate .collections .classes .config import Configure , DataType , Property
66from weaviate .collections .classes .data import DataObject
77
@@ -36,7 +36,7 @@ def test_near_vector_diversity_pure_relevance(
3636 baseline = collection .query .near_vector (near_vector = [1.0 , 0.0 , 0.0 ], limit = 3 ).objects
3737 diverse = collection .query .near_vector (
3838 near_vector = [1.0 , 0.0 , 0.0 ],
39- diversity_selection = Diversity .mmr (limit = 3 , balance = 1.0 ),
39+ diversity_selection = DiversitySelection .mmr (limit = 3 , balance = 1.0 ),
4040 ).objects
4141
4242 assert [o .properties ["text" ] for o in baseline ] == [o .properties ["text" ] for o in diverse ]
@@ -50,7 +50,7 @@ def test_near_vector_diversity_pure_diversity(
5050
5151 result = collection .query .near_vector (
5252 near_vector = [1.0 , 0.0 , 0.0 ],
53- diversity_selection = Diversity .mmr (limit = 3 , balance = 0.0 ),
53+ diversity_selection = DiversitySelection .mmr (limit = 3 , balance = 0.0 ),
5454 )
5555 texts = {o .properties ["text" ] for o in result .objects }
5656 assert len (texts ) == 3
@@ -62,11 +62,11 @@ def test_near_vector_diversity_pure_diversity(
6262def test_near_vector_diversity_with_mmr_class (
6363 collection_factory : CollectionFactory ,
6464) -> None :
65- """Direct MMR class construction (Diversity .MMR) also works, not just the factory."""
65+ """Direct MMR class construction (DiversitySelection .MMR) also works, not just the factory."""
6666 collection = _create_clustered_collection (collection_factory )
6767 result = collection .query .near_vector (
6868 near_vector = [1.0 , 0.0 , 0.0 ],
69- diversity_selection = Diversity .MMR (limit = 3 , balance = 0.0 ),
69+ diversity_selection = DiversitySelection .MMR (limit = 3 , balance = 0.0 ),
7070 )
7171 clusters = {o .properties ["text" ][0 ] for o in result .objects }
7272 assert clusters == {"a" , "b" , "c" }
@@ -79,25 +79,25 @@ def test_near_object_diversity(collection_factory: CollectionFactory) -> None:
7979
8080 result = collection .query .near_object (
8181 near_object = anchor ,
82- diversity_selection = Diversity .mmr (limit = 3 , balance = 0.0 ),
82+ diversity_selection = DiversitySelection .mmr (limit = 3 , balance = 0.0 ),
8383 )
8484 assert len (result .objects ) == 3
8585 clusters = {o .properties ["text" ][0 ] for o in result .objects }
8686 assert len (clusters ) == 3
8787
8888
8989def test_diversity_cannot_be_instantiated () -> None :
90- """Diversity is a factory — direct instantiation should fail."""
90+ """DiversitySelection is a factory — direct instantiation should fail."""
9191 with pytest .raises (TypeError ):
92- Diversity ()
92+ DiversitySelection ()
9393
9494
9595def test_diversity_mmr_only_limit (collection_factory : CollectionFactory ) -> None :
9696 """MMR accepts just a limit (balance defaults to server-side value)."""
9797 collection = _create_clustered_collection (collection_factory )
9898 result = collection .query .near_vector (
9999 near_vector = [1.0 , 0.0 , 0.0 ],
100- diversity_selection = Diversity .mmr (limit = 2 ),
100+ diversity_selection = DiversitySelection .mmr (limit = 2 ),
101101 )
102102 assert len (result .objects ) == 2
103103
@@ -117,7 +117,7 @@ def test_near_text_diversity(collection_factory: CollectionFactory) -> None:
117117
118118 result = collection .query .near_text (
119119 query = "fruit" ,
120- diversity_selection = Diversity .mmr (limit = 3 , balance = 0.0 ),
120+ diversity_selection = DiversitySelection .mmr (limit = 3 , balance = 0.0 ),
121121 )
122122 assert len (result .objects ) == 3
123123
@@ -139,6 +139,6 @@ def test_near_text_generate_diversity(collection_factory: CollectionFactory) ->
139139 result = collection .generate .near_text (
140140 query = "fruit" ,
141141 single_prompt = "Describe {name}" ,
142- diversity_selection = Diversity .mmr (limit = 3 , balance = 0.0 ),
142+ diversity_selection = DiversitySelection .mmr (limit = 3 , balance = 0.0 ),
143143 )
144144 assert len (result .objects ) == 3
0 commit comments