diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index b4ec250ec00..b1b5660f35e 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2547,6 +2547,137 @@ def retrieve_online_documents_v2( query_string, ) + def retrieve_online_documents_v3( + self, + features: List[str], + top_k: int, + embeddings: Optional[Dict[str, List[float]]] = None, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> OnlineResponse: + """ + Retrieve documents using multi-vector search with configurable fusion. + + Args: + features: Feature references (e.g., ["doc_fv:title", "doc_fv:body"]). + top_k: Number of results to return. + embeddings: Map of vector field name to query vector. Required. + Single entry is equivalent to V2's embedding param. + Special case: the key "embedding" auto-resolves to the + FeatureView's single vector field, easing V2 to V3 migration. + FeatureViews with multiple vector fields require explicit names. + query_string: Text query. Ranking signal on Elasticsearch; logged + dropped + on Valkey (Valkey does not support text-as-ranking). + fusion_strategy: AUTO | RRF | WEIGHTED_LINEAR | VECTOR_ONLY. + signal_weights: Per-signal weights for WEIGHTED_LINEAR. + Keys are embedding field names and/or "bm25". + rrf_k: RRF rank constant (default 60). + distance_metric: Override distance metric. + include_signal_scores: When True, requests per-signal score + breakdowns for fusion strategies (RRF, WEIGHTED_LINEAR) at + additional latency cost. Currently a no-op — signal_scores + always follows the best-effort behavior documented in the V3 + design, and the parameter is plumbed through so callers can + opt in today and automatically pick up the explain-based + path when it lands. Default False. + """ + if not embeddings: + raise ValueError( + "V3 requires at least one embedding. " + "For text-only search, use retrieve_online_documents_v2." + ) + + effective_strategy = fusion_strategy.upper() + valid_strategies = {"AUTO", "RRF", "WEIGHTED_LINEAR", "VECTOR_ONLY"} + if effective_strategy not in valid_strategies: + raise ValueError( + f"Unknown fusion_strategy '{fusion_strategy}'. " + f"Must be one of: {', '.join(sorted(valid_strategies))}" + ) + + if effective_strategy == "VECTOR_ONLY": + query_string = None + + ( + available_feature_views, + available_odfv_views, + ) = utils._get_feature_views_to_use( + registry=self._registry, + project=self.project, + features=features, + allow_cache=True, + hide_dummy_entity=False, + ) + + feature_view_set = {f.split(":")[0] for f in features} + if len(feature_view_set) > 1: + raise ValueError("Document retrieval only supports a single feature view.") + + requested_features = [ + f.split(":")[1] for f in features if isinstance(f, str) and ":" in f + ] + + if not available_feature_views and not available_odfv_views: + raise ValueError(f"No feature view found for features {features}.") + + if not available_feature_views: + available_feature_views.extend(available_odfv_views) # type: ignore[arg-type] + + requested_feature_view = available_feature_views[0] + + if isinstance(requested_feature_view, OnDemandFeatureView): + raise ValueError( + "V3 vector search is not supported on OnDemandFeatureViews. " + "Use a regular FeatureView with vector-indexed fields." + ) + + RESERVED_NAMES = {"final_score", "signal_scores"} + collisions = set(requested_features) & RESERVED_NAMES + if collisions: + raise ValueError( + f"Feature names {sorted(collisions)} are reserved by V3 and cannot be " + f"requested directly. Rename your feature view fields or use V2." + ) + + # "embedding" magic key: auto-resolve to the FeatureView's single vector field + if len(embeddings) == 1 and list(embeddings.keys()) == ["embedding"]: + vector_fields = { + f.name: f for f in requested_feature_view.features if f.vector_index + } + if len(vector_fields) == 0: + raise ValueError( + f"FeatureView '{requested_feature_view.name}' has no vector-indexed " + f"fields. Cannot perform vector search." + ) + elif len(vector_fields) == 1: + actual_name = next(iter(vector_fields.keys())) + embeddings = {actual_name: embeddings["embedding"]} + else: + raise ValueError( + f"FeatureView '{requested_feature_view.name}' has multiple vector " + f"fields {sorted(vector_fields.keys())}. " + f"Specify the field name explicitly in embeddings." + ) + + provider = self._get_provider() + return self._retrieve_from_online_store_v3( + provider, + requested_feature_view, + requested_features, + embeddings, + top_k, + query_string, + effective_strategy, + signal_weights, + rrf_k, + distance_metric, + include_signal_scores, + ) + def _retrieve_from_online_store( self, provider: Provider, @@ -2691,6 +2822,89 @@ def _retrieve_from_online_store_v2( return OnlineResponse(online_features_response) + def _retrieve_from_online_store_v3( + self, + provider: Provider, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str], + fusion_strategy: str, + signal_weights: Optional[Dict[str, float]], + rrf_k: int, + distance_metric: Optional[str], + include_signal_scores: bool, + ) -> OnlineResponse: + documents = provider.retrieve_online_documents_v3( + config=self.config, + table=table, + requested_features=requested_features, + embeddings=embeddings, + top_k=top_k, + query_string=query_string, + fusion_strategy=fusion_strategy, + signal_weights=signal_weights, + rrf_k=rrf_k, + distance_metric=distance_metric, + include_signal_scores=include_signal_scores, + ) + + entity_key_dict: Dict[str, List[ValueProto]] = {} + datevals, list_of_feature_dicts = [], [] + for row_ts, entity_key, feature_dict in documents: # type: ignore[misc] + datevals.append(row_ts) + list_of_feature_dicts.append(feature_dict) + if entity_key: + for key, value in zip(entity_key.join_keys, entity_key.entity_values): + python_value = value + if key not in entity_key_dict: + entity_key_dict[key] = [] + entity_key_dict[key].append(python_value) + + features_to_request: List[str] = requested_features + [ + "final_score", + "signal_scores", + ] + + if not datevals: + online_features_response = GetOnlineFeaturesResponse(results=[]) + for _ in features_to_request: + field = online_features_response.results.add() + field.values.extend([]) + field.statuses.extend([]) + field.event_timestamps.extend([]) + online_features_response.metadata.feature_names.val.extend( + features_to_request + ) + return OnlineResponse(online_features_response) + + output_len = len(datevals) + idxs = tuple([i] for i in range(output_len)) + + feature_data = utils._convert_rows_to_protobuf( + requested_features=features_to_request, + read_rows=list(zip(datevals, list_of_feature_dicts)), + ) + + online_features_response = GetOnlineFeaturesResponse(results=[]) + utils._populate_response_from_feature_data( + feature_data=feature_data, + indexes=idxs, + online_features_response=online_features_response, + full_feature_names=False, + requested_features=features_to_request, + table=table, + output_len=output_len, + ) + + utils._populate_result_rows_from_columnar( + online_features_response=online_features_response, + data=entity_key_dict, + ) + + return OnlineResponse(online_features_response) + def _lazy_init_go_server(self): """Lazily initialize self._go_server if it hasn't been initialized before.""" from feast.embedded_go.online_features_service import ( diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 341c09c461a..c556e482743 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -237,10 +237,6 @@ def __init__( else: features.append(field) - assert len([f for f in features if f.vector_index]) < 2, ( - f"Only one vector feature is allowed per feature view. Please update {self.name}." - ) - # TODO(felixwang9817): Add more robust validation of features. if self.batch_source is not None: cols = [field.name for field in schema] diff --git a/sdk/python/feast/infra/online_stores/_signal_scores.py b/sdk/python/feast/infra/online_stores/_signal_scores.py new file mode 100644 index 00000000000..706c4cec87c --- /dev/null +++ b/sdk/python/feast/infra/online_stores/_signal_scores.py @@ -0,0 +1,18 @@ +import json +from typing import Dict + +from feast.protos.feast.types.Value_pb2 import Value as ValueProto + + +def encode_signal_scores(scores: Dict[str, float]) -> ValueProto: + """Encode a signal_scores dict as a JSON string in ValueProto.""" + val = ValueProto() + val.string_val = json.dumps(scores, separators=(",", ":"), sort_keys=True) + return val + + +def decode_signal_scores(value: ValueProto) -> Dict[str, float]: + """Decode a signal_scores ValueProto back to a dict.""" + if not value.HasField("string_val") or not value.string_val: + return {} + return json.loads(value.string_val) diff --git a/sdk/python/feast/infra/online_stores/eg_valkey.py b/sdk/python/feast/infra/online_stores/eg_valkey.py index 27d45b18935..b3cbd66090e 100644 --- a/sdk/python/feast/infra/online_stores/eg_valkey.py +++ b/sdk/python/feast/infra/online_stores/eg_valkey.py @@ -31,24 +31,37 @@ Union, ) +import numpy as np from google.protobuf.timestamp_pb2 import Timestamp from pydantic import StrictStr -from valkey.exceptions import ValkeyError +from valkey.exceptions import ResponseError, ValkeyError from feast import Entity, FeatureView, RepoConfig, utils -from feast.infra.key_encoding_utils import serialize_entity_key +from feast.field import Field +from feast.infra.key_encoding_utils import ( + deserialize_entity_key, + serialize_entity_key, +) +from feast.infra.online_stores._signal_scores import encode_signal_scores from feast.infra.online_stores.helpers import _mmh3, _redis_key, _redis_key_prefix from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.protos.feast.types.Value_pb2 import FloatList +from feast.protos.feast.types.Value_pb2 import ( + Value as ValueProto, +) from feast.repo_config import FeastConfigBaseModel from feast.sorted_feature_view import SortedFeatureView +from feast.types import Array, Float64 from feast.value_type import ValueType try: from valkey import Valkey from valkey import asyncio as valkey_asyncio from valkey.cluster import ClusterNode, ValkeyCluster + from valkey.commands.search.field import TagField, VectorField + from valkey.commands.search.indexDefinition import IndexDefinition, IndexType + from valkey.commands.search.query import Query from valkey.sentinel import Sentinel except ImportError as e: from feast.errors import FeastExtrasDependencyImportError @@ -58,6 +71,91 @@ logger = logging.getLogger(__name__) +def _get_vector_index_name( + project: str, feature_view_name: str, feature_name: str +) -> str: + """Generate Valkey Search index name for a vector field.""" + return f"{project}_{feature_view_name}_{feature_name}_vidx" + + +def _get_valkey_vector_type(feast_dtype) -> str: + """ + Map Feast dtype to Valkey vector TYPE parameter. + + Valkey Search only supports FLOAT32 vectors. Float64 arrays will be + converted to float32 during serialization. + + Args: + feast_dtype: Feast data type (e.g., Array(Float32)) + + Returns: + Valkey vector type string: always "FLOAT32" + """ + if feast_dtype == Array(Float64): + logger.warning( + "Valkey Search only supports FLOAT32 vectors. " + "Float64 data will be converted to float32 (possible precision loss)." + ) + return "FLOAT32" + + +def _serialize_vector_to_bytes(val: ValueProto, field: Field) -> bytes: + """ + Serialize a vector ValueProto to raw float32 bytes for Valkey storage. + + Vector fields must be stored as raw bytes (not protobuf serialized) to be + compatible with Valkey Search FT.SEARCH queries. Valkey only supports + FLOAT32, so float64 data is converted to float32. + + Args: + val: The ValueProto containing the vector data + field: The Field metadata for dtype and dimension information + + Returns: + Raw float32 bytes in the format expected by Valkey vector search + + Raises: + ValueError: If vector type is unsupported or dimension mismatches + """ + if val.HasField("float_list_val"): + vector = np.array(val.float_list_val.val, dtype=np.float32) + elif val.HasField("double_list_val"): + # Convert float64 to float32 (Valkey only supports float32) + vector = np.array(val.double_list_val.val, dtype=np.float32) + else: + raise ValueError( + f"Unsupported vector type for field {field.name}. " + f"Expected float_list_val or double_list_val." + ) + + # Validate dimension matches expected + if field.vector_length > 0 and len(vector) != field.vector_length: + raise ValueError( + f"Vector dimension mismatch for field {field.name}: " + f"expected {field.vector_length}, got {len(vector)}" + ) + + return vector.tobytes() + + +def _deserialize_vector_from_bytes(raw_bytes: bytes, field: Field) -> ValueProto: + """ + Deserialize raw vector bytes back to ValueProto. + + Valkey stores all vectors as float32, so we always deserialize as float32 + regardless of the original field dtype. + + Args: + raw_bytes: Raw float32 bytes from Valkey + field: Field metadata (unused, kept for API consistency) + + Returns: + ValueProto with float_list_val (always float32) + """ + vector = np.frombuffer(raw_bytes, dtype=np.float32) + return ValueProto(float_list_val=FloatList(val=vector.tolist())) + + class EGValkeyType(str, Enum): valkey = "valkey" valkey_cluster = "valkey_cluster" @@ -100,6 +198,19 @@ class EGValkeyOnlineStoreConfig(FeastConfigBaseModel): max_pipeline_commands: Optional[int] = 500 """(Optional) The maximum number of Valkey commands to queue in a pipeline before sending them to Valkey in a single batch.""" + # Vector search configuration + vector_index_algorithm: Literal["FLAT", "HNSW"] = "HNSW" + """Algorithm for vector indexing. FLAT for exact search (<100K vectors), HNSW for approximate search (large datasets).""" + + vector_index_hnsw_m: Optional[int] = 16 + """HNSW: Max number of outgoing edges per node.""" + + vector_index_hnsw_ef_construction: Optional[int] = 200 + """HNSW: Size of dynamic candidate list during index construction.""" + + vector_index_hnsw_ef_runtime: Optional[int] = 10 + """HNSW: Size of dynamic candidate list during search.""" + class EGValkeyOnlineStore(OnlineStore): """ @@ -144,7 +255,12 @@ def delete_table(self, config: RepoConfig, table: FeatureView): deleted_count = 0 prefix = _redis_key_prefix(table.join_keys) - valkey_hash_keys = [_mmh3(f"{table.name}:{f.name}") for f in table.features] + # Build list of hash keys to delete + # Vector fields use original name, non-vector fields use mmh3 hash + valkey_hash_keys = [ + f.name.encode("utf8") if f.vector_index else _mmh3(f"{table.name}:{f.name}") + for f in table.features + ] valkey_hash_keys.append(bytes(f"_ts:{table.name}", "utf8")) with client.pipeline(transaction=False) as pipe: @@ -165,6 +281,33 @@ def delete_table(self, config: RepoConfig, table: FeatureView): logger.debug(f"Deleted {deleted_count} rows for feature view {table.name}") + # Drop vector index if it exists + self._drop_vector_index_if_exists(client, config.project, table) + + def _drop_vector_index_if_exists( + self, + client: Union[Valkey, ValkeyCluster], + project: str, + table: FeatureView, + ) -> None: + """Drop Valkey Search vector indexes for all vector fields in a feature view.""" + vector_fields = [f for f in table.features if f.vector_index] + + # Drop index for each vector field + for field in vector_fields: + index_name = _get_vector_index_name(project, table.name, field.name) + try: + client.ft(index_name).dropindex(delete_documents=False) + logger.info(f"Dropped vector index {index_name}") + except ResponseError as e: + # Index doesn't exist - this is fine + if "unknown index" in str(e).lower(): + logger.debug( + f"Vector index {index_name} does not exist, skipping drop" + ) + else: + raise + def update( self, config: RepoConfig, @@ -202,8 +345,14 @@ def teardown( """ We delete the keys in valkey for tables/views being removed. """ - join_keys_to_delete = set(tuple(table.join_keys) for table in tables) + client = self._get_client(config.online_store) + # Drop vector indexes for each table + for table in tables: + self._drop_vector_index_if_exists(client, config.project, table) + + # Delete entity values + join_keys_to_delete = set(tuple(table.join_keys) for table in tables) for join_keys in join_keys_to_delete: self.delete_entity_values(config, list(join_keys)) @@ -289,6 +438,96 @@ async def _get_client_async(self, online_store_config: EGValkeyOnlineStoreConfig self._client_async = valkey_asyncio.Valkey(**kwargs) return self._client_async + def _create_vector_index_if_not_exists( + self, + client: Union[Valkey, ValkeyCluster], + config: RepoConfig, + table: FeatureView, + vector_fields: Dict[str, Field], + ) -> None: + """ + Create Valkey Search index for each vector field if not already exists. + + Uses FT.CREATE with VECTOR field type and appropriate algorithm parameters. + Creates one index per vector field for future multi-vector support. + + Args: + client: Valkey client + config: Feast repo configuration + table: Feature view with vector fields + vector_fields: Dictionary of vector field name to Field object + """ + online_store_config = config.online_store + assert isinstance(online_store_config, EGValkeyOnlineStoreConfig) + + # Define index on HASH keys with specific prefix (shared across all indexes) + key_prefix = _redis_key_prefix(table.join_keys) + definition = IndexDefinition( + prefix=[key_prefix], + index_type=IndexType.HASH, + ) + + # Create one index per vector field + for field_name, field in vector_fields.items(): + index_name = _get_vector_index_name(config.project, table.name, field_name) + + # Check if index exists + try: + client.ft(index_name).info() + logger.debug(f"Vector index {index_name} already exists") + continue + except ResponseError: + pass # Index doesn't exist, create it + + # Validate required properties + if field.vector_length <= 0: + raise ValueError( + f"Field {field_name} has vector_index=True but vector_length is not set. " + f"vector_length must be > 0 for vector indexing." + ) + + # Determine vector type from Feast dtype + vector_type = _get_valkey_vector_type(field.dtype) + + # Build algorithm attributes + attributes = { + "TYPE": vector_type, # Always FLOAT32 (Valkey limitation) + "DIM": field.vector_length, + "DISTANCE_METRIC": field.vector_search_metric or "COSINE", + } + + # Add algorithm-specific parameters + algorithm = online_store_config.vector_index_algorithm + if algorithm == "HNSW": + attributes["M"] = online_store_config.vector_index_hnsw_m + attributes["EF_CONSTRUCTION"] = ( + online_store_config.vector_index_hnsw_ef_construction + ) + attributes["EF_RUNTIME"] = ( + online_store_config.vector_index_hnsw_ef_runtime + ) + + # Create the index with vector field and project tag for filtering + # __project__ TAG field enables filtering by project in hybrid queries + try: + client.ft(index_name).create_index( + fields=[ + VectorField(field_name, algorithm, attributes), + TagField("__project__"), + ], + definition=definition, + ) + logger.info(f"Created vector index {index_name} for field {field_name}") + except ResponseError as e: + if "already exists" in str(e).lower(): + logger.debug(f"Vector index {index_name} already exists") + continue + logger.error( + f"Failed to create vector index {index_name}: {e}. " + f"Ensure Valkey Search module is loaded." + ) + raise + def online_write_batch( self, config: RepoConfig, @@ -307,6 +546,7 @@ def online_write_batch( feature_view = table.name ts_key = f"_ts:{feature_view}" keys = [] + # Track all ZSET keys touched in this batch for TTL cleanup & trimming zsets_to_cleanup: set[Tuple[bytes, bytes]] = ( set() @@ -448,6 +688,15 @@ def online_write_batch( ) raise else: + # Identify vector fields (only for regular FeatureViews, not SortedFeatureView) + vector_fields = {f.name: f for f in table.features if f.vector_index} + + # Create vector index if needed (only on first write with vector fields) + if vector_fields: + self._create_vector_index_if_not_exists( + client, config, table, vector_fields + ) + # check if a previous record under the key bin exists # TODO: investigate if check and set is a better approach rather than pulling all entity ts and then setting # it may be significantly slower but avoids potential (rare) race conditions @@ -463,9 +712,12 @@ def online_write_batch( # flattening the list of lists. `hmget` does the lookup assuming a list of keys in the key bin prev_event_timestamps = [i[0] for i in prev_event_timestamps] - for valkey_key_bin, prev_event_time, (_, values, timestamp, _) in zip( - keys, prev_event_timestamps, data - ): + for valkey_key_bin, prev_event_time, ( + entity_key, + values, + timestamp, + _, + ) in zip(keys, prev_event_timestamps, data): event_time_seconds = int(utils.make_tzaware(timestamp).timestamp()) # ignore if event_timestamp is before the event features that are currently in the feature store @@ -482,10 +734,24 @@ def online_write_batch( ts.seconds = event_time_seconds entity_hset = dict() entity_hset[ts_key] = ts.SerializeToString() + # Store project and entity key for vector search + entity_hset["__project__"] = project.encode() + entity_hset["__entity_key__"] = serialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) for feature_name, val in values.items(): - f_key = _mmh3(f"{feature_view}:{feature_name}") - entity_hset[f_key] = val.SerializeToString() + if feature_name in vector_fields: + # Vector field: store with ORIGINAL name and RAW bytes + vector_bytes = _serialize_vector_to_bytes( + val, vector_fields[feature_name] + ) + entity_hset[feature_name] = vector_bytes + else: + # Non-vector field: store with mmh3 hash and protobuf serialization + f_key = _mmh3(f"{feature_view}:{feature_name}") + entity_hset[f_key] = val.SerializeToString() pipe.hset(valkey_key_bin, mapping=entity_hset) @@ -580,28 +846,53 @@ def _generate_hset_keys_for_features( self, feature_view: FeatureView, requested_features: Optional[List[str]] = None, - ) -> Tuple[List[str], List[str]]: + ) -> Tuple[List[str], List[str], Dict[str, Field]]: + """ + Generate HSET keys for feature retrieval. + + Returns: + Tuple of (feature_names, hset_keys, vector_fields dict) + """ if not requested_features: requested_features = [f.name for f in feature_view.features] - hset_keys = [_mmh3(f"{feature_view.name}:{k}") for k in requested_features] + vector_fields = {f.name: f for f in feature_view.features if f.vector_index} + + hset_keys = [] + for feature_name in requested_features: + if feature_name in vector_fields: + # Vector field: use original name + hset_keys.append(feature_name) + else: + # Non-vector: use mmh3 hash + hset_keys.append(_mmh3(f"{feature_view.name}:{feature_name}")) ts_key = f"_ts:{feature_view.name}" hset_keys.append(ts_key) - requested_features.append(ts_key) + requested_features = list(requested_features) + [ts_key] - return requested_features, hset_keys + return requested_features, hset_keys, vector_fields def _convert_valkey_values_to_protobuf( self, valkey_values: List[List[ByteString]], - feature_view: str, + feature_view: FeatureView, requested_features: List[str], + vector_fields: Dict[str, Field], ): + """ + Convert Valkey values back to protobuf, handling vector fields. + + Args: + valkey_values: Raw values from Valkey + feature_view: Feature view object (not just name) + requested_features: List of feature names + vector_fields: Dict of field name to Field for vector fields + """ result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for values in valkey_values: features = self._get_features_for_entity( - values, feature_view, requested_features + values, feature_view, requested_features, vector_fields ) result.append(features) return result @@ -619,8 +910,8 @@ def online_read( client = self._get_client(online_store_config) feature_view = table - requested_features, hset_keys = self._generate_hset_keys_for_features( - feature_view, requested_features + requested_features, hset_keys, vector_fields = ( + self._generate_hset_keys_for_features(feature_view, requested_features) ) keys = self._generate_valkey_keys_for_entities(config, entity_keys) @@ -631,7 +922,7 @@ def online_read( valkey_values = pipe.execute() return self._convert_valkey_values_to_protobuf( - valkey_values, feature_view.name, requested_features + valkey_values, feature_view, requested_features, vector_fields ) async def online_read_async( @@ -647,8 +938,8 @@ async def online_read_async( client = await self._get_client_async(online_store_config) feature_view = table - requested_features, hset_keys = self._generate_hset_keys_for_features( - feature_view, requested_features + requested_features, hset_keys, vector_fields = ( + self._generate_hset_keys_for_features(feature_view, requested_features) ) keys = self._generate_valkey_keys_for_entities(config, entity_keys) @@ -658,27 +949,47 @@ async def online_read_async( valkey_values = await pipe.execute() return self._convert_valkey_values_to_protobuf( - valkey_values, feature_view.name, requested_features + valkey_values, feature_view, requested_features, vector_fields ) def _get_features_for_entity( self, values: List[ByteString], - feature_view: str, + feature_view: FeatureView, requested_features: List[str], + vector_fields: Dict[str, Field], ) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]: + """ + Parse features for a single entity, handling vector deserialization. + + Args: + values: Raw bytes from Valkey + feature_view: Feature view object + requested_features: List of feature names (includes _ts key) + vector_fields: Dict of field name to Field for vector fields (O(1) lookup) + """ res_val = dict(zip(requested_features, values)) res_ts = Timestamp() - ts_val = res_val.pop(f"_ts:{feature_view}") + ts_val = res_val.pop(f"_ts:{feature_view.name}") if ts_val: res_ts.ParseFromString(bytes(ts_val)) res = {} for feature_name, val_bin in res_val.items(): - val = ValueProto() - if val_bin: + if not val_bin: + res[feature_name] = ValueProto() + continue + + if feature_name in vector_fields: + # Vector field: deserialize from raw bytes + field = vector_fields[feature_name] + val = _deserialize_vector_from_bytes(bytes(val_bin), field) + else: + # Regular field: parse protobuf + val = ValueProto() val.ParseFromString(bytes(val_bin)) + res[feature_name] = val if not res: @@ -686,3 +997,420 @@ def _get_features_for_entity( else: timestamp = datetime.fromtimestamp(res_ts.seconds, tz=timezone.utc) return timestamp, res + + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Retrieve documents using vector similarity search from Valkey. + + Args: + config: Feast configuration object + table: FeatureView to search + requested_features: List of feature names to return + embedding: Query embedding vector + top_k: Number of results to return + distance_metric: Optional override for distance metric (COSINE, L2, IP) + query_string: Not supported in V1 (reserved for future BM25 search) + + Returns: + List of tuples containing (timestamp, entity_key, features_dict) + """ + if embedding is None: + raise ValueError("embedding must be provided for vector search") + + if query_string is not None: + raise NotImplementedError( + "Keyword search (query_string) is not yet supported for Valkey. " + "Only vector similarity search is available." + ) + + online_store_config = config.online_store + assert isinstance(online_store_config, EGValkeyOnlineStoreConfig) + + client = self._get_client(online_store_config) + project = config.project + + # Find the vector field to search against + vector_field = self._get_vector_field_for_search(table, requested_features) + if vector_field is None: + raise ValueError( + f"No vector field found in FeatureView {table.name}. " + "Ensure the FeatureView has a field with vector_index=True." + ) + + # Determine distance metric + metric = distance_metric or vector_field.vector_search_metric or "COSINE" + + # Serialize query embedding to bytes + embedding_bytes = self._serialize_embedding_for_search(embedding, vector_field) + + # Build and execute FT.SEARCH query + index_name = _get_vector_index_name(project, table.name, vector_field.name) + search_results = self._execute_vector_search( + client=client, + index_name=index_name, + project=project, + vector_field_name=vector_field.name, + embedding_bytes=embedding_bytes, + top_k=top_k, + metric=metric, + ) + + if not search_results: + return [] + + # Fetch features for each result using pipeline HMGET + return self._fetch_features_for_search_results( + client=client, + config=config, + table=table, + requested_features=requested_features, + search_results=search_results, + ) + + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + V3 document retrieval on Valkey backend. + + Valkey supports a subset of V3 features: + - Single embedding only (multi-embedding raises ValueError) + - AUTO and VECTOR_ONLY fusion strategies (others raise ValueError) + - query_string is silently dropped with a warning (Valkey cannot + use text as a ranking signal) + + Returns the same tuple shape as V2 with final_score and signal_scores + added to the feature dict. final_score is the raw Valkey + ``__distance__`` from ``FT.SEARCH KNN`` — lower = better for all + supported metrics (COSINE returns ``1 - cos``, L2 returns squared + distance, IP returns ``1 - inner_product``). Note this ordering is + the opposite of Elasticsearch's ``final_score``, which is a + relevance score where higher = better. + + Reserved parameters (accepted but currently unused): + - ``include_signal_scores``: No-op today. ``signal_scores`` is + populated best-effort (single-entry dict for the one vector + signal). Reserved so callers can opt in now and automatically + pick up a richer breakdown when the explain-based path lands. + """ + del include_signal_scores + valid_strategies = {"AUTO", "RRF", "WEIGHTED_LINEAR", "VECTOR_ONLY"} + effective_strategy = fusion_strategy.upper() + if effective_strategy not in valid_strategies: + raise ValueError( + f"Unknown fusion_strategy '{fusion_strategy}'. " + f"Valid options: {sorted(valid_strategies)}" + ) + + if not embeddings: + raise ValueError( + "V3 requires at least one embedding. " + "Pass embeddings={field_name: vector}." + ) + + if len(embeddings) > 1: + raise ValueError( + "Multi-vector fusion requires the Elasticsearch backend. " + "Valkey supports single-vector search only. " + "Use a single embedding or switch to the Elasticsearch online store." + ) + + if effective_strategy in ("RRF", "WEIGHTED_LINEAR"): + raise ValueError( + f"Fusion strategy '{effective_strategy}' is not supported on Valkey. " + "Use fusion_strategy='AUTO' or 'VECTOR_ONLY', " + "or switch to the Elasticsearch backend for fusion support." + ) + + if query_string is not None and effective_strategy != "VECTOR_ONLY": + logger.warning( + "query_string is being dropped — Valkey backend does not support " + "text search as a ranking signal. To use text as a ranking signal, " + "switch to the Elasticsearch backend." + ) + + embed_key, embed_vector = next(iter(embeddings.items())) + + v2_results = self.retrieve_online_documents_v2( + config=config, + table=table, + requested_features=requested_features, + embedding=embed_vector, + top_k=top_k, + distance_metric=distance_metric, + query_string=None, # Valkey does not support query_string + ) + + v3_results: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + for timestamp, entity_key_proto, feature_dict in v2_results: + if feature_dict is None: + v3_results.append((timestamp, entity_key_proto, None)) + continue + + distance_val = feature_dict.pop("distance", None) + if distance_val is not None and distance_val.HasField("double_val"): + feature_dict["final_score"] = distance_val + signal_scores = {f"vec_{embed_key}": distance_val.double_val} + else: + signal_scores = {} + + feature_dict["signal_scores"] = encode_signal_scores(signal_scores) + v3_results.append((timestamp, entity_key_proto, feature_dict)) + + return v3_results + + def _get_vector_field_for_search( + self, + table: FeatureView, + requested_features: Optional[List[str]], + ) -> Optional[Field]: + """Find the vector field to use for search.""" + vector_fields = [f for f in table.features if f.vector_index] + + if not vector_fields: + return None + + # If requested_features specified, prefer a vector field from that list + if requested_features: + # Convert to set for O(1) lookup instead of O(n) list search + requested_set = set(requested_features) + for f in vector_fields: + if f.name in requested_set: + return f + + # Default to first vector field + return vector_fields[0] + + def _serialize_embedding_for_search( + self, + embedding: List[float], + vector_field: Field, + ) -> bytes: + """Serialize query embedding to bytes matching the field's dtype.""" + # Validate embedding dimension matches field configuration + if len(embedding) != vector_field.vector_length: + raise ValueError( + f"Embedding dimension {len(embedding)} does not match " + f"vector field '{vector_field.name}' dimension {vector_field.vector_length}" + ) + + if vector_field.dtype == Array(Float64): + return np.array(embedding, dtype=np.float64).tobytes() + else: + # Default to float32 + return np.array(embedding, dtype=np.float32).tobytes() + + def _execute_vector_search( + self, + client: Union[Valkey, ValkeyCluster], + index_name: str, + project: str, + vector_field_name: str, + embedding_bytes: bytes, + top_k: int, + metric: str, + ) -> List[Tuple[bytes, float]]: + """ + Execute FT.SEARCH with KNN query. + + Returns: + List of (doc_key, distance) tuples + """ + # Escape special characters in project name for tag filter. + # In Valkey Search tag queries, characters like - . @ need backslash escaping. + escaped_project = project + for ch in r'\-.@+~<>{}[]^":|!*()': + escaped_project = escaped_project.replace(ch, f"\\{ch}") + + query_str = ( + f"(@__project__:{{{escaped_project}}})" + f"=>[KNN {top_k} @{vector_field_name} $vec AS __distance__]" + ) + + # KNN results are already sorted by distance (ascending) by the engine. + # No explicit SORTBY is needed — Valkey Search does not support SORTBY + # with KNN queries. + query = ( + Query(query_str).return_fields("__distance__").paging(0, top_k).dialect(2) + ) + + try: + results = client.ft(index_name).search( + query, + query_params={"vec": embedding_bytes}, + ) + except ResponseError as e: + if "no such index" in str(e).lower(): + raise ValueError( + f"Vector index '{index_name}' does not exist. " + "Ensure data has been materialized with 'feast materialize'." + ) + raise + + # Parse results: extract doc keys and distances + search_results = [] + for doc in results.docs: + doc_key = doc.id.encode() if isinstance(doc.id, str) else doc.id + # Default to inf (worst distance) if __distance__ is missing + # 0.0 would incorrectly indicate a perfect match + distance = float(getattr(doc, "__distance__", float("inf"))) + search_results.append((doc_key, distance)) + + return search_results + + def _fetch_features_for_search_results( + self, + client: Union[Valkey, ValkeyCluster], + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + search_results: List[Tuple[bytes, float]], + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Fetch features for search results using pipeline HMGET. + + This is the second step of two-step retrieval: + 1. FT.SEARCH returns doc keys and distances + 2. HMGET fetches the actual feature values + """ + # Pre-compute mappings once (avoid repeated dict/hash operations in loops) + vector_fields_dict = {f.name: f for f in table.features if f.vector_index} + + # Build feature_name -> hset_key mapping and hset_keys list in single pass + feature_to_hset_key: Dict[str, Any] = {} + hset_keys = [] + for feature_name in requested_features: + if feature_name in vector_fields_dict: + hset_key = feature_name + else: + hset_key = _mmh3(f"{table.name}:{feature_name}") + feature_to_hset_key[feature_name] = hset_key + hset_keys.append(hset_key) + + # Add timestamp and entity key + ts_key = f"_ts:{table.name}" + hset_keys.append(ts_key) + hset_keys.append("__entity_key__") + + # Extract doc_keys and distances in single pass + doc_keys = [] + distances = {} + for doc_key, dist in search_results: + doc_keys.append(doc_key) + distances[doc_key] = dist + + # Pipeline HMGET for all results (single round-trip to Valkey) + with client.pipeline(transaction=False) as pipe: + for doc_key in doc_keys: + key_str = doc_key.decode() if isinstance(doc_key, bytes) else doc_key + pipe.hmget(key_str, hset_keys) + fetched_values = pipe.execute() + + # Pre-fetch serialization version once + entity_key_serialization_version = config.entity_key_serialization_version + + # Build result list + results: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + + for doc_key, values in zip(doc_keys, fetched_values): + # Parse values into dict + val_dict = dict(zip(hset_keys, values)) + + # Parse timestamp + timestamp = None + ts_val = val_dict.get(ts_key) + if ts_val: + ts_proto = Timestamp() + ts_proto.ParseFromString(bytes(ts_val)) + timestamp = datetime.fromtimestamp(ts_proto.seconds, tz=timezone.utc) + + # Parse entity key + entity_key_proto = None + entity_key_bytes = val_dict.get("__entity_key__") + if entity_key_bytes: + entity_key_proto = deserialize_entity_key( + bytes(entity_key_bytes), + entity_key_serialization_version=entity_key_serialization_version, + ) + + # Build feature dict with pre-allocated capacity hint + feature_dict: Dict[str, ValueProto] = {} + + # Add distance as a feature + distance_proto = ValueProto() + distance_proto.double_val = distances[doc_key] + feature_dict["distance"] = distance_proto + + # Parse requested features using pre-computed mappings + for feature_name in requested_features: + hset_key = feature_to_hset_key[feature_name] + val_bin = val_dict.get(hset_key) + + if not val_bin: + feature_dict[feature_name] = ValueProto() + continue + + if feature_name in vector_fields_dict: + # Vector field: deserialize from raw bytes + feature_dict[feature_name] = _deserialize_vector_from_bytes( + bytes(val_bin), vector_fields_dict[feature_name] + ) + else: + # Regular field: parse protobuf + val = ValueProto() + val.ParseFromString(bytes(val_bin)) + feature_dict[feature_name] = val + + results.append((timestamp, entity_key_proto, feature_dict)) + + return results diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index 7e8e533281d..0895a144625 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -3,11 +3,13 @@ import base64 import json import logging +import math from collections import defaultdict from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from elasticsearch import Elasticsearch, helpers +from pydantic import model_validator from feast import Entity, FeatureView, RepoConfig from feast.infra.key_encoding_utils import ( @@ -15,6 +17,7 @@ get_list_val_str, serialize_entity_key, ) +from feast.infra.online_stores._signal_scores import encode_signal_scores from feast.infra.online_stores.online_store import OnlineStore from feast.infra.online_stores.vector_store import VectorStoreConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -26,6 +29,8 @@ to_naive_utc, ) +logger = logging.getLogger(__name__) + class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): """ @@ -46,6 +51,110 @@ class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): # The number of rows to write in a single batch write_batch_size: Optional[int] = 40 + # Quantization / index_options configuration + vector_index_type: Optional[str] = None + # One of: "hnsw", "int8_hnsw", "int4_hnsw", "bbq_hnsw", + # "flat", "int8_flat", "int4_flat", "bbq_flat" + # None = use ES default (hnsw for <8.x, int8_hnsw for 9.0+) + + # HNSW tuning parameters (only apply to HNSW index types) + hnsw_m: Optional[int] = None # Neighbor connections (ES default: 16) + hnsw_ef_construction: Optional[int] = ( + None # Build-time candidates (ES default: 100) + ) + + # Rescore configuration for quantized indices only (int4/int8/bbq) + rescore_oversample: Optional[float] = ( + None # Must be (1.0, 10.0) exclusive; None to disable + ) + + # Query method toggle + use_native_knn: bool = False # False = script_score (backward compatible) + # True = native knn query (faster, approximate) + + # KNN query tuning + knn_num_candidates_multiplier: Optional[float] = ( + None # Default: 2.0; num_candidates = top_k * multiplier (must be >= 1.0) + ) + + @model_validator(mode="after") + def validate_quantization_config(self): + """Validate quantization configuration constraints.""" + # Validate vector_index_type is a known value + valid_index_types = { + "hnsw", + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "flat", + "int8_flat", + "int4_flat", + "bbq_flat", + } + if ( + self.vector_index_type is not None + and self.vector_index_type not in valid_index_types + ): + raise ValueError( + f"vector_index_type must be one of {valid_index_types}, got {self.vector_index_type}" + ) + + # Validate rescore_oversample range and constraints + # ES requires: (1.0, 10.0) exclusive, per https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/dense-vector + if self.rescore_oversample is not None: + if self.rescore_oversample <= 1.0 or self.rescore_oversample >= 10.0: + raise ValueError( + f"rescore_oversample must be in the range (1.0, 10.0) exclusive, " + f"got {self.rescore_oversample}" + ) + + # Validate rescore_oversample only applies to quantized indices + quantized_types = { + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "int8_flat", + "int4_flat", + "bbq_flat", + } + if ( + self.vector_index_type is None + or self.vector_index_type not in quantized_types + ): + raise ValueError( + f"rescore_oversample can only be used with quantized index types {quantized_types}, " + f"got vector_index_type={self.vector_index_type}" + ) + + # Validate HNSW parameters only apply to HNSW index types + hnsw_types = {"hnsw", "int8_hnsw", "int4_hnsw", "bbq_hnsw"} + if (self.hnsw_m is not None or self.hnsw_ef_construction is not None) and ( + self.vector_index_type is not None + and self.vector_index_type not in hnsw_types + ): + raise ValueError( + f"hnsw_m and hnsw_ef_construction only apply to HNSW index types {hnsw_types}, " + f"got vector_index_type='{self.vector_index_type}'" + ) + + # Validate HNSW parameter ranges (basic sanity only; ES enforces its own limits) + if self.hnsw_m is not None and self.hnsw_m < 1: + raise ValueError(f"hnsw_m must be >= 1, got {self.hnsw_m}") + + if self.hnsw_ef_construction is not None and self.hnsw_ef_construction < 1: + raise ValueError( + f"hnsw_ef_construction must be >= 1, got {self.hnsw_ef_construction}" + ) + + # Validate knn_num_candidates_multiplier range (must be >= 1.0) + if self.knn_num_candidates_multiplier is not None: + if self.knn_num_candidates_multiplier < 1.0: + raise ValueError( + f"knn_num_candidates_multiplier must be >= 1.0, got {self.knn_num_candidates_multiplier}" + ) + + return self + class ElasticSearchOnlineStore(OnlineStore): _client: Optional[Elasticsearch] = None @@ -93,6 +202,8 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: + vector_field = _get_feature_view_vector_field_metadata(table) + vector_field_name = vector_field.name if vector_field else None insert_values = [] grouped_docs: dict[str, dict[str, Any]] = defaultdict( lambda: { @@ -115,27 +226,31 @@ def online_write_batch( doc_key = f"{encoded_entity_key}_{timestamp}" for feature_name, value in values.items(): - doc = _encode_feature_value(value) + doc = _encode_feature_value( + value, is_vector=(feature_name == vector_field_name) + ) grouped_docs[doc_key]["features"][feature_name] = doc grouped_docs[doc_key]["timestamp"] = timestamp grouped_docs[doc_key]["created_ts"] = created_ts grouped_docs[doc_key]["entity_key"] = encoded_entity_key - insert_values = [ - { - "entity_key": document["entity_key"], - "timestamp": document["timestamp"], - "created_ts": document["created_ts"], - **(document["features"] or {}), - } - for document in grouped_docs.values() - ] + insert_values = [ + { + "entity_key": document["entity_key"], + "timestamp": document["timestamp"], + "created_ts": document["created_ts"], + **(document["features"] or {}), + } + for document in grouped_docs.values() + ] batch_size = config.online_store.write_batch_size for i in range(0, len(insert_values), batch_size): batch = insert_values[i : i + batch_size] actions = self._bulk_batch_actions(table, batch) helpers.bulk(self._get_client(config), actions, refresh="wait_for") + if progress: + progress(len(batch)) def online_read( self, @@ -159,6 +274,7 @@ def online_read( includes.append("*") body = { + "size": len(encoded_entity_keys), "_source": {"includes": includes, "excludes": ["*.vector_value"]}, "query": { "bool": {"filter": [{"terms": {"entity_key": encoded_entity_keys}}]} @@ -167,12 +283,16 @@ def online_read( response = self._get_client(config).search(index=table.name, body=body) - results = [] + # Build a lookup dict keyed by entity_key to preserve input order + entity_key_to_result: Dict[ + str, Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] + ] = {} for hit in response["hits"]["hits"]: source = hit["_source"] + entity_key_val = source.get("entity_key") timestamp = source.get("timestamp") - timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S") + timestamp = datetime.fromisoformat(timestamp) features: Dict[str, ValueProto] = {} @@ -193,7 +313,15 @@ def online_read( f"Failed to parse feature '{feature_name}' from hit: {e}" ) - results.append((timestamp, features if features else None)) + entity_key_to_result[entity_key_val] = ( + timestamp, + features if features else None, + ) + + # Return results in the same order as input entity_keys + results: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + for encoded_key in encoded_entity_keys: + results.append(entity_key_to_result.get(encoded_key, (None, None))) return results @@ -205,10 +333,19 @@ def create_index(self, config: RepoConfig, table: FeatureView): config: Feast repo configuration object. table: FeatureView table for which the index needs to be created. """ - vector_field_length = getattr( - _get_feature_view_vector_field_metadata(table), "vector_length", 512 + vector_field_length = ( + getattr(_get_feature_view_vector_field_metadata(table), "vector_length", 0) + or 512 ) + # Validate vector_field_length is positive + if vector_field_length <= 0: + raise ValueError( + f"vector_field_length must be > 0, got {vector_field_length} for table '{table.name}'" + ) + + vector_mapping = _build_vector_mapping(config, vector_field_length, table.name) + index_mapping = { "dynamic_templates": [ { @@ -220,12 +357,7 @@ def create_index(self, config: RepoConfig, table: FeatureView): "properties": { "feature_value": {"type": "binary"}, "value_text": {"type": "text"}, - "vector_value": { - "type": "dense_vector", - "dims": vector_field_length, - "index": True, - "similarity": config.online_store.similarity, - }, + "vector_value": vector_mapping, }, }, } @@ -238,10 +370,11 @@ def create_index(self, config: RepoConfig, table: FeatureView): }, } - self._get_client(config).indices.create( - index=table.name, - mappings=index_mapping, - ) + client = self._get_client(config) + if not client.indices.exists(index=table.name): + client.indices.create(index=table.name, mappings=index_mapping) + else: + logger.info(f"Index '{table.name}' already exists; skipping creation. ") def update( self, @@ -252,11 +385,28 @@ def update( entities_to_keep: Sequence[Entity], partial: bool, ): - # implement the update method + client = self._get_client(config) + + # Cache existing indices to reduce API calls + all_table_names = [t.name for t in tables_to_delete] + [ + t.name for t in tables_to_keep + ] + existing_indices: Set[str] = set() + for table_name in all_table_names: + if client.indices.exists(index=table_name): + existing_indices.add(table_name) + + # Delete data from indices that should be removed for table in tables_to_delete: - self._get_client(config).delete_by_query(index=table.name) + if table.name in existing_indices: + client.delete_by_query( + index=table.name, body={"query": {"match_all": {}}} + ) + + # Create indices for tables that should be kept for table in tables_to_keep: - self.create_index(config, table) + if table.name not in existing_indices: + self.create_index(config, table) def teardown( self, @@ -265,9 +415,17 @@ def teardown( entities: Sequence[Entity], ): project = config.project + client = self._get_client(config) try: + # Cache existing indices to reduce API calls + existing_indices: Set[str] = set() for table in tables: - self._get_client(config).indices.delete(index=table.name) + if client.indices.exists(index=table.name): + existing_indices.add(table.name) + + # Delete all existing indices + for table_name in existing_indices: + client.indices.delete(index=table_name) except Exception as e: logging.exception(f"Error deleting index in project {project}: {e}") raise @@ -299,21 +457,48 @@ def retrieve_online_documents( Optional[ValueProto], ] ] = [] + vector_field = _get_feature_view_vector_field_metadata(table) vector_field_path = ( - config.online_store.vector_field_path or "embedding.vector_value" + f"{vector_field.name}.vector_value" + if vector_field + else config.online_store.vector_field_path or "embedding.vector_value" ) - query = { - "script_score": { - "query": { - "bool": {"filter": [{"exists": {"field": vector_field_path}}]} - }, - "script": { - "source": f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0", - "params": {"query_vector": embedding}, - }, + + # Build query based on use_native_knn config + body: Dict[str, Any] = {"size": top_k, "_source": True} + + if config.online_store.use_native_knn: + # Native knn query (fast, approximate) + # Uses the similarity metric configured in the index mapping + multiplier = config.online_store.knn_num_candidates_multiplier or 2.0 + num_candidates: int = max(top_k, math.ceil(top_k * multiplier)) + + knn_query: Dict[str, Any] = { + "field": vector_field_path, + "query_vector": embedding, + "k": top_k, + "num_candidates": num_candidates, + } + + if config.online_store.rescore_oversample is not None: + knn_query["rescore_vector"] = { + "oversample": config.online_store.rescore_oversample + } + + body["knn"] = knn_query + else: + # Legacy script_score query (slow, exact, backward compatible) + body["query"] = { + "script_score": { + "query": { + "bool": {"filter": [{"exists": {"field": vector_field_path}}]} + }, + "script": { + "source": f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0", + "params": {"query_vector": embedding}, + }, + } } - } - body = {"size": top_k, "_source": True, "query": query} response = self._get_client(config).search(index=table.name, body=body) rows = response["hits"]["hits"][0:top_k] for row in rows: @@ -322,7 +507,7 @@ def retrieve_online_documents( distance = row["_score"] timestamp_str = source.get("timestamp") - timestamp = datetime.strptime(timestamp_str, "%Y-%m-%dT%H:%M:%S.%f") + timestamp = datetime.fromisoformat(timestamp_str) for feature_name in requested_features: feature_data = source.get(feature_name, {}) @@ -366,12 +551,12 @@ def retrieve_online_documents_v2( Optional[Dict[str, ValueProto]], ] ] = [] - if not config.online_store.vector_enabled: - raise ValueError("Vector search is not enabled in the online store config") - if embedding is None and query_string is None: raise ValueError("Either embedding or query_string must be provided") + if embedding is not None and not config.online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config") + es_index = table.name body: Dict[str, Any] = { "size": top_k, @@ -384,49 +569,115 @@ def retrieve_online_documents_v2( body["_source"] = source_fields if embedding: - similarity = (distance_metric or config.online_store.similarity).lower() + vector_field = _get_feature_view_vector_field_metadata(table) vector_field_path = ( - config.online_store.vector_field_path or "embedding.vector_value" + f"{vector_field.name}.vector_value" + if vector_field + else config.online_store.vector_field_path or "embedding.vector_value" ) - if similarity == "cosine": - script = f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0" - elif similarity == "dot_product": - script = f"dotProduct(params.query_vector, '{vector_field_path}')" - elif similarity in ("l2", "l2_norm", "euclidean"): - script = f"1 / (1 + l2norm(params.query_vector, '{vector_field_path}'))" - else: - raise ValueError( - f"Unsupported similarity/distance_metric: {similarity}" + similarity = ( + distance_metric + or ( + vector_field.vector_search_metric + if vector_field and vector_field.vector_search_metric + else None ) + or config.online_store.similarity + ).lower() + + # Determine query method: native knn or script_score + use_native_knn = config.online_store.use_native_knn + + if use_native_knn: + # Native knn query (fast, approximate) + # Uses the similarity metric configured in the index mapping + # Validate that the requested similarity is supported + if similarity not in ( + "cosine", + "dot_product", + "l2", + "l2_norm", + "euclidean", + ): + raise ValueError( + f"Unsupported similarity for native knn: {similarity}" + ) + + # Calculate num_candidates for approximate nearest neighbor search + multiplier = config.online_store.knn_num_candidates_multiplier or 2.0 + num_candidates: int = max(top_k, math.ceil(top_k * multiplier)) + + knn_clause: Dict[str, Any] = { + "field": vector_field_path, + "query_vector": embedding, + "k": top_k, + "num_candidates": num_candidates, + } + + if config.online_store.rescore_oversample is not None: + knn_clause["rescore_vector"] = { + "oversample": config.online_store.rescore_oversample + } + else: + # Legacy script_score query (slow, exact, backward compatible) + if similarity == "cosine": + script = f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0" + elif similarity == "dot_product": + script = f"dotProduct(params.query_vector, '{vector_field_path}')" + elif similarity in ("l2", "l2_norm", "euclidean"): + script = ( + f"1 / (1 + l2norm(params.query_vector, '{vector_field_path}'))" + ) + else: + raise ValueError( + f"Unsupported similarity/distance_metric: {similarity}" + ) - # Hybrid search + # Build query based on search type and query method + # Hybrid search (embedding + keyword) if embedding and query_string: - body["query"] = { - "script_score": { - "query": { - "bool": { - "must": [ - {"query_string": {"query": f'"{query_string}"'}}, - {"exists": {"field": vector_field_path}}, - ] - } - }, - "script": { - "source": script, - "params": {"query_vector": embedding}, - }, + if use_native_knn: + # Native knn with query filter + body["knn"] = knn_clause + body["query"] = {"query_string": {"query": f'"{query_string}"'}} + else: + # Legacy script_score with keyword filter + body["query"] = { + "script_score": { + "query": { + "bool": { + "must": [ + {"query_string": {"query": f'"{query_string}"'}}, + {"exists": {"field": vector_field_path}}, + ] + } + }, + "script": { + "source": script, + "params": {"query_vector": embedding}, + }, + } } - } # Vector search only elif embedding: - body["query"] = { - "script_score": { - "query": { - "bool": {"filter": [{"exists": {"field": vector_field_path}}]} - }, - "script": {"source": script, "params": {"query_vector": embedding}}, + if use_native_knn: + # Native knn query + body["knn"] = knn_clause + else: + # Legacy script_score + body["query"] = { + "script_score": { + "query": { + "bool": { + "filter": [{"exists": {"field": vector_field_path}}] + } + }, + "script": { + "source": script, + "params": {"query_vector": embedding}, + }, + } } - } # Keyword search only elif query_string: body["query"] = {"query_string": {"query": f'"{query_string}"'}} @@ -441,7 +692,7 @@ def retrieve_online_documents_v2( entity_key_serialization_version=config.entity_key_serialization_version, ) timestamp = row["_source"]["timestamp"] - timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timestamp = datetime.fromisoformat(timestamp) # Create feature dict with all requested features feature_dict = {"distance": _to_value_proto(float(row["_score"]))} @@ -460,6 +711,269 @@ def retrieve_online_documents_v2( result.append((timestamp, entity_key_proto, feature_dict)) return result + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + V3 document retrieval on Elasticsearch backend. + + Uses the ES retriever API (ES 8.14+) for all query types: single-signal + kNN, multi-signal RRF, and weighted linear fusion. + + Reserved output fields (always present in each result's feature_dict): + - ``final_score``: ES _score (higher = better). For single-signal this + is the raw kNN score; for fusion it is the rank-based composite score. + - ``signal_scores``: JSON-encoded Dict[str, float] with per-signal + scores when available, empty dict for fused results (ES does not + expose per-retriever scores after fusion). + + Reserved parameters (accepted but currently unused): + - ``distance_metric``: V3-ES always uses the metric configured in the + index mapping; this param is reserved for future per-query override. + - ``include_signal_scores``: No-op today. ``signal_scores`` follows + best-effort behavior — populated for single-signal queries, empty + for RRF/WEIGHTED_LINEAR fusion (ES does not expose per-retriever + scores after fusion). Reserved for a future ES-explain path that + will populate the breakdown for fusion strategies at extra latency + cost. + """ + del distance_metric + del include_signal_scores + + valid_strategies = {"AUTO", "RRF", "WEIGHTED_LINEAR", "VECTOR_ONLY"} + effective_strategy = fusion_strategy.upper() + if effective_strategy not in valid_strategies: + raise ValueError( + f"Unknown fusion_strategy '{fusion_strategy}'. " + f"Valid options: {sorted(valid_strategies)}" + ) + + if not embeddings: + raise ValueError( + "V3 requires at least one embedding. " + "Pass embeddings={field_name: vector}." + ) + + if not config.online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config.") + + if effective_strategy == "VECTOR_ONLY": + query_string = None + + # Normalize empty/whitespace query_string to None + if query_string is not None and not query_string.strip(): + query_string = None + + # Validate embedding keys against FeatureView schema + vector_fields = {f.name: f for f in table.features if f.vector_index} + for key in embeddings: + if key not in vector_fields: + available = sorted(vector_fields.keys()) + if not available: + raise ValueError( + f"FeatureView '{table.name}' has no vector-indexed fields. " + f"Cannot perform vector search." + ) + raise ValueError( + f"Embedding key '{key}' does not match any vector-indexed field " + f"in FeatureView '{table.name}'. " + f"Available vector fields: {available}" + ) + + # Build retrievers: one kNN per embedding, optional BM25 for query_string + retrievers_with_names: List[Tuple[str, Dict[str, Any]]] = [] + for field_name, vec in embeddings.items(): + knn_retriever: Dict[str, Any] = { + "knn": { + "field": f"{field_name}.vector_value", + "query_vector": vec, + } + } + retrievers_with_names.append((field_name, knn_retriever)) + + has_text_signal = query_string is not None + if has_text_signal: + text_retriever: Dict[str, Any] = { + "standard": {"query": {"query_string": {"query": query_string}}} + } + retrievers_with_names.append(("bm25", text_retriever)) + + is_single_signal = len(retrievers_with_names) == 1 + + if is_single_signal and effective_strategy in ("RRF", "WEIGHTED_LINEAR"): + logger.warning( + "Only one signal present — fusion_strategy '%s' has no effect. " + "The query will execute as a single-signal retrieval.", + effective_strategy, + ) + + # Set inner k based on signal count + multiplier = ( + getattr(config.online_store, "knn_num_candidates_multiplier", 2.0) or 2.0 + ) + if is_single_signal: + inner_k = top_k + else: + inner_k = min(max(top_k * 10, 100), 1000) + num_candidates = max(inner_k, math.ceil(inner_k * multiplier)) + + rescore_oversample = config.online_store.rescore_oversample + for _, retriever in retrievers_with_names: + if "knn" in retriever: + retriever["knn"]["k"] = inner_k + retriever["knn"]["num_candidates"] = num_candidates + if rescore_oversample is not None: + retriever["knn"]["rescore_vector"] = { + "oversample": rescore_oversample + } + + # Resolve execution mode + if is_single_signal: + execution_mode = "single" + elif effective_strategy == "WEIGHTED_LINEAR": + execution_mode = "linear" + else: + execution_mode = "rrf" + + # Validate WEIGHTED_LINEAR signal coverage + if execution_mode == "linear": + expected_signals = {name for name, _ in retrievers_with_names} + provided = set(signal_weights.keys()) if signal_weights else set() + missing = expected_signals - provided + if missing: + raise ValueError( + f"WEIGHTED_LINEAR fusion missing weights for signals: " + f"{sorted(missing)}. Provide a weight for each signal: " + f"embedding field names and/or 'bm25'." + ) + + # Compose query body + retrievers = [r for _, r in retrievers_with_names] + composite_key_name = _get_composite_key_name(table) + source_fields = requested_features.copy() + source_fields += ["entity_key", "timestamp"] + source_fields += composite_key_name + + if execution_mode == "single": + top_retriever = retrievers[0] + elif execution_mode == "rrf": + top_retriever = {"rrf": {"retrievers": retrievers, "rank_constant": rrf_k}} + else: + assert signal_weights is not None + weighted = [] + for signal_name, retriever in retrievers_with_names: + weight = signal_weights[signal_name] + weighted.append({"retriever": retriever, "weight": weight}) + top_retriever = {"linear": {"retrievers": weighted}} + + body: Dict[str, Any] = { + "retriever": top_retriever, + "size": top_k, + "_source": source_fields, + } + + response = self._get_client(config).search(index=table.name, body=body) + + # Parse results + result: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + + rows = response["hits"]["hits"][:top_k] + for row in rows: + entity_key = row["_source"]["entity_key"] + entity_key_proto = deserialize_entity_key( + base64.b64decode(entity_key), + entity_key_serialization_version=config.entity_key_serialization_version, + ) + timestamp = datetime.fromisoformat(row["_source"]["timestamp"]) + + feature_dict: Dict[str, ValueProto] = {} + feature_dict["final_score"] = _to_value_proto(float(row["_score"])) + + signal_scores: Dict[str, float] = {} + if is_single_signal: + embed_key = next(iter(embeddings.keys())) + signal_scores[f"vec_{embed_key}"] = float(row["_score"]) + + feature_dict["signal_scores"] = encode_signal_scores(signal_scores) + + join_key_values = _extract_join_keys(entity_key_proto) + feature_dict.update(join_key_values) + + for feature in requested_features: + if feature in ("final_score", "signal_scores"): + continue + value = row["_source"].get(feature, None) + if value is not None: + feature_dict[feature] = _to_value_proto(value) + + result.append((timestamp, entity_key_proto, feature_dict)) + + return result + + +def _build_vector_mapping( + config: RepoConfig, vector_field_length: int, table_name: str +) -> Dict[str, Any]: + """ + Build the dense_vector mapping for an Elasticsearch index, including + quantization index_options when configured. + """ + # Validate dimension-based quantization constraints + if config.online_store.vector_index_type: + index_type = config.online_store.vector_index_type + if "int4" in index_type and vector_field_length % 2 != 0: + raise ValueError( + f"int4 quantization ('{index_type}') requires even number of dimensions, " + f"got {vector_field_length} for table '{table_name}'. " + f"See https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/dense-vector" + ) + if "bbq" in index_type and vector_field_length < 64: + raise ValueError( + f"bbq quantization ('{index_type}') requires >= 64 dimensions, " + f"got {vector_field_length} for table '{table_name}'. " + f"See https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/dense-vector" + ) + + vector_mapping: Dict[str, Any] = { + "type": "dense_vector", + "dims": vector_field_length, + "index": True, + "similarity": config.online_store.similarity, + } + + if config.online_store.vector_index_type: + index_options: Dict[str, Any] = {"type": config.online_store.vector_index_type} + if config.online_store.hnsw_m is not None: + index_options["m"] = config.online_store.hnsw_m + if config.online_store.hnsw_ef_construction is not None: + index_options["ef_construction"] = config.online_store.hnsw_ef_construction + vector_mapping["index_options"] = index_options + + return vector_mapping + def _to_value_proto(value: Any) -> ValueProto: """ @@ -468,37 +982,55 @@ def _to_value_proto(value: Any) -> ValueProto: val_proto = ValueProto() if isinstance(value, ValueProto): return value - if isinstance(value, float): + # Check bool before int/float since bool is a subclass of int in Python + if isinstance(value, bool): + val_proto.bool_val = value + elif isinstance(value, float): val_proto.float_val = value - elif isinstance(value, str): - val_proto.string_val = value elif isinstance(value, int): val_proto.int64_val = value - elif isinstance(value, bool): - val_proto.bool_val = value - elif isinstance(value, list) and all(isinstance(v, float) for v in value): - val_proto.float_list_val.val.extend(value) - elif isinstance(value, dict) and "feature_value" in value: - try: - raw_bytes = base64.b64decode(value["feature_value"]) - val_proto.ParseFromString(raw_bytes) - except Exception as e: - raise ValueError(f"Failed to decode feature_value from dict: {e}") + elif isinstance(value, str): + val_proto.string_val = value + elif isinstance(value, list): + if not value: + val_proto.float_list_val.val.extend(value) + elif all(isinstance(v, float) for v in value): + val_proto.float_list_val.val.extend(value) + elif all(isinstance(v, int) for v in value): + val_proto.int64_list_val.val.extend(value) + else: + raise ValueError(f"List contains mixed or unsupported types: {value}") + elif isinstance(value, dict): + if "feature_value" in value: + try: + raw_bytes = base64.b64decode(value["feature_value"]) + val_proto.ParseFromString(raw_bytes) + except Exception as e: + raise ValueError(f"Failed to decode feature_value from dict: {e}") + else: + raise ValueError(f"Dict missing 'feature_value' key: {value}") else: - raise ValueError(f"Unsupported type for ValueProto: {type(value)}") + raise ValueError( + f"Unsupported type for ValueProto: {type(value).__name__} (value: {value})" + ) return val_proto -def _encode_feature_value(value: ValueProto) -> Dict[str, Any]: +def _encode_feature_value(value: ValueProto, is_vector: bool = False) -> Dict[str, Any]: """ Encode a ValueProto into a dictionary for Elasticsearch storage. """ encoded_value = base64.b64encode(value.SerializeToString()).decode("utf-8") result = {"feature_value": encoded_value} - vector_val = get_list_val_str(value) - if vector_val: - result["vector_value"] = json.loads(vector_val) + if is_vector: + vector_val = get_list_val_str(value) + if vector_val: + result["vector_value"] = json.loads(vector_val) + else: + logger.warning( + "Feature is marked as vector but value does not contain a valid vector." + ) if value.HasField("string_val"): result["value_text"] = value.string_val elif value.HasField("bytes_val"): diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index b77185229d5..500297309b5 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -467,6 +467,31 @@ def retrieve_online_documents_v2( f"Online store {self.__class__.__name__} does not support online retrieval" ) + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + raise NotImplementedError( + f"Online store {self.__class__.__name__} does not support " + f"V3 document retrieval" + ) + async def initialize(self, config: RepoConfig) -> None: pass diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index ec2b05759ba..79cd2a1073e 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -329,6 +329,31 @@ def retrieve_online_documents_v2( logger.error(error_msg) raise RuntimeError(error_msg) + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + raise NotImplementedError( + "V3 document retrieval is not yet supported via the remote online store. " + "Use the SDK directly against a local online store." + ) + def _extract_requested_feature_value( self, response_json: dict, diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 26d3ca3d6bf..f0475e511cb 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -435,6 +435,37 @@ def retrieve_online_documents_v2( ) return result + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: Optional[List[str]], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List: + result = [] + if self.online_store: + result = self.online_store.retrieve_online_documents_v3( + config, + table, + requested_features, + embeddings, + top_k, + query_string, + fusion_strategy, + signal_weights, + rrf_k, + distance_metric, + include_signal_scores, + ) + return result + @staticmethod def _prep_table_and_join_keys_for_ingestion( feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 3255e34de4c..b88af17efea 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -496,6 +496,29 @@ def retrieve_online_documents_v2( """ pass + @abstractmethod + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + pass + @abstractmethod def validate_data_source( self, diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 19a5d8b3158..c5575ef2231 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -83,8 +83,8 @@ "hazelcast": "feast.infra.online_stores.hazelcast_online_store.hazelcast_online_store.HazelcastOnlineStore", "ikv": "feast.infra.online_stores.ikv_online_store.ikv.IKVOnlineStore", "eg-milvus": "feast.expediagroup.vectordb.eg_milvus_online_store.EGMilvusOnlineStore", - "elasticsearch": "feast.expediagroup.vectordb.elasticsearch_online_store.ElasticsearchOnlineStore", - # "elasticsearch": "feast.infra.online_stores.elasticsearch_online_store.elasticsearch.ElasticSearchOnlineStore", + # "elasticsearch": "feast.expediagroup.vectordb.elasticsearch_online_store.ElasticsearchOnlineStore", + "elasticsearch": "feast.infra.online_stores.elasticsearch_online_store.elasticsearch.ElasticSearchOnlineStore", "remote": "feast.infra.online_stores.remote.RemoteOnlineStore", "singlestore": "feast.infra.online_stores.singlestore_online_store.singlestore.SingleStoreOnlineStore", "qdrant": "feast.infra.online_stores.qdrant_online_store.qdrant.QdrantOnlineStore", diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 8b076af92d4..af55266c682 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1421,10 +1421,6 @@ def _get_feature_view_vector_field_metadata( feature_view, ) -> Optional[Field]: vector_fields = [field for field in feature_view.schema if field.vector_index] - if len(vector_fields) > 1: - raise ValueError( - f"Feature view {feature_view.name} has multiple vector fields. Only one vector field per feature view is supported." - ) if not vector_fields: return None return vector_fields[0] diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index f8396acc2df..4706ec9f49f 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -184,6 +184,28 @@ def retrieve_online_documents_v2( ]: return [] + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + return [] + def validate_data_source( self, config: RepoConfig, diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 02d3b593bc9..908f88de317 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -81,6 +81,9 @@ from tests.integration.feature_repos.universal.online_store.dynamodb import ( DynamoDBOnlineStoreCreator, ) +from tests.integration.feature_repos.universal.online_store.elasticsearch import ( + ElasticSearchOnlineStoreCreator, +) from tests.integration.feature_repos.universal.online_store.milvus import ( MilvusOnlineStoreCreator, ) @@ -155,6 +158,7 @@ str, Tuple[Union[str, Dict[Any, Any]], Optional[Type[OnlineStoreCreator]]] ] = { "sqlite": ({"type": "sqlite"}, None), + "elasticsearch": ({"type": "elasticsearch"}, ElasticSearchOnlineStoreCreator), # uncomment below once Milvus implementation is complete # "milvus": ({"type": "milvus"}, MilvusOnlineStoreCreator), } diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py b/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py index 1e8088a997e..2fdc66cf6e1 100644 --- a/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py @@ -20,7 +20,6 @@ def create_online_store(self) -> Dict[str, Any]: "host": "localhost", "type": "elasticsearch", "port": self.container.get_exposed_port(9200), - "vector_length": 2, "similarity": "cosine", } diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 3d9390eaa45..7a0756b1d6e 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -1167,7 +1167,12 @@ def test_retrieve_online_documents_v2(environment, fake_document_data): name="item_embeddings", entities=[item], schema=[ - Field(name="embedding", dtype=Array(Float32), vector_index=True), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=2, + ), Field(name="text_field", dtype=String), Field(name="category", dtype=String), Field(name="item_id", dtype=Int64), diff --git a/sdk/python/tests/unit/infra/online_store/test_valkey.py b/sdk/python/tests/unit/infra/online_store/test_valkey.py index 02e9cb0cbdb..ed911bcb92a 100644 --- a/sdk/python/tests/unit/infra/online_store/test_valkey.py +++ b/sdk/python/tests/unit/infra/online_store/test_valkey.py @@ -1,22 +1,36 @@ import time from datetime import datetime, timedelta, timezone +import numpy as np import pytest from valkey import Valkey -from feast import Entity, Field, FileSource, RepoConfig, ValueType +from feast import Entity, FeatureView, Field, FileSource, RepoConfig, ValueType from feast.infra.online_stores.eg_valkey import ( EGValkeyOnlineStore, EGValkeyOnlineStoreConfig, + _deserialize_vector_from_bytes, + _get_valkey_vector_type, + _get_vector_index_name, + _serialize_vector_to_bytes, ) from feast.infra.online_stores.helpers import _mmh3, _redis_key from feast.protos.feast.core.SortedFeatureView_pb2 import SortOrder from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.protos.feast.types.Value_pb2 import ( + DoubleList, + FloatList, +) +from feast.protos.feast.types.Value_pb2 import ( + Value as ValueProto, +) from feast.sorted_feature_view import SortedFeatureView, SortKey from feast.types import ( + Array, Float32, + Float64, Int32, + Int64, String, UnixTimestamp, ) @@ -455,3 +469,1687 @@ def test_ttl_cleanup_no_expired_members(repo_config): remaining = redis_client.zrange(zset_key, 0, -1) assert active_member in remaining + + +class TestVectorIndexName: + """Tests for _get_vector_index_name helper function.""" + + def test_get_vector_index_name(self): + """Test index name generation follows expected format.""" + assert ( + _get_vector_index_name("my_project", "item_embeddings", "embedding") + == "my_project_item_embeddings_embedding_vidx" + ) + + def test_get_vector_index_name_with_special_chars(self): + """Test index name with underscores in names.""" + assert ( + _get_vector_index_name("prod_project", "user_item_embeddings", "vec_field") + == "prod_project_user_item_embeddings_vec_field_vidx" + ) + + +class TestGetValkeyVectorType: + """Tests for _get_valkey_vector_type helper function.""" + + def test_get_valkey_vector_type_float32(self): + """Test Float32 array maps to FLOAT32.""" + assert _get_valkey_vector_type(Array(Float32)) == "FLOAT32" + + def test_get_valkey_vector_type_float64_converts_to_float32(self): + """Test Float64 array also maps to FLOAT32 (Valkey only supports float32).""" + assert _get_valkey_vector_type(Array(Float64)) == "FLOAT32" + + def test_get_valkey_vector_type_unsupported_defaults_to_float32(self): + """Test unsupported types default to FLOAT32.""" + # Int32 array is not a valid vector type, should default to FLOAT32 + assert _get_valkey_vector_type(Array(Int32)) == "FLOAT32" + + +class TestSerializeVectorToBytes: + """Tests for _serialize_vector_to_bytes helper function.""" + + def test_serialize_vector_float32(self): + """Test Float32 vector serialization to raw bytes.""" + val = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3, 0.4])) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + result = _serialize_vector_to_bytes(val, field) + + expected = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32).tobytes() + assert result == expected + + def test_serialize_vector_float64_converts_to_float32(self): + """Test Float64 vector is converted to float32 bytes (Valkey limitation).""" + val = ValueProto(double_list_val=DoubleList(val=[0.1, 0.2, 0.3, 0.4])) + field = Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + result = _serialize_vector_to_bytes(val, field) + + # Should be float32 bytes, not float64 + expected = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32).tobytes() + assert result == expected + + def test_serialize_vector_dimension_mismatch(self): + """Test error when vector dimension doesn't match expected length.""" + val = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=128, # Expected 128, but vector has 3 elements + ) + + with pytest.raises(ValueError, match="dimension mismatch"): + _serialize_vector_to_bytes(val, field) + + def test_serialize_vector_unsupported_type(self): + """Test error when vector type is not float_list or double_list.""" + val = ValueProto(int32_val=123) # Not a list type + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + with pytest.raises(ValueError, match="Unsupported vector type"): + _serialize_vector_to_bytes(val, field) + + def test_serialize_vector_no_length_validation_when_zero(self): + """Test that vector_length=0 skips dimension validation.""" + val = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=0, # No validation + ) + + # Should not raise + result = _serialize_vector_to_bytes(val, field) + assert len(result) == 3 * 4 # 3 floats * 4 bytes each + + +class TestDeserializeVectorFromBytes: + """Tests for _deserialize_vector_from_bytes helper function.""" + + def test_deserialize_vector_float32(self): + """Test Float32 vector deserialization from raw bytes.""" + original = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + raw_bytes = original.tobytes() + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + result = _deserialize_vector_from_bytes(raw_bytes, field) + + assert result.HasField("float_list_val") + np.testing.assert_array_almost_equal( + result.float_list_val.val, original, decimal=5 + ) + + def test_deserialize_always_returns_float32(self): + """Test deserialization always returns float32 (Valkey only supports float32).""" + original = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + raw_bytes = original.tobytes() + # Even with Float64 field dtype, result should be float32 + field = Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + result = _deserialize_vector_from_bytes(raw_bytes, field) + + # Should always return float_list_val regardless of field dtype + assert result.HasField("float_list_val") + np.testing.assert_array_almost_equal( + result.float_list_val.val, original, decimal=5 + ) + + def test_roundtrip_float32(self): + """Test serialize then deserialize preserves Float32 vector values.""" + original_values = [0.123, 0.456, 0.789, 1.0] + val = ValueProto(float_list_val=FloatList(val=original_values)) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + raw_bytes = _serialize_vector_to_bytes(val, field) + result = _deserialize_vector_from_bytes(raw_bytes, field) + + np.testing.assert_array_almost_equal( + result.float_list_val.val, original_values, decimal=5 + ) + + def test_roundtrip_float64_converts_to_float32(self): + """Test Float64 input is converted to float32 during roundtrip.""" + original_values = [0.123456789, 0.987654321, 0.111111111, 0.999999999] + val = ValueProto(double_list_val=DoubleList(val=original_values)) + field = Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + raw_bytes = _serialize_vector_to_bytes(val, field) + result = _deserialize_vector_from_bytes(raw_bytes, field) + + # Result is float32, so we get float_list_val with reduced precision + assert result.HasField("float_list_val") + np.testing.assert_array_almost_equal( + result.float_list_val.val, original_values, decimal=5 + ) + + +class TestVectorConfigOptions: + """Tests for vector-related configuration options.""" + + def test_default_vector_config_values(self): + """Test that vector config has sensible defaults.""" + config = EGValkeyOnlineStoreConfig() + + assert config.vector_index_algorithm == "HNSW" + assert config.vector_index_hnsw_m == 16 + assert config.vector_index_hnsw_ef_construction == 200 + assert config.vector_index_hnsw_ef_runtime == 10 + + def test_vector_config_custom_values(self): + """Test that vector config can be customized.""" + config = EGValkeyOnlineStoreConfig( + vector_index_algorithm="FLAT", + vector_index_hnsw_m=32, + vector_index_hnsw_ef_construction=400, + vector_index_hnsw_ef_runtime=20, + ) + + assert config.vector_index_algorithm == "FLAT" + assert config.vector_index_hnsw_m == 32 + assert config.vector_index_hnsw_ef_construction == 400 + assert config.vector_index_hnsw_ef_runtime == 20 + + +class TestGenerateHsetKeysForFeatures: + """Tests for _generate_hset_keys_for_features helper method.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with mixed vector and non-vector fields.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="entity_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="entity_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + Field(name="string_feature", dtype=String), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView with only non-vector fields.""" + return FeatureView( + name="test_fv_no_vec", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="entity_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="entity_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + Field(name="string_feature", dtype=String), + ], + ) + + def test_vector_field_uses_original_name(self, feature_view_with_vector): + """Test that vector fields use original name as hset key.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=["embedding"] + ) + ) + + # Vector field should use original name + assert "embedding" in hset_keys + assert "embedding" in vector_fields + + def test_non_vector_field_uses_mmh3_hash(self, feature_view_with_vector): + """Test that non-vector fields use mmh3 hash as hset key.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=["scalar_feature"] + ) + ) + + # Non-vector field should use mmh3 hash + expected_hash = _mmh3(f"{feature_view_with_vector.name}:scalar_feature") + assert expected_hash in hset_keys + assert "scalar_feature" not in vector_fields + + def test_timestamp_key_appended(self, feature_view_with_vector): + """Test that timestamp key is always appended to hset keys.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=["embedding"] + ) + ) + + ts_key = f"_ts:{feature_view_with_vector.name}" + assert ts_key in hset_keys + assert ts_key in requested_features + + def test_mixed_fields_correct_keys(self, feature_view_with_vector): + """Test that mixed vector and non-vector fields get correct keys.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, + requested_features=["embedding", "scalar_feature", "string_feature"], + ) + ) + + # Vector field uses original name + assert "embedding" in hset_keys + + # Non-vector fields use mmh3 hash + scalar_hash = _mmh3(f"{feature_view_with_vector.name}:scalar_feature") + string_hash = _mmh3(f"{feature_view_with_vector.name}:string_feature") + assert scalar_hash in hset_keys + assert string_hash in hset_keys + + # Only embedding should be in vector_fields (now a dict) + assert set(vector_fields.keys()) == {"embedding"} + + def test_no_requested_features_uses_all(self, feature_view_with_vector): + """Test that None requested_features returns all feature view features.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=None + ) + ) + + # Should include all features from the feature view + # Features are: scalar_feature, string_feature, embedding (excluding entity_id which is join key) + assert len(requested_features) == 4 # 3 features + timestamp key + + def test_feature_view_without_vectors(self, feature_view_no_vector): + """Test feature view with no vector fields returns empty vector_fields dict.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_no_vector, + requested_features=["scalar_feature", "string_feature"], + ) + ) + + # No vector fields (empty dict) + assert vector_fields == {} + + # All fields should use mmh3 hash + for key in hset_keys: + if not isinstance(key, str) or not key.startswith("_ts:"): + assert isinstance(key, bytes) # mmh3 returns bytes + + +class TestVectorFieldValidation: + """Tests for vector field validation during index creation.""" + + def test_vector_field_missing_vector_length_raises_error( + self, valkey_online_store, repo_config_without_docker_connection_string + ): + """Test that vector field without vector_length raises ValueError.""" + from unittest.mock import MagicMock + + from valkey.exceptions import ResponseError + + # Create a FeatureView with vector field but no vector_length + fv = FeatureView( + name="test_missing_length", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + # vector_length intentionally not set (defaults to 0) + ), + ], + ) + + # Get vector fields + vector_fields = {f.name: f for f in fv.features if f.vector_index} + + # Mock client to avoid actual connection + mock_client = MagicMock() + # Simulate index doesn't exist (ResponseError is raised by valkey-py) + mock_client.ft.return_value.info.side_effect = ResponseError("Unknown index") + + with pytest.raises(ValueError, match="vector_length"): + valkey_online_store._create_vector_index_if_not_exists( + mock_client, + repo_config_without_docker_connection_string, + fv, + vector_fields, + ) + + def test_vector_field_with_negative_vector_length_raises_error( + self, valkey_online_store, repo_config_without_docker_connection_string + ): + """Test that vector field with negative vector_length raises ValueError.""" + from unittest.mock import MagicMock + + from valkey.exceptions import ResponseError + + # Create a FeatureView with vector field with negative vector_length + fv = FeatureView( + name="test_negative_length", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=-1, + ), + ], + ) + + vector_fields = {f.name: f for f in fv.features if f.vector_index} + + mock_client = MagicMock() + mock_client.ft.return_value.info.side_effect = ResponseError("Unknown index") + + with pytest.raises(ValueError, match="vector_length"): + valkey_online_store._create_vector_index_if_not_exists( + mock_client, + repo_config_without_docker_connection_string, + fv, + vector_fields, + ) + + +class TestVectorIndexCreation: + """Tests for vector index creation with correct schema.""" + + def test_index_includes_project_tag_field( + self, valkey_online_store, repo_config_without_docker_connection_string + ): + """Test that index schema includes TagField for __project__ filtering.""" + from unittest.mock import MagicMock + + from valkey.exceptions import ResponseError + + fv = FeatureView( + name="test_with_project_tag", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ), + ], + ) + + vector_fields = {f.name: f for f in fv.features if f.vector_index} + + mock_client = MagicMock() + # Simulate index doesn't exist + mock_client.ft.return_value.info.side_effect = ResponseError("Unknown index") + + valkey_online_store._create_vector_index_if_not_exists( + mock_client, + repo_config_without_docker_connection_string, + fv, + vector_fields, + ) + + # Verify create_index was called + mock_client.ft.return_value.create_index.assert_called_once() + + # Get the fields argument + call_kwargs = mock_client.ft.return_value.create_index.call_args + fields = call_kwargs.kwargs.get("fields") or call_kwargs.args[0] + + # Verify we have both VectorField and TagField + field_types = [type(f).__name__ for f in fields] + assert "VectorField" in field_types, "Index should include VectorField" + assert "TagField" in field_types, ( + "Index should include TagField for __project__" + ) + + # Verify TagField is for __project__ + tag_fields = [f for f in fields if type(f).__name__ == "TagField"] + assert len(tag_fields) == 1 + assert tag_fields[0].name == "__project__" + + +# ============================================================================ +# Vector Support Integration Tests (Docker Required) +# ============================================================================ + + +def _create_feature_view_with_vector_field(): + """Create a FeatureView with a vector embedding field.""" + fv = FeatureView( + name="item_embeddings", + source=FileSource( + name="item_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="item_name", dtype=String), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + return fv + + +def _make_vector_rows(): + """Generate rows with vector embeddings.""" + now = datetime.now(tz=timezone.utc) + return [ + ( + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + { + "item_name": ValueProto(string_val="item_1"), + "embedding": ValueProto( + float_list_val=FloatList(val=[0.1, 0.2, 0.3, 0.4]) + ), + }, + now, + None, + ), + ( + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=2)], + ), + { + "item_name": ValueProto(string_val="item_2"), + "embedding": ValueProto( + float_list_val=FloatList(val=[0.5, 0.6, 0.7, 0.8]) + ), + }, + now, + None, + ), + ] + + +@pytest.mark.docker +def test_valkey_online_write_batch_with_vector_field( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test writing a FeatureView with vector field stores data correctly.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data - note: index creation will fail without Search module, + # but the write itself should work for storage verification + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + # If Search module is not available, index creation will fail + # This is expected with basic Valkey container + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + # Verify data was stored + redis_client = _make_redis_client(repo_config) + + entity_key = EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ) + valkey_key_bin = _redis_key( + repo_config.project, + entity_key, + entity_key_serialization_version=repo_config.entity_key_serialization_version, + ) + + stored_data = redis_client.hgetall(valkey_key_bin) + + # Verify vector field is stored with original name (not hashed) + assert b"embedding" in stored_data + + # Verify non-vector field is stored with mmh3 hash + item_name_key = _mmh3(f"{feature_view.name}:item_name") + assert item_name_key in stored_data + + # Verify vector bytes can be deserialized + embedding_bytes = stored_data[b"embedding"] + vector = np.frombuffer(embedding_bytes, dtype=np.float32) + np.testing.assert_array_almost_equal(vector, [0.1, 0.2, 0.3, 0.4], decimal=5) + + # Verify __project__ is stored for vector search filtering + assert b"__project__" in stored_data + # Should be stored as string (valkey-py encodes to bytes, but value should match project) + assert stored_data[b"__project__"] == repo_config.project.encode() + + # Verify __entity_key__ is stored for entity key retrieval + assert b"__entity_key__" in stored_data + + +@pytest.mark.docker +def test_valkey_online_read_with_vector_field( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading a FeatureView with vector field deserializes correctly.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + # Read data back + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=2)], + ), + ] + + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + ) + + # Verify results + assert len(results) == 2 + + # Check first entity + ts1, features1 = results[0] + assert ts1 is not None + assert "embedding" in features1 + assert "item_name" in features1 + + # Verify vector values + embedding1 = features1["embedding"] + assert embedding1.HasField("float_list_val") + np.testing.assert_array_almost_equal( + embedding1.float_list_val.val, [0.1, 0.2, 0.3, 0.4], decimal=5 + ) + + # Check second entity + ts2, features2 = results[1] + embedding2 = features2["embedding"] + np.testing.assert_array_almost_equal( + embedding2.float_list_val.val, [0.5, 0.6, 0.7, 0.8], decimal=5 + ) + + +@pytest.mark.docker +def test_valkey_online_read_with_requested_features_vector_only( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading only the vector field using requested_features parameter.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + ] + + # Request only the vector field + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + requested_features=["embedding"], + ) + + assert len(results) == 1 + ts, features = results[0] + + # Should only have the embedding feature + assert "embedding" in features + assert "item_name" not in features + + # Verify vector values + embedding = features["embedding"] + assert embedding.HasField("float_list_val") + np.testing.assert_array_almost_equal( + embedding.float_list_val.val, [0.1, 0.2, 0.3, 0.4], decimal=5 + ) + + +@pytest.mark.docker +def test_valkey_online_read_with_requested_features_non_vector_only( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading only non-vector fields using requested_features parameter.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + ] + + # Request only the non-vector field + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + requested_features=["item_name"], + ) + + assert len(results) == 1 + ts, features = results[0] + + # Should only have the item_name feature + assert "item_name" in features + assert "embedding" not in features + + # Verify string value + assert features["item_name"].string_val == "item_1" + + +@pytest.mark.docker +def test_valkey_online_read_with_requested_features_mixed( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading a mix of vector and non-vector fields using requested_features.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=2)], + ), + ] + + # Request both vector and non-vector fields + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + requested_features=["embedding", "item_name"], + ) + + assert len(results) == 1 + ts, features = results[0] + + # Should have both features + assert "embedding" in features + assert "item_name" in features + + # Verify vector values + embedding = features["embedding"] + np.testing.assert_array_almost_equal( + embedding.float_list_val.val, [0.5, 0.6, 0.7, 0.8], decimal=5 + ) + + # Verify string value + assert features["item_name"].string_val == "item_2" + + +class TestGetVectorFieldForSearch: + """Tests for _get_vector_field_for_search helper method.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with vector field for testing.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView without vector fields.""" + return FeatureView( + name="test_fv_no_vector", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + ], + ) + + def test_returns_vector_field_from_requested_features( + self, feature_view_with_vector + ): + """Test that vector field is returned when in requested_features.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_with_vector, + requested_features=["embedding", "scalar_feature"], + ) + assert result is not None + assert result.name == "embedding" + + def test_returns_first_vector_field_when_not_in_requested( + self, feature_view_with_vector + ): + """Test that first vector field is returned when not in requested_features.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_with_vector, requested_features=["scalar_feature"] + ) + assert result is not None + assert result.name == "embedding" + + def test_returns_none_for_no_vector_fields(self, feature_view_no_vector): + """Test that None is returned when no vector fields exist.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_no_vector, requested_features=["scalar_feature"] + ) + assert result is None + + +class TestSerializeEmbeddingForSearch: + """Tests for _serialize_embedding_for_search helper method.""" + + @pytest.fixture + def float32_vector_field(self): + """Create a Float32 vector field.""" + return Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + @pytest.fixture + def float64_vector_field(self): + """Create a Float64 vector field.""" + return Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + def test_serializes_to_float32_bytes(self, float32_vector_field): + """Test that embedding is serialized to float32 bytes.""" + store = EGValkeyOnlineStore() + embedding = [0.1, 0.2, 0.3, 0.4] + result = store._serialize_embedding_for_search(embedding, float32_vector_field) + + # Verify it's bytes + assert isinstance(result, bytes) + + # Verify length (4 floats * 4 bytes each = 16 bytes) + assert len(result) == 16 + + # Verify values can be deserialized back + arr = np.frombuffer(result, dtype=np.float32) + np.testing.assert_array_almost_equal(arr, embedding, decimal=5) + + def test_serializes_to_float64_bytes(self, float64_vector_field): + """Test that embedding is serialized to float64 bytes for Float64 fields.""" + store = EGValkeyOnlineStore() + embedding = [0.1, 0.2, 0.3, 0.4] + result = store._serialize_embedding_for_search(embedding, float64_vector_field) + + # Verify it's bytes + assert isinstance(result, bytes) + + # Verify length (4 doubles * 8 bytes each = 32 bytes) + assert len(result) == 32 + + # Verify values can be deserialized back + arr = np.frombuffer(result, dtype=np.float64) + np.testing.assert_array_almost_equal(arr, embedding, decimal=10) + + def test_raises_error_on_dimension_mismatch(self, float32_vector_field): + """Test that ValueError is raised when embedding dimension doesn't match field.""" + store = EGValkeyOnlineStore() + # Field expects 4 dimensions, but we provide 3 + embedding = [0.1, 0.2, 0.3] + with pytest.raises(ValueError, match="dimension .* does not match"): + store._serialize_embedding_for_search(embedding, float32_vector_field) + + +class TestRetrieveOnlineDocumentsV2Validation: + """Tests for retrieve_online_documents_v2 input validation.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with vector field for testing.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView without vector fields.""" + return FeatureView( + name="test_fv_no_vector", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + ], + ) + + @pytest.fixture + def repo_config(self): + """Create a minimal RepoConfig for testing.""" + return RepoConfig( + project="test_project", + provider="local", + registry="test_registry.db", + online_store=EGValkeyOnlineStoreConfig( + type="eg-valkey", + connection_string="localhost:6379", + ), + entity_key_serialization_version=3, + ) + + def test_raises_error_when_embedding_is_none( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when embedding is None.""" + store = EGValkeyOnlineStore() + with pytest.raises(ValueError, match="embedding must be provided"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=None, + top_k=10, + ) + + def test_raises_error_when_query_string_provided( + self, repo_config, feature_view_with_vector + ): + """Test that NotImplementedError is raised when query_string is provided.""" + store = EGValkeyOnlineStore() + with pytest.raises(NotImplementedError, match="Keyword search"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + query_string="test query", + ) + + def test_raises_error_when_no_vector_field( + self, repo_config, feature_view_no_vector + ): + """Test that ValueError is raised when FeatureView has no vector fields.""" + store = EGValkeyOnlineStore() + with pytest.raises(ValueError, match="No vector field found"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_no_vector, + requested_features=["scalar_feature"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + ) + + def test_raises_error_when_dimension_mismatch( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when embedding dimension doesn't match field.""" + store = EGValkeyOnlineStore() + # feature_view_with_vector has vector_length=4, so 3-dim embedding should fail + with pytest.raises(ValueError, match="Embedding dimension .* does not match"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3], # Wrong dimension (3 instead of 4) + top_k=10, + ) + + def test_raises_error_when_index_does_not_exist( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when vector index doesn't exist.""" + from unittest.mock import MagicMock, patch + + from valkey.exceptions import ResponseError + + store = EGValkeyOnlineStore() + + # Mock the client to simulate "no such index" error + mock_client = MagicMock() + mock_client.ft.return_value.search.side_effect = ResponseError("no such index") + + with patch.object(store, "_get_client", return_value=mock_client): + with pytest.raises(ValueError, match="does not exist.*materialize"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + ) + + +class TestExecuteVectorSearch: + """Tests for _execute_vector_search helper method.""" + + @pytest.fixture + def store(self): + return EGValkeyOnlineStore() + + def test_project_name_with_hyphen_is_escaped(self, store): + """Test that project names with hyphens are backslash-escaped in queries.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="my-project", # Hyphen in project name + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + mock_client.ft.return_value.search.assert_called_once() + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # Hyphen should be backslash-escaped to prevent interpretation as negation + assert r"my\-project" in query.query_string() + + def test_project_name_with_double_quote_is_escaped(self, store): + """Test that double quotes in project names are backslash-escaped.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project='my"project', # Double quote in project name + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + mock_client.ft.return_value.search.assert_called_once() + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # Double quote should be backslash-escaped + assert r"\"" in query.query_string() + + def test_no_sortby_in_knn_query(self, store): + """Test that KNN queries do not use SORTBY (engine sorts by distance automatically).""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # KNN results are sorted by the engine; no explicit SORTBY should be set + assert query._sortby is None + + def test_default_distance_is_infinity_not_zero(self, store): + """Test that missing __distance__ defaults to infinity, not 0.0.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_doc = MagicMock() + mock_doc.id = "test_key" + # Simulate missing __distance__ attribute + del mock_doc.__distance__ + + mock_result = MagicMock() + mock_result.docs = [mock_doc] + mock_client.ft.return_value.search.return_value = mock_result + + results = store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + # Distance should default to infinity, not 0.0 + # 0.0 would incorrectly indicate a perfect match + assert len(results) == 1 + doc_key, distance = results[0] + assert distance == float("inf") + + +class TestRetrieveOnlineDocumentsV3Validation: + """Tests for retrieve_online_documents_v3 input validation on Valkey.""" + + @pytest.fixture + def store(self): + return EGValkeyOnlineStore() + + @pytest.fixture + def feature_view_with_vector(self): + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + Field(name="title", dtype=String), + ], + ) + + @pytest.fixture + def feature_view_multi_vector(self): + return FeatureView( + name="test_fv_multi", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="title_vec", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + Field( + name="body_vec", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def repo_config(self): + return RepoConfig( + project="test_project", + provider="local", + registry="test_registry.db", + online_store=EGValkeyOnlineStoreConfig( + type="eg-valkey", + connection_string="localhost:6379", + ), + entity_key_serialization_version=3, + ) + + def test_empty_embeddings_raises( + self, store, repo_config, feature_view_with_vector + ): + with pytest.raises(ValueError, match="at least one embedding"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={}, + top_k=5, + ) + + def test_multi_embedding_raises( + self, store, repo_config, feature_view_multi_vector + ): + with pytest.raises(ValueError, match="single-vector search only"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_multi_vector, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + + def test_rrf_strategy_raises(self, store, repo_config, feature_view_with_vector): + with pytest.raises(ValueError, match="not supported on Valkey"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="RRF", + ) + + def test_weighted_linear_strategy_raises( + self, store, repo_config, feature_view_with_vector + ): + with pytest.raises(ValueError, match="not supported on Valkey"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"embedding": 1.0}, + ) + + def test_unknown_fusion_strategy_raises( + self, store, repo_config, feature_view_with_vector + ): + with pytest.raises(ValueError, match="Unknown fusion_strategy"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="BOGUS", + ) + + def test_auto_strategy_accepted(self, store, repo_config, feature_view_with_vector): + """AUTO should not raise — it delegates to V2.""" + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.5)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="AUTO", + ) + assert len(results) == 1 + + def test_vector_only_strategy_accepted( + self, store, repo_config, feature_view_with_vector + ): + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.3)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="VECTOR_ONLY", + ) + assert len(results) == 1 + + def test_query_string_warns_and_dropped( + self, store, repo_config, feature_view_with_vector + ): + """query_string should trigger a logger.warning and be passed as None to V2.""" + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.5)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ) as mock_v2: + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="test query", + fusion_strategy="AUTO", + ) + # Verify query_string=None was passed to V2 + call_kwargs = mock_v2.call_args[1] + assert call_kwargs.get("query_string") is None + + def test_include_signal_scores_accepted( + self, store, repo_config, feature_view_with_vector + ): + from unittest.mock import patch + + mock_results = [] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + include_signal_scores=False, + ) + assert results == [] + + +class TestRetrieveOnlineDocumentsV3ResponseTransform: + """Tests for V3 response transformation on Valkey (V2→V3 wrapper).""" + + @pytest.fixture + def store(self): + return EGValkeyOnlineStore() + + @pytest.fixture + def feature_view(self): + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + Field(name="title", dtype=String), + ], + ) + + @pytest.fixture + def repo_config(self): + return RepoConfig( + project="test_project", + provider="local", + registry="test_registry.db", + online_store=EGValkeyOnlineStoreConfig( + type="eg-valkey", + connection_string="localhost:6379", + ), + entity_key_serialization_version=3, + ) + + def test_distance_renamed_to_final_score(self, store, repo_config, feature_view): + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + { + "distance": ValueProto(double_val=0.25), + "title": ValueProto(string_val="hello"), + }, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert len(results) == 1 + ts, ek, feat_dict = results[0] + assert "final_score" in feat_dict + assert "distance" not in feat_dict + assert feat_dict["final_score"].double_val == pytest.approx(0.25) + + def test_signal_scores_populated(self, store, repo_config, feature_view): + from unittest.mock import patch + + from feast.infra.online_stores._signal_scores import decode_signal_scores + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.5)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + assert "signal_scores" in feat_dict + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert "vec_embedding" in scores + assert scores["vec_embedding"] == pytest.approx(0.5) + + def test_none_feature_dict_passthrough(self, store, repo_config, feature_view): + from unittest.mock import patch + + mock_results = [ + (datetime(2024, 1, 1), None, None), + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert len(results) == 1 + assert results[0][2] is None + + def test_missing_distance_no_final_score(self, store, repo_config, feature_view): + """If V2 returns no distance, signal_scores should be empty.""" + from unittest.mock import patch + + from feast.infra.online_stores._signal_scores import decode_signal_scores + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"title": ValueProto(string_val="test")}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert scores == {} + + def test_empty_v2_results(self, store, repo_config, feature_view): + from unittest.mock import patch + + with patch.object(store, "retrieve_online_documents_v2", return_value=[]): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert results == [] diff --git a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py new file mode 100644 index 00000000000..81c58666450 --- /dev/null +++ b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py @@ -0,0 +1,1163 @@ +import base64 +import json +import math +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from feast import Entity, FeatureView, RepoConfig +from feast.field import Field +from feast.infra.online_stores._signal_scores import decode_signal_scores +from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, + _encode_feature_value, + _to_value_proto, +) +from feast.protos.feast.types.Value_pb2 import ( + FloatList, + Int64List, +) +from feast.protos.feast.types.Value_pb2 import ( + Value as ValueProto, +) +from feast.types import Array, Float32, Int64, String +from feast.value_type import ValueType + + +class TestEncodeFeatureValue: + def test_vector_field_includes_vector_value(self): + """When is_vector=True and value is a float list, vector_value should be present.""" + value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + result = _encode_feature_value(value, is_vector=True) + + assert "vector_value" in result + assert result["vector_value"] == pytest.approx([0.1, 0.2, 0.3]) + + def test_non_vector_list_excludes_vector_value(self): + """When is_vector=False and value is a float list, vector_value should NOT be present.""" + value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + result = _encode_feature_value(value, is_vector=False) + + assert "vector_value" not in result + + def test_non_vector_int_list_excludes_vector_value(self): + """An int64 list with is_vector=False should not produce vector_value.""" + value = ValueProto(int64_list_val=Int64List(val=[1, 2, 3])) + result = _encode_feature_value(value, is_vector=False) + + assert "vector_value" not in result + + def test_string_value_has_value_text(self): + """A string ValueProto should produce value_text, not vector_value.""" + value = ValueProto(string_val="hello") + result = _encode_feature_value(value, is_vector=False) + + assert result["value_text"] == "hello" + assert "vector_value" not in result + + def test_feature_value_always_present(self): + """feature_value (base64 binary) should always be present regardless of is_vector.""" + vector_value = ValueProto(float_list_val=FloatList(val=[1.0, 2.0])) + string_value = ValueProto(string_val="test") + int_value = ValueProto(int64_val=42) + + for val in [vector_value, string_value, int_value]: + for is_vector in [True, False]: + result = _encode_feature_value(val, is_vector=is_vector) + assert "feature_value" in result + # Verify it's valid base64 that deserializes back + decoded = base64.b64decode(result["feature_value"]) + roundtrip = ValueProto() + roundtrip.ParseFromString(decoded) + + def test_default_is_vector_false(self): + """Calling without is_vector should default to False (no vector_value).""" + value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2])) + result = _encode_feature_value(value) + + assert "vector_value" not in result + + +def _make_feature_view( + name="test_fv", + vector_fields=None, + extra_fields=None, +): + """Helper to build a FeatureView with optional vector fields.""" + from feast import FileSource + + schema = [Field(name="item_id", dtype=Int64)] + if vector_fields is None: + vector_fields = [("embedding", 4)] + for fname, dim in vector_fields: + schema.append( + Field( + name=fname, + dtype=Array(Float32), + vector_index=True, + vector_length=dim, + vector_search_metric="COSINE", + ) + ) + for fname, dtype in extra_fields or []: + schema.append(Field(name=fname, dtype=dtype)) + + return FeatureView( + name=name, + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=schema, + ) + + +_repo_config_counter = 0 + + +def _make_repo_config(vector_enabled=True, **overrides): + """Helper to build a RepoConfig with ES online store.""" + global _repo_config_counter + _repo_config_counter += 1 + es_config = ElasticSearchOnlineStoreConfig( + type="elasticsearch", + host="localhost", + port=9200, + vector_enabled=vector_enabled, + **overrides, + ) + return RepoConfig( + project="test_project", + provider="local", + registry=f"/tmp/test_registry_{_repo_config_counter}.db", + online_store=es_config, + entity_key_serialization_version=3, + ) + + +class TestRetrieveOnlineDocumentsV3Validation: + """Tests for retrieve_online_documents_v3 input validation.""" + + @pytest.fixture + def store(self): + return ElasticSearchOnlineStore() + + @pytest.fixture + def config(self): + return _make_repo_config() + + @pytest.fixture + def fv_single_vector(self): + return _make_feature_view( + vector_fields=[("embedding", 4)], + extra_fields=[("title", String)], + ) + + @pytest.fixture + def fv_multi_vector(self): + return _make_feature_view( + vector_fields=[("title_vec", 4), ("body_vec", 4)], + ) + + @pytest.fixture + def fv_no_vector(self): + return _make_feature_view(vector_fields=[]) + + def test_empty_embeddings_raises(self, store, config, fv_single_vector): + with pytest.raises(ValueError, match="at least one embedding"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={}, + top_k=5, + ) + + def test_vector_not_enabled_raises(self, store, fv_single_vector): + config = _make_repo_config(vector_enabled=False) + with pytest.raises(ValueError, match="not enabled"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + def test_unknown_fusion_strategy_raises(self, store, config, fv_single_vector): + with pytest.raises(ValueError, match="Unknown fusion_strategy"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="INVALID", + ) + + def test_unknown_embedding_key_raises(self, store, config, fv_single_vector): + with pytest.raises(ValueError, match="does not match any vector-indexed"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"nonexistent_field": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + def test_no_vector_fields_raises(self, store, config, fv_no_vector): + with pytest.raises(ValueError, match="no vector-indexed fields"): + store.retrieve_online_documents_v3( + config=config, + table=fv_no_vector, + requested_features=["item_id"], + embeddings={"some_field": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + def test_weighted_linear_missing_weights_raises( + self, store, config, fv_multi_vector + ): + with pytest.raises(ValueError, match="missing weights for signals"): + store.retrieve_online_documents_v3( + config=config, + table=fv_multi_vector, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + query_string="test", + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"title_vec": 0.5}, + ) + + def test_weighted_linear_partial_weights_raises( + self, store, config, fv_multi_vector + ): + """Missing bm25 weight when query_string is present.""" + with pytest.raises(ValueError, match=r"missing weights for signals.*\bbm25\b"): + store.retrieve_online_documents_v3( + config=config, + table=fv_multi_vector, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + query_string="test", + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"title_vec": 0.5, "body_vec": 0.3}, + ) + + def test_vector_only_nullifies_query_string(self, store, config, fv_single_vector): + """VECTOR_ONLY should drop query_string before building retrievers.""" + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="should be dropped", + fusion_strategy="VECTOR_ONLY", + ) + + call_body = mock_client.search.call_args[1]["body"] + retriever = call_body["retriever"] + assert "knn" in retriever, "VECTOR_ONLY should produce a knn retriever" + assert "standard" not in json.dumps(retriever) + assert "rrf" not in retriever + + def test_empty_query_string_treated_as_none(self, store, config, fv_single_vector): + """Whitespace-only query_string should not create a BM25 retriever.""" + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string=" ", + ) + + call_body = mock_client.search.call_args[1]["body"] + retriever = call_body["retriever"] + assert "knn" in retriever + assert "standard" not in json.dumps(retriever) + + @pytest.mark.parametrize( + "strategy", ["auto", "Auto", "AUTO", "rrf", "Rrf", "vector_only"] + ) + def test_strategy_case_insensitive(self, store, config, fv_single_vector, strategy): + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy=strategy, + ) + mock_client.search.assert_called_once() + + @pytest.mark.parametrize("flag", [True, False]) + def test_include_signal_scores_accepted_but_ignored( + self, store, config, fv_single_vector, flag + ): + """include_signal_scores is a reserved param; should not raise for True or False.""" + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + include_signal_scores=flag, + ) + + +class TestRetrieveOnlineDocumentsV3QueryBuilding: + """Tests for the ES query body construction.""" + + @pytest.fixture + def store(self): + return ElasticSearchOnlineStore() + + @pytest.fixture + def config(self): + return _make_repo_config() + + @pytest.fixture + def fv_single(self): + return _make_feature_view( + vector_fields=[("embedding", 4)], + extra_fields=[("title", String)], + ) + + @pytest.fixture + def fv_multi(self): + return _make_feature_view( + vector_fields=[("title_vec", 4), ("body_vec", 4)], + ) + + def _call_and_capture_body(self, store, config, table, **kwargs): + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3(config=config, table=table, **kwargs) + return mock_client.search.call_args[1]["body"] + + def test_single_vector_uses_knn_retriever(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + retriever = body["retriever"] + assert "knn" in retriever + assert retriever["knn"]["field"] == "embedding.vector_value" + assert retriever["knn"]["query_vector"] == [0.1, 0.2, 0.3, 0.4] + assert retriever["knn"]["k"] == 5 + assert body["size"] == 5 + + def test_single_vector_knn_k_equals_top_k(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=10, + ) + assert body["retriever"]["knn"]["k"] == 10 + + def test_multi_vector_uses_rrf_by_default(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + retriever = body["retriever"] + assert "rrf" in retriever + assert len(retriever["rrf"]["retrievers"]) == 2 + + def test_multi_vector_rrf_has_rank_constant(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + rrf_k=42, + ) + assert body["retriever"]["rrf"]["rank_constant"] == 42 + + def test_query_string_adds_bm25_retriever(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="search term", + ) + retriever = body["retriever"] + assert "rrf" in retriever + retrievers = retriever["rrf"]["retrievers"] + assert len(retrievers) == 2 + retriever_types = [list(r.keys())[0] for r in retrievers] + assert "knn" in retriever_types + assert "standard" in retriever_types + + def test_single_vector_plus_bm25_uses_rrf(self, store, config, fv_single): + """Single vector + query_string should produce RRF with knn + standard retrievers.""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="search term", + fusion_strategy="RRF", + ) + retriever = body["retriever"] + assert "rrf" in retriever + retrievers = retriever["rrf"]["retrievers"] + assert len(retrievers) == 2 + types = {list(r.keys())[0] for r in retrievers} + assert types == {"knn", "standard"} + for r in retrievers: + if "knn" in r: + assert r["knn"]["field"] == "embedding.vector_value" + if "standard" in r: + assert r["standard"]["query"]["query_string"]["query"] == "search term" + + def test_weighted_linear_uses_linear_retriever(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"title_vec": 0.7, "body_vec": 0.3}, + ) + retriever = body["retriever"] + assert "linear" in retriever + weighted = retriever["linear"]["retrievers"] + assert len(weighted) == 2 + weights = [w["weight"] for w in weighted] + assert 0.7 in weights + assert 0.3 in weights + + def test_multi_signal_inner_k_larger_than_top_k(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + for r in body["retriever"]["rrf"]["retrievers"]: + if "knn" in r: + assert r["knn"]["k"] >= 100 + assert r["knn"]["k"] <= 1000 + + def test_num_candidates_uses_math_ceil(self, store, config, fv_single): + """Verify math.ceil is applied by using a multiplier that produces a fraction.""" + object.__setattr__(config.online_store, "knn_num_candidates_multiplier", 1.5) + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=3, + ) + k = body["retriever"]["knn"]["k"] + num_candidates = body["retriever"]["knn"]["num_candidates"] + # 3 * 1.5 = 4.5 → ceil → 5, proving ceil is used (floor would give 4) + assert num_candidates == math.ceil(k * 1.5) + assert num_candidates == 5 + assert num_candidates != int(k * 1.5) + + def test_rrf_single_signal_executes_as_single(self, store, config, fv_single): + """RRF with only one signal should still succeed (logged warning, not error).""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="RRF", + ) + retriever = body["retriever"] + assert "knn" in retriever, "Single signal RRF degrades to single retriever" + assert "rrf" not in retriever + + def test_auto_single_signal_uses_direct_knn(self, store, config, fv_single): + """AUTO with one vector and no query_string should produce a bare knn retriever.""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="AUTO", + ) + retriever = body["retriever"] + assert "knn" in retriever, "Single signal AUTO should use direct knn" + assert "rrf" not in retriever + assert "linear" not in retriever + + def test_source_fields_include_metadata(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + source = body["_source"] + assert "entity_key" in source + assert "timestamp" in source + assert "title" in source + + def test_rescore_oversample_applied_to_single_knn(self, store, fv_single): + """When rescore_oversample is configured on a quantized index, the V3 + kNN retriever should include a rescore_vector clause.""" + config = _make_repo_config( + vector_index_type="int8_hnsw", rescore_oversample=3.0 + ) + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + knn = body["retriever"]["knn"] + assert knn["rescore_vector"] == {"oversample": 3.0} + + def test_rescore_oversample_applied_to_all_multi_vector_knns(self, store, fv_multi): + """Multi-vector V3 queries should apply rescore_vector to every kNN + retriever, not just the first one.""" + config = _make_repo_config( + vector_index_type="int4_hnsw", rescore_oversample=2.5 + ) + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + retrievers = body["retriever"]["rrf"]["retrievers"] + assert len(retrievers) == 2 + for r in retrievers: + assert r["knn"]["rescore_vector"] == {"oversample": 2.5} + + def test_rescore_oversample_absent_when_not_configured( + self, store, config, fv_single + ): + """Default config has no rescore_oversample; the kNN clause should not + include rescore_vector.""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + assert "rescore_vector" not in body["retriever"]["knn"] + + +class TestRetrieveOnlineDocumentsV3ResponseParsing: + """Tests for parsing ES response into V3 result tuples.""" + + @pytest.fixture + def store(self): + return ElasticSearchOnlineStore() + + @pytest.fixture + def config(self): + return _make_repo_config() + + @pytest.fixture + def fv(self): + return _make_feature_view( + vector_fields=[("embedding", 4)], + extra_fields=[("title", String)], + ) + + def _mock_es_response(self, hits): + return {"hits": {"hits": hits}} + + def _make_hit(self, score, timestamp, features=None): + from feast.infra.key_encoding_utils import serialize_entity_key + from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto + + ek = EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ) + ek_bytes = serialize_entity_key(ek, entity_key_serialization_version=3) + ek_b64 = base64.b64encode(ek_bytes).decode("utf-8") + source = { + "entity_key": ek_b64, + "timestamp": timestamp, + } + if features: + source.update(features) + return {"_source": source, "_score": score} + + def test_single_result_has_final_score(self, store, config, fv): + hit = self._make_hit(0.95, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert len(results) == 1 + ts, ek, feat_dict = results[0] + assert feat_dict["final_score"].float_val == pytest.approx(0.95) + + def test_single_result_has_signal_scores(self, store, config, fv): + hit = self._make_hit(0.95, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert "vec_embedding" in scores + assert scores["vec_embedding"] == pytest.approx(0.95) + + def test_signal_scores_is_compact_sorted_json(self, store, config, fv): + """signal_scores should be compact JSON with sorted keys.""" + hit = self._make_hit(0.95, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + raw = results[0][2]["signal_scores"].string_val + assert " " not in raw + parsed = json.loads(raw) + assert list(parsed.keys()) == sorted(parsed.keys()) + + def test_multi_signal_signal_scores_are_empty(self, store, config): + fv = _make_feature_view( + vector_fields=[("title_vec", 4), ("body_vec", 4)], + ) + hit = self._make_hit(0.8, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + + feat_dict = results[0][2] + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert scores == {} + + def test_empty_results(self, store, config, fv): + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert results == [] + + def test_timestamp_parsed(self, store, config, fv): + hit = self._make_hit(0.9, "2024-06-15T12:30:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + ts = results[0][0] + assert isinstance(ts, datetime) + assert ts.year == 2024 + assert ts.month == 6 + + def test_top_k_limits_results(self, store, config, fv): + """Verify that at most top_k results are returned even if ES returns more.""" + hits = [self._make_hit(0.9 - i * 0.1, "2024-01-01T00:00:00") for i in range(5)] + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response(hits) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=3, + ) + + assert len(results) <= 3 + body = mock_client.search.call_args[1]["body"] + assert body["size"] == 3 + + def test_feature_values_included(self, store, config, fv): + encoded_val = base64.b64encode( + ValueProto(string_val="hello world").SerializeToString() + ).decode("utf-8") + hit = self._make_hit( + 0.9, + "2024-01-01T00:00:00", + features={"title": {"feature_value": encoded_val}}, + ) + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + assert "title" in feat_dict + assert feat_dict["title"].string_val == "hello world" + + +class TestElasticSearchOnlineStoreConfig: + def test_defaults(self): + """Test default config values.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + config = ElasticSearchOnlineStoreConfig() + assert config.vector_index_type is None + assert config.hnsw_m is None + assert config.hnsw_ef_construction is None + assert config.rescore_oversample is None + assert config.use_native_knn is False + assert config.knn_num_candidates_multiplier is None + + def test_valid_index_type(self): + """Test valid vector_index_type values.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + for index_type in [ + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "hnsw", + "flat", + "bbq_flat", + ]: + config = ElasticSearchOnlineStoreConfig(vector_index_type=index_type) + assert config.vector_index_type == index_type + + def test_invalid_index_type(self): + """Test invalid vector_index_type raises ValueError.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + with pytest.raises(ValueError, match="vector_index_type must be one of"): + ElasticSearchOnlineStoreConfig(vector_index_type="invalid_type") + + def test_rescore_range_validation(self): + """Test rescore_oversample range validation.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid values: (1.0, 10.0) exclusive + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=2.0 + ) + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=5.5 + ) + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=9.9 + ) + # None disables rescore + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=None + ) + + # Invalid: at or below 1.0 + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=1.0 + ) + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=0.5 + ) + + # Invalid: at or above 10.0 + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=10.0 + ) + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=50.0 + ) + + def test_rescore_requires_quantized_type(self): + """Test rescore_oversample only works with quantized types.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid: quantized type + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=2.0 + ) + + # Invalid: non-quantized type + with pytest.raises(ValueError, match="can only be used with quantized"): + ElasticSearchOnlineStoreConfig( + vector_index_type="hnsw", rescore_oversample=2.0 + ) + + # Invalid: vector_index_type is None + with pytest.raises(ValueError, match="can only be used with quantized"): + ElasticSearchOnlineStoreConfig( + vector_index_type=None, rescore_oversample=2.0 + ) + + def test_hnsw_params_require_hnsw_type(self): + """Test HNSW params only work with HNSW types.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid: HNSW type + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=32) + + # Invalid: flat type + with pytest.raises(ValueError, match="only apply to HNSW index types"): + ElasticSearchOnlineStoreConfig(vector_index_type="int8_flat", hnsw_m=32) + + def test_hnsw_m_range(self): + """Test hnsw_m range validation.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid: ES enforces its own upper limits, Feast only rejects < 1 + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=1) + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=100) + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=200) + + # Invalid: zero or negative + with pytest.raises(ValueError, match="must be >= 1"): + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=0) + + def test_knn_multiplier_validation(self): + """Test knn_num_candidates_multiplier validation.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid + ElasticSearchOnlineStoreConfig(knn_num_candidates_multiplier=1.0) + ElasticSearchOnlineStoreConfig(knn_num_candidates_multiplier=10.0) + + # Invalid: too low + with pytest.raises(ValueError, match="must be >= 1.0"): + ElasticSearchOnlineStoreConfig(knn_num_candidates_multiplier=0.5) + + +class TestCreateIndexWithQuantization: + def test_index_mapping_with_int8_quantization(self): + """Test index mapping includes quantization settings.""" + from unittest.mock import MagicMock + + from feast import FeatureView, Field, RepoConfig + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, + ) + from feast.types import Array, Float32 + + config = RepoConfig( + project="test", + registry="registry.db", + provider="local", + online_store=ElasticSearchOnlineStoreConfig( + vector_enabled=True, + similarity="cosine", + vector_index_type="int8_hnsw", + hnsw_m=32, + hnsw_ef_construction=200, + ), + ) + + fv = MagicMock(spec=FeatureView) + fv.name = "test_fv" + fv.schema = [ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_length=128, + vector_search_metric="cosine", + ) + ] + + store = ElasticSearchOnlineStore() + mock_client = MagicMock() + mock_client.indices.exists.return_value = False + store._client = mock_client + + store.create_index(config, fv) + + # Verify create was called + assert mock_client.indices.create.called + call_args = mock_client.indices.create.call_args + mapping = call_args.kwargs["mappings"] + + # Check quantization settings in dynamic template + template = mapping["dynamic_templates"][0]["feature_objects"]["mapping"] + vector_props = template["properties"]["vector_value"] + + assert vector_props["type"] == "dense_vector" + assert vector_props["dims"] == 128 + assert "index_options" in vector_props + assert vector_props["index_options"]["type"] == "int8_hnsw" + assert vector_props["index_options"]["m"] == 32 + assert vector_props["index_options"]["ef_construction"] == 200 + + def test_int4_requires_even_dimensions(self): + """Test int4 quantization validates even dimensions.""" + from unittest.mock import MagicMock + + from feast import FeatureView, Field, RepoConfig + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, + ) + from feast.types import Array, Float32 + + config = RepoConfig( + project="test", + registry="registry.db", + provider="local", + online_store=ElasticSearchOnlineStoreConfig( + vector_enabled=True, vector_index_type="int4_hnsw" + ), + ) + + fv = MagicMock(spec=FeatureView) + fv.name = "test_fv" + fv.schema = [ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_length=127, # Odd number + ) + ] + + store = ElasticSearchOnlineStore() + mock_client = MagicMock() + mock_client.indices.exists.return_value = False + store._client = mock_client + + with pytest.raises(ValueError, match="requires even number of dimensions"): + store.create_index(config, fv) + + def test_bbq_requires_min_dimensions(self): + """Test bbq quantization validates minimum dimensions.""" + from unittest.mock import MagicMock + + from feast import FeatureView, Field, RepoConfig + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, + ) + from feast.types import Array, Float32 + + config = RepoConfig( + project="test", + registry="registry.db", + provider="local", + online_store=ElasticSearchOnlineStoreConfig( + vector_enabled=True, vector_index_type="bbq_hnsw" + ), + ) + + fv = MagicMock(spec=FeatureView) + fv.name = "test_fv" + fv.schema = [ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_length=32, # Less than 64 + ) + ] + + store = ElasticSearchOnlineStore() + mock_client = MagicMock() + mock_client.indices.exists.return_value = False + store._client = mock_client + + with pytest.raises(ValueError, match="requires >= 64 dimensions"): + store.create_index(config, fv) + + +class TestToValueProto: + def test_bool_not_treated_as_int(self): + """bool is a subclass of int in Python; ensure True -> bool_val, not int64_val.""" + result = _to_value_proto(True) + assert result.bool_val is True + assert result.int64_val == 0 + + result = _to_value_proto(False) + assert result.bool_val is False + + def test_int(self): + result = _to_value_proto(42) + assert result.int64_val == 42 + assert result.bool_val is False + + def test_float(self): + result = _to_value_proto(3.14) + assert result.float_val == pytest.approx(3.14) + + def test_string(self): + result = _to_value_proto("hello") + assert result.string_val == "hello" + + def test_float_list(self): + result = _to_value_proto([1.0, 2.0, 3.0]) + assert list(result.float_list_val.val) == pytest.approx([1.0, 2.0, 3.0]) + + def test_int_list(self): + result = _to_value_proto([1, 2, 3]) + assert list(result.int64_list_val.val) == [1, 2, 3] + + def test_mixed_list_raises(self): + with pytest.raises(ValueError, match="mixed or unsupported"): + _to_value_proto([1, "two", 3.0]) + + def test_passthrough_value_proto(self): + original = ValueProto(string_val="already a proto") + result = _to_value_proto(original) + assert result is original + + def test_unsupported_type_raises(self): + with pytest.raises(ValueError, match="Unsupported type"): + _to_value_proto(object()) diff --git a/sdk/python/tests/unit/online_store/test_signal_scores.py b/sdk/python/tests/unit/online_store/test_signal_scores.py new file mode 100644 index 00000000000..4b426202d3c --- /dev/null +++ b/sdk/python/tests/unit/online_store/test_signal_scores.py @@ -0,0 +1,76 @@ +import json + +import pytest + +from feast.infra.online_stores._signal_scores import ( + decode_signal_scores, + encode_signal_scores, +) +from feast.protos.feast.types.Value_pb2 import Value as ValueProto + + +class TestEncodeSignalScores: + def test_single_score(self): + result = encode_signal_scores({"vec_embedding": 0.95}) + assert result.HasField("string_val") + parsed = json.loads(result.string_val) + assert parsed == {"vec_embedding": 0.95} + + def test_multiple_scores(self): + scores = {"vec_title": 0.8, "vec_body": 0.6, "bm25": 12.5} + result = encode_signal_scores(scores) + parsed = json.loads(result.string_val) + assert parsed == scores + + def test_empty_dict(self): + result = encode_signal_scores({}) + assert result.string_val == "{}" + + def test_sort_keys_deterministic(self): + result_a = encode_signal_scores({"z_field": 1.0, "a_field": 2.0}) + result_b = encode_signal_scores({"a_field": 2.0, "z_field": 1.0}) + assert result_a.string_val == result_b.string_val + parsed = json.loads(result_a.string_val) + keys = list(parsed.keys()) + assert keys == sorted(keys) + + def test_compact_json_no_spaces(self): + result = encode_signal_scores({"a": 1.0, "b": 2.0}) + assert " " not in result.string_val + + def test_returns_value_proto(self): + result = encode_signal_scores({"x": 1.0}) + assert isinstance(result, ValueProto) + + +class TestDecodeSignalScores: + def test_roundtrip(self): + original = {"vec_embedding": 0.95, "bm25": 12.5} + encoded = encode_signal_scores(original) + decoded = decode_signal_scores(encoded) + assert decoded == original + + def test_roundtrip_empty(self): + encoded = encode_signal_scores({}) + decoded = decode_signal_scores(encoded) + assert decoded == {} + + def test_empty_string_val(self): + val = ValueProto() + val.string_val = "" + assert decode_signal_scores(val) == {} + + def test_no_string_field(self): + val = ValueProto() + val.int64_val = 42 + assert decode_signal_scores(val) == {} + + def test_default_value_proto(self): + val = ValueProto() + assert decode_signal_scores(val) == {} + + def test_malformed_json_raises(self): + val = ValueProto() + val.string_val = "not-json" + with pytest.raises(json.JSONDecodeError): + decode_signal_scores(val)