Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2547,6 +2547,137 @@ def retrieve_online_documents_v2(
query_string,
)

def retrieve_online_documents_v3(
self,
features: List[str],
top_k: int,
embeddings: Optional[Dict[str, List[float]]] = None,
query_string: Optional[str] = None,
fusion_strategy: str = "AUTO",
signal_weights: Optional[Dict[str, float]] = None,
rrf_k: int = 60,
distance_metric: Optional[str] = None,
include_signal_scores: bool = False,
) -> OnlineResponse:
"""
Retrieve documents using multi-vector search with configurable fusion.

Args:
features: Feature references (e.g., ["doc_fv:title", "doc_fv:body"]).
top_k: Number of results to return.
embeddings: Map of vector field name to query vector. Required.
Single entry is equivalent to V2's embedding param.
Special case: the key "embedding" auto-resolves to the
FeatureView's single vector field, easing V2 to V3 migration.
FeatureViews with multiple vector fields require explicit names.
query_string: Text query. Ranking signal on Elasticsearch; logged + dropped
on Valkey (Valkey does not support text-as-ranking).
fusion_strategy: AUTO | RRF | WEIGHTED_LINEAR | VECTOR_ONLY.
signal_weights: Per-signal weights for WEIGHTED_LINEAR.
Keys are embedding field names and/or "bm25".
rrf_k: RRF rank constant (default 60).
distance_metric: Override distance metric.
include_signal_scores: When True, requests per-signal score
breakdowns for fusion strategies (RRF, WEIGHTED_LINEAR) at
additional latency cost. Currently a no-op — signal_scores
always follows the best-effort behavior documented in the V3
design, and the parameter is plumbed through so callers can
opt in today and automatically pick up the explain-based
path when it lands. Default False.
"""
if not embeddings:
raise ValueError(
"V3 requires at least one embedding. "
"For text-only search, use retrieve_online_documents_v2."
)

effective_strategy = fusion_strategy.upper()
valid_strategies = {"AUTO", "RRF", "WEIGHTED_LINEAR", "VECTOR_ONLY"}
if effective_strategy not in valid_strategies:
raise ValueError(
f"Unknown fusion_strategy '{fusion_strategy}'. "
f"Must be one of: {', '.join(sorted(valid_strategies))}"
)

if effective_strategy == "VECTOR_ONLY":
query_string = None

(
available_feature_views,
available_odfv_views,
) = utils._get_feature_views_to_use(
registry=self._registry,
project=self.project,
features=features,
allow_cache=True,
hide_dummy_entity=False,
)

feature_view_set = {f.split(":")[0] for f in features}
if len(feature_view_set) > 1:
raise ValueError("Document retrieval only supports a single feature view.")

requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]

if not available_feature_views and not available_odfv_views:
raise ValueError(f"No feature view found for features {features}.")

if not available_feature_views:
available_feature_views.extend(available_odfv_views) # type: ignore[arg-type]

requested_feature_view = available_feature_views[0]

if isinstance(requested_feature_view, OnDemandFeatureView):
raise ValueError(
"V3 vector search is not supported on OnDemandFeatureViews. "
"Use a regular FeatureView with vector-indexed fields."
)

RESERVED_NAMES = {"final_score", "signal_scores"}
collisions = set(requested_features) & RESERVED_NAMES
if collisions:
raise ValueError(
f"Feature names {sorted(collisions)} are reserved by V3 and cannot be "
f"requested directly. Rename your feature view fields or use V2."
)

# "embedding" magic key: auto-resolve to the FeatureView's single vector field
if len(embeddings) == 1 and list(embeddings.keys()) == ["embedding"]:
vector_fields = {
f.name: f for f in requested_feature_view.features if f.vector_index
}
if len(vector_fields) == 0:
raise ValueError(
f"FeatureView '{requested_feature_view.name}' has no vector-indexed "
f"fields. Cannot perform vector search."
)
elif len(vector_fields) == 1:
actual_name = next(iter(vector_fields.keys()))
embeddings = {actual_name: embeddings["embedding"]}
else:
raise ValueError(
f"FeatureView '{requested_feature_view.name}' has multiple vector "
f"fields {sorted(vector_fields.keys())}. "
f"Specify the field name explicitly in embeddings."
)

