Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions weaviate/classes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NearVector,
QueryNested,
QueryReference,
Boost,
Rerank,
Sort,
TargetVectors,
Expand All @@ -36,6 +37,7 @@
"QueryNested",
"QueryReference",
"NearVector",
"Boost",
"Rerank",
"Sort",
"TargetVectors",
Expand Down
194 changes: 194 additions & 0 deletions weaviate/collections/classes/grpc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum, auto
from typing import (
Any,
Expand Down Expand Up @@ -242,6 +243,199 @@ class Rerank(_WeaviateInput):
query: Optional[str] = Field(default=None)



@dataclass
class _DecayFunction:
property: str
origin: str
scale: str
offset: Optional[str] = None
curve: Optional[str] = None
decay_value: Optional[float] = None


@dataclass
class _PropertyValueFunction:
property: str
modifier: Optional[str] = None


@dataclass
class _BoostCondition:
filter: Optional[Any] = None # FilterReturn
decay: Optional[_DecayFunction] = None
property_value: Optional[_PropertyValueFunction] = None
weight: Optional[float] = None


@dataclass
class _Boost:
conditions: List[_BoostCondition]
weight: Optional[float] = None
depth: Optional[int] = None


def _decay_value_to_str(val: Union[str, int, float, timedelta, datetime]) -> str:
"""Convert a decay parameter value to the string format expected by the server."""
if isinstance(val, timedelta):
total_seconds = val.total_seconds()
if total_seconds >= 86400 and total_seconds % 86400 == 0:
return f"{int(total_seconds // 86400)}d"
if total_seconds >= 3600 and total_seconds % 3600 == 0:
return f"{int(total_seconds // 3600)}h"
if total_seconds >= 60 and total_seconds % 60 == 0:
return f"{int(total_seconds // 60)}m"
if total_seconds == int(total_seconds):
return f"{int(total_seconds)}s"
return f"{total_seconds}s"
if isinstance(val, datetime):
return val.isoformat()
return str(val)


class _BoostCurve(str, BaseEnum):
"""Decay curve type for distance-based rank scoring."""

EXPONENTIAL = "exp"
GAUSSIAN = "gauss"
LINEAR = "linear"


class _BoostModifier(str, BaseEnum):
"""Score modifier for property-value rank scoring."""

NONE = "none"
LOG1P = "log1p"
SQRT = "sqrt"


class Boost:
"""Define soft-ranking conditions to boost or demote matching documents without excluding them.

Use the static methods `boost()`, `decay()`, and `blend()` to create rank configurations.
"""

Curve = _BoostCurve
Modifier = _BoostModifier

def __init__(self) -> None:
raise TypeError("Boost cannot be instantiated. Use the static methods to create a rank.")

@staticmethod
def filter(
filter: Any,
*,
weight: Optional[float] = None,
depth: Optional[int] = None,
) -> _Boost:
"""Boost or demote results matching a filter condition.

Args:
filter: The filter condition (same as used in `filters=` parameter).
weight: Blending weight [0,1] controlling how much the rank affects final scores.
depth: Number of results to rescore (default 100, max 10000). Higher values improve accuracy at the cost of performance.
"""
return _Boost(conditions=[_BoostCondition(filter=filter)], weight=weight, depth=depth)

@staticmethod
def decay(
property: str,
*,
origin: Optional[Union[str, int, float, datetime]] = None,
scale: Union[str, int, float, timedelta],
offset: Optional[Union[str, int, float, timedelta]] = None,
curve: Optional[Union[_BoostCurve, str]] = None,
decay_value: Optional[float] = None,
weight: Optional[float] = None,
depth: Optional[int] = None,
) -> _Boost:
"""Apply distance-based decay scoring from an origin value.

Args:
property: The property name to compute distance from.
origin: The origin point. Use "now" for current time, a datetime for a specific time,
or a numeric value for number properties. Defaults to "now" for date properties.
scale: Distance from origin where score equals decay_value. Use timedelta for date
properties (e.g. timedelta(days=7)) or a number for numeric properties. String
shorthands like "7d", "24h" are also accepted.
offset: Documents within this distance from origin get full score (default "0").
Accepts the same types as scale.
curve: Decay curve type: `Boost.Curve.EXPONENTIAL` (default), `Boost.Curve.GAUSSIAN`, or `Boost.Curve.LINEAR`.
decay_value: Score at scale distance from origin (default 0.5).
weight: Blending weight [0,1] controlling how much the rank affects final scores.
depth: Number of results to rescore (default 100, max 10000). Higher values improve accuracy at the cost of performance.
"""
return _Boost(
conditions=[
_BoostCondition(
decay=_DecayFunction(
property=property,
origin=_decay_value_to_str(origin) if origin is not None else "",
scale=_decay_value_to_str(scale),
offset=_decay_value_to_str(offset) if offset is not None else None,
curve=curve.value if isinstance(curve, _BoostCurve) else curve,
decay_value=decay_value,
)
)
],
weight=weight,
depth=depth,
)

@staticmethod
def property(
name: str,
*,
modifier: Optional[Union[_BoostModifier, str]] = None,
weight: Optional[float] = None,
depth: Optional[int] = None,
) -> _Boost:
"""Rank by a numeric property's value directly.

Args:
name: The property name to use as a ranking signal.
modifier: Score modifier: `Boost.Modifier.NONE` (default), `Boost.Modifier.LOG1P`, or `Boost.Modifier.SQRT`.
weight: Blending weight [0,1] controlling how much the rank affects final scores.
depth: Number of results to rescore (default 100, max 10000).
"""
return _Boost(
conditions=[
_BoostCondition(
property_value=_PropertyValueFunction(
property=name,
modifier=modifier.value if isinstance(modifier, _BoostModifier) else modifier,
)
)
],
weight=weight,
depth=depth,
)

@staticmethod
def blend(
*ranks: _Boost,
weight: Optional[float] = None,
depth: Optional[int] = None,
) -> _Boost:
"""Combine multiple rank conditions with individual weights.

When blending, each sub-rank's weight becomes a per-condition weight,
and the `weight` parameter here controls the overall blending strength.

Args:
*ranks: Rank objects created via `Boost.filter()`, `Boost.decay()`, or `Boost.property()`.
weight: Overall blending weight [0,1] for combining primary search and rank scores.
depth: Number of results to rescore (default 100, max 10000). Higher values improve accuracy at the cost of performance.
"""
conditions: List[_BoostCondition] = []
for r in ranks:
for cond in r.conditions:
if cond.weight is None and r.weight is not None:
cond = _BoostCondition(filter=cond.filter, decay=cond.decay, property_value=cond.property_value, weight=r.weight)
conditions.append(cond)
return _Boost(conditions=conditions, weight=weight, depth=depth)


@dataclass
class BM25OperatorOptions:
# replace with ClassVar[base_search_pb2.SearchOperatorOptions.Operator] once python 3.10 is removed
Expand Down
51 changes: 51 additions & 0 deletions weaviate/collections/grpc/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_MetadataQuery,
_QueryReference,
_QueryReferenceMultiTarget,
_Boost,
_Sorting,
)
from weaviate.collections.classes.internal import (
Expand Down Expand Up @@ -121,6 +122,7 @@ def get(
return_references: Optional[REFERENCES] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
) -> search_get_pb2.SearchRequest:
if self._validate_arguments:
_validate_input(_ValidateArgument([_Sorting, None], "sort", sort))
Expand All @@ -143,6 +145,7 @@ def get(
return_references=return_references,
generative=generative,
rerank=rerank,
boost=boost,
sort_by=sort_by,
)

Expand All @@ -166,6 +169,7 @@ def hybrid(
return_references: Optional[REFERENCES] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
target_vector: Optional[TargetVectorJoinType] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
Expand All @@ -178,6 +182,7 @@ def hybrid(
return_references=return_references,
generative=generative,
rerank=rerank,
boost=boost,
autocut=autocut,
hybrid_search=self._parse_hybrid(
query,
Expand Down Expand Up @@ -207,6 +212,7 @@ def bm25(
return_references: Optional[REFERENCES] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
) -> search_get_pb2.SearchRequest:
if self._validate_arguments:
_validate_input(
Expand All @@ -226,6 +232,7 @@ def bm25(
return_references=return_references,
generative=generative,
rerank=rerank,
boost=boost,
autocut=autocut,
bm25=(
base_search_pb2.BM25(
Expand Down Expand Up @@ -258,6 +265,7 @@ def near_vector(
group_by: Optional[_GroupBy] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
target_vector: Optional[TargetVectorJoinType] = None,
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
Expand All @@ -272,6 +280,7 @@ def near_vector(
return_references=return_references,
generative=generative,
rerank=rerank,
boost=boost,
autocut=autocut,
group_by=group_by,
near_vector=self._parse_near_vector(
Expand All @@ -292,6 +301,7 @@ def near_object(
group_by: Optional[_GroupBy] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
target_vector: Optional[TargetVectorJoinType] = None,
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
Expand All @@ -306,6 +316,7 @@ def near_object(
return_references=return_references,
generative=generative,
rerank=rerank,
boost=boost,
autocut=autocut,
group_by=group_by,
near_object=self._parse_near_object(near_object, certainty, distance, target_vector),
Expand All @@ -326,6 +337,7 @@ def near_text(
group_by: Optional[_GroupBy] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
target_vector: Optional[TargetVectorJoinType] = None,
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
Expand All @@ -340,6 +352,7 @@ def near_text(
return_references=return_references,
generative=generative,
rerank=rerank,
boost=boost,
autocut=autocut,
group_by=group_by,
near_text=self._parse_near_text(
Expand All @@ -366,6 +379,7 @@ def near_media(
group_by: Optional[_GroupBy] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
target_vector: Optional[TargetVectorJoinType] = None,
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
Expand All @@ -380,6 +394,7 @@ def near_media(
return_references=return_references,
generative=generative,
rerank=rerank,
boost=boost,
autocut=autocut,
group_by=group_by,
**self._parse_media(
Expand All @@ -402,6 +417,7 @@ def __create_request(
return_references: Optional[REFERENCES] = None,
generative: Optional[_Generative] = None,
rerank: Optional[Rerank] = None,
boost: Optional[_Boost] = None,
autocut: Optional[int] = None,
group_by: Optional[_GroupBy] = None,
near_vector: Optional[base_search_pb2.NearVector] = None,
Expand Down Expand Up @@ -495,6 +511,7 @@ def __create_request(
if rerank is not None
else None
),
boost=self.__boost_to_grpc(boost),
near_vector=near_vector,
sort_by=sort_by,
hybrid_search=hybrid_search,
Expand Down Expand Up @@ -523,6 +540,40 @@ def _metadata_to_grpc(self, metadata: _MetadataQuery) -> search_get_pb2.Metadata
vectors=metadata.vectors,
)

def __boost_to_grpc(
self, boost: Optional[_Boost]
) -> Optional[search_get_pb2.Boost]:
if boost is None:
return None
conditions = []
for cond in boost.conditions:
grpc_cond = search_get_pb2.BoostCondition(
filter=_FilterToGRPC.convert(cond.filter) if cond.filter is not None else None,
decay=(
search_get_pb2.DecayFunction(
path=[cond.decay.property],
origin=cond.decay.origin,
scale=cond.decay.scale,
offset=cond.decay.offset,
curve=cond.decay.curve,
decay_value=cond.decay.decay_value,
)
if cond.decay is not None
else None
),
property_value=(
search_get_pb2.PropertyValueFunction(
path=[cond.property_value.property],
modifier=cond.property_value.modifier,
)
if cond.property_value is not None
else None
),
weight=cond.weight,
)
conditions.append(grpc_cond)
return search_get_pb2.Boost(conditions=conditions, weight=boost.weight, depth=boost.depth)

def __resolve_property(self, prop: QueryNested) -> search_get_pb2.ObjectPropertiesRequest:
props = prop.properties if isinstance(prop.properties, list) else [prop.properties]
return search_get_pb2.ObjectPropertiesRequest(
Expand Down
Loading