Skip to content

Commit 3c53626

Browse files
committed
Add Boost API
1 parent 6dce105 commit 3c53626

26 files changed

Lines changed: 968 additions & 153 deletions

File tree

weaviate/classes/query.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
NearVector,
1515
QueryNested,
1616
QueryReference,
17+
Boost,
1718
Rerank,
1819
Sort,
1920
TargetVectors,
@@ -36,6 +37,7 @@
3637
"QueryNested",
3738
"QueryReference",
3839
"NearVector",
40+
"Boost",
3941
"Rerank",
4042
"Sort",
4143
"TargetVectors",

weaviate/collections/classes/grpc.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from datetime import datetime, timedelta
23
from enum import Enum, auto
34
from typing import (
45
Any,
@@ -242,6 +243,199 @@ class Rerank(_WeaviateInput):
242243
query: Optional[str] = Field(default=None)
243244

244245

246+
247+
@dataclass
248+
class _DecayFunction:
249+
property: str
250+
origin: str
251+
scale: str
252+
offset: Optional[str] = None
253+
curve: Optional[str] = None
254+
decay_value: Optional[float] = None
255+
256+
257+
@dataclass
258+
class _PropertyValueFunction:
259+
property: str
260+
modifier: Optional[str] = None
261+
262+
263+
@dataclass
264+
class _BoostCondition:
265+
filter: Optional[Any] = None # FilterReturn
266+
decay: Optional[_DecayFunction] = None
267+
property_value: Optional[_PropertyValueFunction] = None
268+
weight: Optional[float] = None
269+
270+
271+
@dataclass
272+
class _Boost:
273+
conditions: List[_BoostCondition]
274+
weight: Optional[float] = None
275+
depth: Optional[int] = None
276+
277+
278+
def _decay_value_to_str(val: Union[str, int, float, timedelta, datetime]) -> str:
279+
"""Convert a decay parameter value to the string format expected by the server."""
280+
if isinstance(val, timedelta):
281+
total_seconds = val.total_seconds()
282+
if total_seconds >= 86400 and total_seconds % 86400 == 0:
283+
return f"{int(total_seconds // 86400)}d"
284+
if total_seconds >= 3600 and total_seconds % 3600 == 0:
285+
return f"{int(total_seconds // 3600)}h"
286+
if total_seconds >= 60 and total_seconds % 60 == 0:
287+
return f"{int(total_seconds // 60)}m"
288+
if total_seconds == int(total_seconds):
289+
return f"{int(total_seconds)}s"
290+
return f"{total_seconds}s"
291+
if isinstance(val, datetime):
292+
return val.isoformat()
293+
return str(val)
294+
295+
296+
class _BoostCurve(str, BaseEnum):
297+
"""Decay curve type for distance-based rank scoring."""
298+
299+
EXPONENTIAL = "exp"
300+
GAUSSIAN = "gauss"
301+
LINEAR = "linear"
302+
303+
304+
class _BoostModifier(str, BaseEnum):
305+
"""Score modifier for property-value rank scoring."""
306+
307+
NONE = "none"
308+
LOG1P = "log1p"
309+
SQRT = "sqrt"
310+
311+
312+
class Boost:
313+
"""Define soft-ranking conditions to boost or demote matching documents without excluding them.
314+
315+
Use the static methods `boost()`, `decay()`, and `blend()` to create rank configurations.
316+
"""
317+
318+
Curve = _BoostCurve
319+
Modifier = _BoostModifier
320+
321+
def __init__(self) -> None:
322+
raise TypeError("Boost cannot be instantiated. Use the static methods to create a rank.")
323+
324+
@staticmethod
325+
def filter(
326+
filter: Any,
327+
*,
328+
weight: Optional[float] = None,
329+
depth: Optional[int] = None,
330+
) -> _Boost:
331+
"""Boost or demote results matching a filter condition.
332+
333+
Args:
334+
filter: The filter condition (same as used in `filters=` parameter).
335+
weight: Blending weight [0,1] controlling how much the rank affects final scores.
336+
depth: Number of results to rescore (default 100, max 10000). Higher values improve accuracy at the cost of performance.
337+
"""
338+
return _Boost(conditions=[_BoostCondition(filter=filter)], weight=weight, depth=depth)
339+
340+
@staticmethod
341+
def decay(
342+
property: str,
343+
*,
344+
origin: Optional[Union[str, int, float, datetime]] = None,
345+
scale: Union[str, int, float, timedelta],
346+
offset: Optional[Union[str, int, float, timedelta]] = None,
347+
curve: Optional[Union[_BoostCurve, str]] = None,
348+
decay_value: Optional[float] = None,
349+
weight: Optional[float] = None,
350+
depth: Optional[int] = None,
351+
) -> _Boost:
352+
"""Apply distance-based decay scoring from an origin value.
353+
354+
Args:
355+
property: The property name to compute distance from.
356+
origin: The origin point. Use "now" for current time, a datetime for a specific time,
357+
or a numeric value for number properties. Defaults to "now" for date properties.
358+
scale: Distance from origin where score equals decay_value. Use timedelta for date
359+
properties (e.g. timedelta(days=7)) or a number for numeric properties. String
360+
shorthands like "7d", "24h" are also accepted.
361+
offset: Documents within this distance from origin get full score (default "0").
362+
Accepts the same types as scale.
363+
curve: Decay curve type: `Boost.Curve.EXPONENTIAL` (default), `Boost.Curve.GAUSSIAN`, or `Boost.Curve.LINEAR`.
364+
decay_value: Score at scale distance from origin (default 0.5).
365+
weight: Blending weight [0,1] controlling how much the rank affects final scores.
366+
depth: Number of results to rescore (default 100, max 10000). Higher values improve accuracy at the cost of performance.
367+
"""
368+
return _Boost(
369+
conditions=[
370+
_BoostCondition(
371+
decay=_DecayFunction(
372+
property=property,
373+
origin=_decay_value_to_str(origin) if origin is not None else "",
374+
scale=_decay_value_to_str(scale),
375+
offset=_decay_value_to_str(offset) if offset is not None else None,
376+
curve=curve.value if isinstance(curve, _BoostCurve) else curve,
377+
decay_value=decay_value,
378+
)
379+
)
380+
],
381+
weight=weight,
382+
depth=depth,
383+
)
384+
385+
@staticmethod
386+
def property(
387+
name: str,
388+
*,
389+
modifier: Optional[Union[_BoostModifier, str]] = None,
390+
weight: Optional[float] = None,
391+
depth: Optional[int] = None,
392+
) -> _Boost:
393+
"""Rank by a numeric property's value directly.
394+
395+
Args:
396+
name: The property name to use as a ranking signal.
397+
modifier: Score modifier: `Boost.Modifier.NONE` (default), `Boost.Modifier.LOG1P`, or `Boost.Modifier.SQRT`.
398+
weight: Blending weight [0,1] controlling how much the rank affects final scores.
399+
depth: Number of results to rescore (default 100, max 10000).
400+
"""
401+
return _Boost(
402+
conditions=[
403+
_BoostCondition(
404+
property_value=_PropertyValueFunction(
405+
property=name,
406+
modifier=modifier.value if isinstance(modifier, _BoostModifier) else modifier,
407+
)
408+
)
409+
],
410+
weight=weight,
411+
depth=depth,
412+
)
413+
414+
@staticmethod
415+
def blend(
416+
*ranks: _Boost,
417+
weight: Optional[float] = None,
418+
depth: Optional[int] = None,
419+
) -> _Boost:
420+
"""Combine multiple rank conditions with individual weights.
421+
422+
When blending, each sub-rank's weight becomes a per-condition weight,
423+
and the `weight` parameter here controls the overall blending strength.
424+
425+
Args:
426+
*ranks: Rank objects created via `Boost.filter()`, `Boost.decay()`, or `Boost.property()`.
427+
weight: Overall blending weight [0,1] for combining primary search and rank scores.
428+
depth: Number of results to rescore (default 100, max 10000). Higher values improve accuracy at the cost of performance.
429+
"""
430+
conditions: List[_BoostCondition] = []
431+
for r in ranks:
432+
for cond in r.conditions:
433+
if cond.weight is None and r.weight is not None:
434+
cond = _BoostCondition(filter=cond.filter, decay=cond.decay, property_value=cond.property_value, weight=r.weight)
435+
conditions.append(cond)
436+
return _Boost(conditions=conditions, weight=weight, depth=depth)
437+
438+
245439
@dataclass
246440
class BM25OperatorOptions:
247441
# replace with ClassVar[base_search_pb2.SearchOperatorOptions.Operator] once python 3.10 is removed

weaviate/collections/grpc/query.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_MetadataQuery,
3535
_QueryReference,
3636
_QueryReferenceMultiTarget,
37+
_Boost,
3738
_Sorting,
3839
)
3940
from weaviate.collections.classes.internal import (
@@ -121,6 +122,7 @@ def get(
121122
return_references: Optional[REFERENCES] = None,
122123
generative: Optional[_Generative] = None,
123124
rerank: Optional[Rerank] = None,
125+
boost: Optional[_Boost] = None,
124126
) -> search_get_pb2.SearchRequest:
125127
if self._validate_arguments:
126128
_validate_input(_ValidateArgument([_Sorting, None], "sort", sort))
@@ -143,6 +145,7 @@ def get(
143145
return_references=return_references,
144146
generative=generative,
145147
rerank=rerank,
148+
boost=boost,
146149
sort_by=sort_by,
147150
)
148151

@@ -166,6 +169,7 @@ def hybrid(
166169
return_references: Optional[REFERENCES] = None,
167170
generative: Optional[_Generative] = None,
168171
rerank: Optional[Rerank] = None,
172+
boost: Optional[_Boost] = None,
169173
target_vector: Optional[TargetVectorJoinType] = None,
170174
) -> search_get_pb2.SearchRequest:
171175
return self.__create_request(
@@ -178,6 +182,7 @@ def hybrid(
178182
return_references=return_references,
179183
generative=generative,
180184
rerank=rerank,
185+
boost=boost,
181186
autocut=autocut,
182187
hybrid_search=self._parse_hybrid(
183188
query,
@@ -207,6 +212,7 @@ def bm25(
207212
return_references: Optional[REFERENCES] = None,
208213
generative: Optional[_Generative] = None,
209214
rerank: Optional[Rerank] = None,
215+
boost: Optional[_Boost] = None,
210216
) -> search_get_pb2.SearchRequest:
211217
if self._validate_arguments:
212218
_validate_input(
@@ -226,6 +232,7 @@ def bm25(
226232
return_references=return_references,
227233
generative=generative,
228234
rerank=rerank,
235+
boost=boost,
229236
autocut=autocut,
230237
bm25=(
231238
base_search_pb2.BM25(
@@ -258,6 +265,7 @@ def near_vector(
258265
group_by: Optional[_GroupBy] = None,
259266
generative: Optional[_Generative] = None,
260267
rerank: Optional[Rerank] = None,
268+
boost: Optional[_Boost] = None,
261269
target_vector: Optional[TargetVectorJoinType] = None,
262270
return_metadata: Optional[_MetadataQuery] = None,
263271
return_properties: Union[PROPERTIES, bool, None] = None,
@@ -272,6 +280,7 @@ def near_vector(
272280
return_references=return_references,
273281
generative=generative,
274282
rerank=rerank,
283+
boost=boost,
275284
autocut=autocut,
276285
group_by=group_by,
277286
near_vector=self._parse_near_vector(
@@ -292,6 +301,7 @@ def near_object(
292301
group_by: Optional[_GroupBy] = None,
293302
generative: Optional[_Generative] = None,
294303
rerank: Optional[Rerank] = None,
304+
boost: Optional[_Boost] = None,
295305
target_vector: Optional[TargetVectorJoinType] = None,
296306
return_metadata: Optional[_MetadataQuery] = None,
297307
return_properties: Union[PROPERTIES, bool, None] = None,
@@ -306,6 +316,7 @@ def near_object(
306316
return_references=return_references,
307317
generative=generative,
308318
rerank=rerank,
319+
boost=boost,
309320
autocut=autocut,
310321
group_by=group_by,
311322
near_object=self._parse_near_object(near_object, certainty, distance, target_vector),
@@ -326,6 +337,7 @@ def near_text(
326337
group_by: Optional[_GroupBy] = None,
327338
generative: Optional[_Generative] = None,
328339
rerank: Optional[Rerank] = None,
340+
boost: Optional[_Boost] = None,
329341
target_vector: Optional[TargetVectorJoinType] = None,
330342
return_metadata: Optional[_MetadataQuery] = None,
331343
return_properties: Union[PROPERTIES, bool, None] = None,
@@ -340,6 +352,7 @@ def near_text(
340352
return_references=return_references,
341353
generative=generative,
342354
rerank=rerank,
355+
boost=boost,
343356
autocut=autocut,
344357
group_by=group_by,
345358
near_text=self._parse_near_text(
@@ -366,6 +379,7 @@ def near_media(
366379
group_by: Optional[_GroupBy] = None,
367380
generative: Optional[_Generative] = None,
368381
rerank: Optional[Rerank] = None,
382+
boost: Optional[_Boost] = None,
369383
target_vector: Optional[TargetVectorJoinType] = None,
370384
return_metadata: Optional[_MetadataQuery] = None,
371385
return_properties: Union[PROPERTIES, bool, None] = None,
@@ -380,6 +394,7 @@ def near_media(
380394
return_references=return_references,
381395
generative=generative,
382396
rerank=rerank,
397+
boost=boost,
383398
autocut=autocut,
384399
group_by=group_by,
385400
**self._parse_media(
@@ -402,6 +417,7 @@ def __create_request(
402417
return_references: Optional[REFERENCES] = None,
403418
generative: Optional[_Generative] = None,
404419
rerank: Optional[Rerank] = None,
420+
boost: Optional[_Boost] = None,
405421
autocut: Optional[int] = None,
406422
group_by: Optional[_GroupBy] = None,
407423
near_vector: Optional[base_search_pb2.NearVector] = None,
@@ -495,6 +511,7 @@ def __create_request(
495511
if rerank is not None
496512
else None
497513
),
514+
boost=self.__boost_to_grpc(boost),
498515
near_vector=near_vector,
499516
sort_by=sort_by,
500517
hybrid_search=hybrid_search,
@@ -523,6 +540,40 @@ def _metadata_to_grpc(self, metadata: _MetadataQuery) -> search_get_pb2.Metadata
523540
vectors=metadata.vectors,
524541
)
525542

543+
def __boost_to_grpc(
544+
self, boost: Optional[_Boost]
545+
) -> Optional[search_get_pb2.Boost]:
546+
if boost is None:
547+
return None
548+
conditions = []
549+
for cond in boost.conditions:
550+
grpc_cond = search_get_pb2.BoostCondition(
551+
filter=_FilterToGRPC.convert(cond.filter) if cond.filter is not None else None,
552+
decay=(
553+
search_get_pb2.DecayFunction(
554+
path=[cond.decay.property],
555+
origin=cond.decay.origin,
556+
scale=cond.decay.scale,
557+
offset=cond.decay.offset,
558+
curve=cond.decay.curve,
559+
decay_value=cond.decay.decay_value,
560+
)
561+
if cond.decay is not None
562+
else None
563+
),
564+
property_value=(
565+
search_get_pb2.PropertyValueFunction(
566+
path=[cond.property_value.property],
567+
modifier=cond.property_value.modifier,
568+
)
569+
if cond.property_value is not None
570+
else None
571+
),
572+
weight=cond.weight,
573+
)
574+
conditions.append(grpc_cond)
575+
return search_get_pb2.Boost(conditions=conditions, weight=boost.weight, depth=boost.depth)
576+
526577
def __resolve_property(self, prop: QueryNested) -> search_get_pb2.ObjectPropertiesRequest:
527578
props = prop.properties if isinstance(prop.properties, list) else [prop.properties]
528579
return search_get_pb2.ObjectPropertiesRequest(

0 commit comments

Comments
 (0)