provider = self._get_provider()
return self._retrieve_from_online_store_v3(
provider,
requested_feature_view,
requested_features,
embeddings,
top_k,
query_string,
effective_strategy,
signal_weights,
rrf_k,
distance_metric,
include_signal_scores,
)

def _retrieve_from_online_store(
self,
provider: Provider,
Expand Down Expand Up @@ -2691,6 +2822,89 @@ def _retrieve_from_online_store_v2(

return OnlineResponse(online_features_response)

def _retrieve_from_online_store_v3(
self,
provider: Provider,
table: FeatureView,
requested_features: List[str],
embeddings: Dict[str, List[float]],
top_k: int,
query_string: Optional[str],
fusion_strategy: str,
signal_weights: Optional[Dict[str, float]],
rrf_k: int,
distance_metric: Optional[str],
include_signal_scores: bool,
) -> OnlineResponse:
documents = provider.retrieve_online_documents_v3(
config=self.config,
table=table,
requested_features=requested_features,
embeddings=embeddings,
top_k=top_k,
query_string=query_string,
fusion_strategy=fusion_strategy,
signal_weights=signal_weights,
rrf_k=rrf_k,
distance_metric=distance_metric,
include_signal_scores=include_signal_scores,
)

entity_key_dict: Dict[str, List[ValueProto]] = {}
datevals, list_of_feature_dicts = [], []
for row_ts, entity_key, feature_dict in documents: # type: ignore[misc]
datevals.append(row_ts)
list_of_feature_dicts.append(feature_dict)
if entity_key:
for key, value in zip(entity_key.join_keys, entity_key.entity_values):
python_value = value
if key not in entity_key_dict:
entity_key_dict[key] = []
entity_key_dict[key].append(python_value)

features_to_request: List[str] = requested_features + [
"final_score",
"signal_scores",
]

if not datevals:
online_features_response = GetOnlineFeaturesResponse(results=[])
for _ in features_to_request:
field = online_features_response.results.add()
field.values.extend([])
field.statuses.extend([])
field.event_timestamps.extend([])
online_features_response.metadata.feature_names.val.extend(
features_to_request
)
return OnlineResponse(online_features_response)

output_len = len(datevals)
idxs = tuple([i] for i in range(output_len))

feature_data = utils._convert_rows_to_protobuf(
requested_features=features_to_request,
read_rows=list(zip(datevals, list_of_feature_dicts)),
)

online_features_response = GetOnlineFeaturesResponse(results=[])
utils._populate_response_from_feature_data(
feature_data=feature_data,
indexes=idxs,
online_features_response=online_features_response,
full_feature_names=False,
requested_features=features_to_request,
table=table,
output_len=output_len,
)

utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data=entity_key_dict,
)

return OnlineResponse(online_features_response)

def _lazy_init_go_server(self):
"""Lazily initialize self._go_server if it hasn't been initialized before."""
from feast.embedded_go.online_features_service import (
Expand Down
4 changes: 0 additions & 4 deletions sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ def __init__(
else:
features.append(field)

assert len([f for f in features if f.vector_index]) < 2, (
f"Only one vector feature is allowed per feature view. Please update {self.name}."
)

# TODO(felixwang9817): Add more robust validation of features.
if self.batch_source is not None:
cols = [field.name for field in schema]
Expand Down
18 changes: 18 additions & 0 deletions sdk/python/feast/infra/online_stores/_signal_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json
from typing import Dict

from feast.protos.feast.types.Value_pb2 import Value as ValueProto


def encode_signal_scores(scores: Dict[str, float]) -> ValueProto:
"""Encode a signal_scores dict as a JSON string in ValueProto."""
val = ValueProto()
val.string_val = json.dumps(scores, separators=(",", ":"), sort_keys=True)
return val


def decode_signal_scores(value: ValueProto) -> Dict[str, float]:
"""Decode a signal_scores ValueProto back to a dict."""
if not value.HasField("string_val") or not value.string_val:
return {}
return json.loads(value.string_val)
Loading
Loading