Skip to content

Commit cccff14

Browse files
trengrjdirkkul
andauthored
feat: Add Boost API (#2030)
* Add Boost API * Add typing for filter * Review feedback * Improve docstrings * Enforce enum for Boost curves and modifiers * Rank -> boost naming * Stop Boost.blend() mutating in place * Explicitly set 'now' for time decay * Misc fixes * Add BoostReturn * Fix validation of blend * Review feedback, property -> numeric_property and other changes --------- Co-authored-by: Dirk Kulawiak <dirk@semi.technology>
1 parent 5a4b983 commit cccff14

53 files changed

Lines changed: 1694 additions & 130 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import pytest
2+
3+
from integration.conftest import CollectionFactory
4+
from weaviate.classes.query import Boost, Filter, MetadataQuery
5+
from weaviate.collections.classes.config import Configure, DataType, Property
6+
from weaviate.exceptions import WeaviateInvalidInputError
7+
from weaviate.collections.classes.data import DataObject
8+
9+
10+
def _create_collection(collection_factory: CollectionFactory):
11+
"""Create a collection with numeric and date properties for boost testing."""
12+
collection = collection_factory(
13+
properties=[
14+
Property(name="text", data_type=DataType.TEXT),
15+
Property(name="price", data_type=DataType.NUMBER),
16+
Property(name="rating", data_type=DataType.NUMBER),
17+
Property(name="count", data_type=DataType.INT),
18+
Property(name="created", data_type=DataType.DATE),
19+
],
20+
vectorizer_config=Configure.Vectorizer.none(),
21+
vector_index_config=Configure.VectorIndex.flat(),
22+
)
23+
if collection._connection._weaviate_version.is_lower_than(1, 38, 0):
24+
pytest.skip("Boost requires Weaviate >= 1.38.0")
25+
collection.data.insert_many(
26+
[
27+
DataObject(
28+
properties={
29+
"text": "cheap good",
30+
"price": 10.0,
31+
"rating": 4.9,
32+
"count": 1000,
33+
"created": "2024-01-01T00:00:00Z",
34+
},
35+
vector=[1.0, 0.0, 0.0],
36+
),
37+
DataObject(
38+
properties={
39+
"text": "cheap bad",
40+
"price": 10.0,
41+
"rating": 2.0,
42+
"count": 5,
43+
"created": "2020-01-01T00:00:00Z",
44+
},
45+
vector=[0.9, 0.1, 0.0],
46+
),
47+
DataObject(
48+
properties={
49+
"text": "expensive good",
50+
"price": 500.0,
51+
"rating": 4.8,
52+
"count": 500,
53+
"created": "2023-06-01T00:00:00Z",
54+
},
55+
vector=[0.0, 1.0, 0.0],
56+
),
57+
DataObject(
58+
properties={
59+
"text": "expensive bad",
60+
"price": 500.0,
61+
"rating": 1.5,
62+
"count": 2,
63+
"created": "2019-01-01T00:00:00Z",
64+
},
65+
vector=[0.0, 0.9, 0.1],
66+
),
67+
DataObject(
68+
properties={
69+
"text": "mid range",
70+
"price": 50.0,
71+
"rating": 3.5,
72+
"count": 100,
73+
"created": "2022-01-01T00:00:00Z",
74+
},
75+
vector=[0.0, 0.0, 1.0],
76+
),
77+
]
78+
)
79+
return collection
80+
81+
82+
def test_boost_filter(collection_factory: CollectionFactory) -> None:
83+
"""Boost results matching a filter — boosted items should score higher."""
84+
collection = _create_collection(collection_factory)
85+
86+
baseline = collection.query.near_vector(
87+
near_vector=[1.0, 0.0, 0.0],
88+
limit=5,
89+
return_metadata=MetadataQuery(distance=True),
90+
).objects
91+
92+
boosted = collection.query.near_vector(
93+
near_vector=[1.0, 0.0, 0.0],
94+
limit=5,
95+
boost=Boost.filter(
96+
Filter.by_property("rating").greater_or_equal(4.0),
97+
weight=1.0,
98+
),
99+
return_metadata=MetadataQuery(distance=True),
100+
).objects
101+
102+
assert len(boosted) == 5
103+
# The boost should change the ordering compared to baseline
104+
assert [o.uuid for o in baseline] != [o.uuid for o in boosted]
105+
106+
107+
def test_boost_numeric_decay(collection_factory: CollectionFactory) -> None:
108+
"""Numeric decay: prefer items with price near the origin."""
109+
collection = _create_collection(collection_factory)
110+
111+
result = collection.query.near_vector(
112+
near_vector=[1.0, 0.0, 0.0],
113+
limit=5,
114+
boost=Boost.numeric_decay(
115+
"price",
116+
origin=50.0,
117+
scale=20.0,
118+
curve=Boost.Curve.LINEAR,
119+
decay=0.5,
120+
weight=1.0,
121+
),
122+
return_metadata=MetadataQuery(distance=True),
123+
).objects
124+
125+
assert len(result) == 5
126+
127+
128+
def test_boost_time_decay(collection_factory: CollectionFactory) -> None:
129+
"""Time decay: prefer items with dates closer to origin."""
130+
collection = _create_collection(collection_factory)
131+
132+
result = collection.query.near_vector(
133+
near_vector=[1.0, 0.0, 0.0],
134+
limit=5,
135+
boost=Boost.time_decay(
136+
"created",
137+
origin="2024-01-01T00:00:00Z",
138+
scale="365d",
139+
curve=Boost.Curve.EXPONENTIAL,
140+
decay=0.3,
141+
weight=1.0,
142+
),
143+
return_metadata=MetadataQuery(distance=True),
144+
).objects
145+
146+
assert len(result) == 5
147+
148+
149+
def test_boost_property_value(collection_factory: CollectionFactory) -> None:
150+
"""Property value boost: rank by a numeric property directly."""
151+
collection = _create_collection(collection_factory)
152+
153+
result = collection.query.near_vector(
154+
near_vector=[1.0, 0.0, 0.0],
155+
limit=5,
156+
boost=Boost.numeric_property(
157+
"count",
158+
modifier=Boost.Modifier.LOG1P,
159+
weight=1.0,
160+
),
161+
return_metadata=MetadataQuery(distance=True),
162+
).objects
163+
164+
assert len(result) == 5
165+
166+
167+
def test_boost_blend(collection_factory: CollectionFactory) -> None:
168+
"""Blend multiple boost conditions together."""
169+
collection = _create_collection(collection_factory)
170+
171+
result = collection.query.near_vector(
172+
near_vector=[1.0, 0.0, 0.0],
173+
limit=5,
174+
boost=Boost.blend(
175+
[
176+
Boost.filter(
177+
Filter.by_property("rating").greater_or_equal(4.0),
178+
weight=2.0,
179+
),
180+
Boost.numeric_decay(
181+
"price",
182+
origin=30.0,
183+
scale=100.0,
184+
curve=Boost.Curve.EXPONENTIAL,
185+
),
186+
],
187+
weight=0.8,
188+
),
189+
return_metadata=MetadataQuery(distance=True),
190+
).objects
191+
192+
assert len(result) == 5
193+
194+
195+
def test_boost_with_depth(collection_factory: CollectionFactory) -> None:
196+
"""Boost with explicit depth parameter."""
197+
collection = _create_collection(collection_factory)
198+
199+
result = collection.query.near_vector(
200+
near_vector=[1.0, 0.0, 0.0],
201+
limit=5,
202+
boost=Boost.filter(
203+
Filter.by_property("rating").greater_or_equal(4.0),
204+
weight=1.0,
205+
depth=100,
206+
),
207+
return_metadata=MetadataQuery(distance=True),
208+
).objects
209+
210+
assert len(result) == 5
211+
212+
213+
def test_boost_bm25(collection_factory: CollectionFactory) -> None:
214+
"""Boost works with BM25 keyword search."""
215+
collection = _create_collection(collection_factory)
216+
217+
result = collection.query.bm25(
218+
query="cheap",
219+
limit=5,
220+
boost=Boost.filter(
221+
Filter.by_property("rating").greater_or_equal(4.0),
222+
weight=1.0,
223+
),
224+
return_metadata=MetadataQuery(score=True),
225+
).objects
226+
227+
assert len(result) >= 1
228+
229+
230+
def test_boost_hybrid(collection_factory: CollectionFactory) -> None:
231+
"""Boost works with hybrid search."""
232+
collection = _create_collection(collection_factory)
233+
234+
result = collection.query.hybrid(
235+
query="cheap",
236+
vector=[1.0, 0.0, 0.0],
237+
limit=5,
238+
boost=Boost.filter(
239+
Filter.by_property("price").less_than(100.0),
240+
weight=0.6,
241+
),
242+
return_metadata=MetadataQuery(score=True),
243+
).objects
244+
245+
assert len(result) >= 1
246+
247+
248+
def test_boost_api_surface() -> None:
249+
"""Test the public API surface: factory guard + static methods."""
250+
with pytest.raises(TypeError):
251+
Boost()
252+
253+
# Static methods produce _Boost instances
254+
b = Boost.filter(
255+
Filter.by_property("x").equal("y"),
256+
weight=0.5,
257+
)
258+
assert len(b.conditions) == 1
259+
assert b.weight == 0.5
260+
261+
b = Boost.blend(
262+
[
263+
Boost.filter(Filter.by_property("x").equal("y"), weight=1.0),
264+
Boost.numeric_property("z", modifier=Boost.Modifier.LOG1P),
265+
],
266+
weight=0.8,
267+
depth=200,
268+
)
269+
assert len(b.conditions) == 2
270+
assert b.weight == 0.8
271+
assert b.depth == 200
272+
273+
# blend() also accepts a single boost
274+
b = Boost.blend(Boost.filter(Filter.by_property("x").equal("y")), weight=0.5)
275+
assert len(b.conditions) == 1
276+
assert b.weight == 0.5
277+
278+
279+
def test_boost_blend_rejects_sub_boost_depth() -> None:
280+
"""blend() raises if any sub-boost has depth set."""
281+
with pytest.raises(WeaviateInvalidInputError):
282+
Boost.blend(
283+
Boost.numeric_property("count", depth=500),
284+
depth=100,
285+
)
286+
287+
288+
def test_boost_default_curve_is_unspecified() -> None:
289+
"""Omitting curve defaults to None (sent as UNSPECIFIED on the wire)."""
290+
b = Boost.numeric_decay("price", origin=50.0, scale=20.0)
291+
assert b.conditions[0].numeric_decay.curve is None
292+
293+
b = Boost.time_decay("created", scale="7d")
294+
assert b.conditions[0].time_decay.curve is None
295+
296+
297+
def test_boost_default_modifier_is_unspecified() -> None:
298+
"""Omitting modifier defaults to None (sent as UNSPECIFIED on the wire)."""
299+
b = Boost.numeric_property("count")
300+
assert b.conditions[0].property_value.modifier is None

weaviate/classes/query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
BM25OperatorFactory as BM25Operator,
66
)
77
from weaviate.collections.classes.grpc import (
8+
Boost,
9+
BoostReturn,
810
Diversity,
911
GroupBy,
1012
HybridFusion,
@@ -38,6 +40,8 @@
3840
"QueryNested",
3941
"QueryReference",
4042
"NearVector",
43+
"Boost",
44+
"BoostReturn",
4145
"Rerank",
4246
"Sort",
4347
"TargetVectors",

0 commit comments

Comments
 (0)