From 7c4e819cfe7fe6085df44140c5585684d698233b Mon Sep 17 00:00:00 2001 From: vanitabhagwat <92561664+vanitabhagwat@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:25:29 -0700 Subject: [PATCH 1/7] fix: Updated the elastic search to use vector_index defined in the feature view to identify vector fields (#348) * updated the elastic search to use vector_index defined in the feature view to identify vector fields * fix: formatting * Added logging and switched to use open source elastic search --------- Co-authored-by: vanitabhagwat --- .../elasticsearch.py | 41 ++++++++--- sdk/python/feast/repo_config.py | 4 +- .../test_elasticsearch_online_store.py | 68 +++++++++++++++++++ 3 files changed, 103 insertions(+), 10 deletions(-) create mode 100644 sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index 7e8e533281d..a7a69f3472a 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -26,6 +26,8 @@ to_naive_utc, ) +logger = logging.getLogger(__name__) + class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): """ @@ -93,6 +95,8 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: + vector_field = _get_feature_view_vector_field_metadata(table) + vector_field_name = vector_field.name if vector_field else None insert_values = [] grouped_docs: dict[str, dict[str, Any]] = defaultdict( lambda: { @@ -115,7 +119,9 @@ def online_write_batch( doc_key = f"{encoded_entity_key}_{timestamp}" for feature_name, value in values.items(): - doc = _encode_feature_value(value) + doc = _encode_feature_value( + value, is_vector=(feature_name == vector_field_name) + ) grouped_docs[doc_key]["features"][feature_name] = doc grouped_docs[doc_key]["timestamp"] = timestamp grouped_docs[doc_key]["created_ts"] = created_ts @@ -299,8 +305,11 @@ def retrieve_online_documents( Optional[ValueProto], ] ] = [] + vector_field = _get_feature_view_vector_field_metadata(table) vector_field_path = ( - config.online_store.vector_field_path or "embedding.vector_value" + f"{vector_field.name}.vector_value" + if vector_field + else config.online_store.vector_field_path or "embedding.vector_value" ) query = { "script_score": { @@ -384,10 +393,21 @@ def retrieve_online_documents_v2( body["_source"] = source_fields if embedding: - similarity = (distance_metric or config.online_store.similarity).lower() + vector_field = _get_feature_view_vector_field_metadata(table) vector_field_path = ( - config.online_store.vector_field_path or "embedding.vector_value" + f"{vector_field.name}.vector_value" + if vector_field + else config.online_store.vector_field_path or "embedding.vector_value" ) + similarity = ( + distance_metric + or ( + vector_field.vector_search_metric + if vector_field and vector_field.vector_search_metric + else None + ) + or config.online_store.similarity + ).lower() if similarity == "cosine": script = f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0" elif similarity == "dot_product": @@ -489,16 +509,21 @@ def _to_value_proto(value: Any) -> ValueProto: return val_proto -def _encode_feature_value(value: ValueProto) -> Dict[str, Any]: +def _encode_feature_value(value: ValueProto, is_vector: bool = False) -> Dict[str, Any]: """ Encode a ValueProto into a dictionary for Elasticsearch storage. """ encoded_value = base64.b64encode(value.SerializeToString()).decode("utf-8") result = {"feature_value": encoded_value} - vector_val = get_list_val_str(value) - if vector_val: - result["vector_value"] = json.loads(vector_val) + if is_vector: + vector_val = get_list_val_str(value) + if vector_val: + result["vector_value"] = json.loads(vector_val) + else: + logger.warning( + "Feature is marked as vector but value does not contain a valid vector." + ) if value.HasField("string_val"): result["value_text"] = value.string_val elif value.HasField("bytes_val"): diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 19a5d8b3158..c5575ef2231 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -83,8 +83,8 @@ "hazelcast": "feast.infra.online_stores.hazelcast_online_store.hazelcast_online_store.HazelcastOnlineStore", "ikv": "feast.infra.online_stores.ikv_online_store.ikv.IKVOnlineStore", "eg-milvus": "feast.expediagroup.vectordb.eg_milvus_online_store.EGMilvusOnlineStore", - "elasticsearch": "feast.expediagroup.vectordb.elasticsearch_online_store.ElasticsearchOnlineStore", - # "elasticsearch": "feast.infra.online_stores.elasticsearch_online_store.elasticsearch.ElasticSearchOnlineStore", + # "elasticsearch": "feast.expediagroup.vectordb.elasticsearch_online_store.ElasticsearchOnlineStore", + "elasticsearch": "feast.infra.online_stores.elasticsearch_online_store.elasticsearch.ElasticSearchOnlineStore", "remote": "feast.infra.online_stores.remote.RemoteOnlineStore", "singlestore": "feast.infra.online_stores.singlestore_online_store.singlestore.SingleStoreOnlineStore", "qdrant": "feast.infra.online_stores.qdrant_online_store.qdrant.QdrantOnlineStore", diff --git a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py new file mode 100644 index 00000000000..cb94205d1f3 --- /dev/null +++ b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py @@ -0,0 +1,68 @@ +import base64 + +import pytest + +from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + _encode_feature_value, +) +from feast.protos.feast.types.Value_pb2 import ( + FloatList, + Int64List, +) +from feast.protos.feast.types.Value_pb2 import ( + Value as ValueProto, +) + + +class TestEncodeFeatureValue: + def test_vector_field_includes_vector_value(self): + """When is_vector=True and value is a float list, vector_value should be present.""" + value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + result = _encode_feature_value(value, is_vector=True) + + assert "vector_value" in result + assert result["vector_value"] == pytest.approx([0.1, 0.2, 0.3]) + + def test_non_vector_list_excludes_vector_value(self): + """When is_vector=False and value is a float list, vector_value should NOT be present.""" + value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + result = _encode_feature_value(value, is_vector=False) + + assert "vector_value" not in result + + def test_non_vector_int_list_excludes_vector_value(self): + """An int64 list with is_vector=False should not produce vector_value.""" + value = ValueProto(int64_list_val=Int64List(val=[1, 2, 3])) + result = _encode_feature_value(value, is_vector=False) + + assert "vector_value" not in result + + def test_string_value_has_value_text(self): + """A string ValueProto should produce value_text, not vector_value.""" + value = ValueProto(string_val="hello") + result = _encode_feature_value(value, is_vector=False) + + assert result["value_text"] == "hello" + assert "vector_value" not in result + + def test_feature_value_always_present(self): + """feature_value (base64 binary) should always be present regardless of is_vector.""" + vector_value = ValueProto(float_list_val=FloatList(val=[1.0, 2.0])) + string_value = ValueProto(string_val="test") + int_value = ValueProto(int64_val=42) + + for val in [vector_value, string_value, int_value]: + for is_vector in [True, False]: + result = _encode_feature_value(val, is_vector=is_vector) + assert "feature_value" in result + # Verify it's valid base64 that deserializes back + decoded = base64.b64decode(result["feature_value"]) + roundtrip = ValueProto() + roundtrip.ParseFromString(decoded) + + def test_default_is_vector_false(self): + """Calling without is_vector should default to False (no vector_value).""" + value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2])) + result = _encode_feature_value(value) + + assert "vector_value" not in result From c6443c1d94c48b538bdf26f0aebc447599ac5100 Mon Sep 17 00:00:00 2001 From: vanitabhagwat <92561664+vanitabhagwat@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:18:42 -0700 Subject: [PATCH 2/7] fix: ES integration tests (#350) * fix: ES integration tests * fix: Added fromisoformat() for converting timestamps --------- Co-authored-by: vanitabhagwat --- .../elasticsearch_online_store/elasticsearch.py | 11 ++++++----- .../integration/feature_repos/repo_configuration.py | 4 ++++ .../universal/online_store/elasticsearch.py | 1 - .../integration/online_store/test_universal_online.py | 7 ++++++- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index a7a69f3472a..d134f22e054 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -178,7 +178,7 @@ def online_read( for hit in response["hits"]["hits"]: source = hit["_source"] timestamp = source.get("timestamp") - timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S") + timestamp = datetime.fromisoformat(timestamp) features: Dict[str, ValueProto] = {} @@ -211,8 +211,9 @@ def create_index(self, config: RepoConfig, table: FeatureView): config: Feast repo configuration object. table: FeatureView table for which the index needs to be created. """ - vector_field_length = getattr( - _get_feature_view_vector_field_metadata(table), "vector_length", 512 + vector_field_length = ( + getattr(_get_feature_view_vector_field_metadata(table), "vector_length", 0) + or 512 ) index_mapping = { @@ -331,7 +332,7 @@ def retrieve_online_documents( distance = row["_score"] timestamp_str = source.get("timestamp") - timestamp = datetime.strptime(timestamp_str, "%Y-%m-%dT%H:%M:%S.%f") + timestamp = datetime.fromisoformat(timestamp_str) for feature_name in requested_features: feature_data = source.get(feature_name, {}) @@ -461,7 +462,7 @@ def retrieve_online_documents_v2( entity_key_serialization_version=config.entity_key_serialization_version, ) timestamp = row["_source"]["timestamp"] - timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timestamp = datetime.fromisoformat(timestamp) # Create feature dict with all requested features feature_dict = {"distance": _to_value_proto(float(row["_score"]))} diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 02d3b593bc9..908f88de317 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -81,6 +81,9 @@ from tests.integration.feature_repos.universal.online_store.dynamodb import ( DynamoDBOnlineStoreCreator, ) +from tests.integration.feature_repos.universal.online_store.elasticsearch import ( + ElasticSearchOnlineStoreCreator, +) from tests.integration.feature_repos.universal.online_store.milvus import ( MilvusOnlineStoreCreator, ) @@ -155,6 +158,7 @@ str, Tuple[Union[str, Dict[Any, Any]], Optional[Type[OnlineStoreCreator]]] ] = { "sqlite": ({"type": "sqlite"}, None), + "elasticsearch": ({"type": "elasticsearch"}, ElasticSearchOnlineStoreCreator), # uncomment below once Milvus implementation is complete # "milvus": ({"type": "milvus"}, MilvusOnlineStoreCreator), } diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py b/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py index 1e8088a997e..2fdc66cf6e1 100644 --- a/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/elasticsearch.py @@ -20,7 +20,6 @@ def create_online_store(self) -> Dict[str, Any]: "host": "localhost", "type": "elasticsearch", "port": self.container.get_exposed_port(9200), - "vector_length": 2, "similarity": "cosine", } diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 3d9390eaa45..7a0756b1d6e 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -1167,7 +1167,12 @@ def test_retrieve_online_documents_v2(environment, fake_document_data): name="item_embeddings", entities=[item], schema=[ - Field(name="embedding", dtype=Array(Float32), vector_index=True), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=2, + ), Field(name="text_field", dtype=String), Field(name="category", dtype=String), Field(name="item_id", dtype=Int64), From 7af68fb24d92982d9768ef341dcb1f622ac4d7fd Mon Sep 17 00:00:00 2001 From: vanitabhagwat <92561664+vanitabhagwat@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:49:10 -0700 Subject: [PATCH 3/7] =?UTF-8?q?fix:Elasticsearch=20online=20store=20?= =?UTF-8?q?=E2=80=94=20correctness,=20performance,=20and=20robustness=20fi?= =?UTF-8?q?xes=20(#353)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: vanitabhagwat --- .../elasticsearch.py | 58 ++++++++++++------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index d134f22e054..58ff9b5f3b0 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -127,21 +127,23 @@ def online_write_batch( grouped_docs[doc_key]["created_ts"] = created_ts grouped_docs[doc_key]["entity_key"] = encoded_entity_key - insert_values = [ - { - "entity_key": document["entity_key"], - "timestamp": document["timestamp"], - "created_ts": document["created_ts"], - **(document["features"] or {}), - } - for document in grouped_docs.values() - ] + insert_values = [ + { + "entity_key": document["entity_key"], + "timestamp": document["timestamp"], + "created_ts": document["created_ts"], + **(document["features"] or {}), + } + for document in grouped_docs.values() + ] batch_size = config.online_store.write_batch_size for i in range(0, len(insert_values), batch_size): batch = insert_values[i : i + batch_size] actions = self._bulk_batch_actions(table, batch) helpers.bulk(self._get_client(config), actions, refresh="wait_for") + if progress: + progress(len(batch)) def online_read( self, @@ -165,6 +167,7 @@ def online_read( includes.append("*") body = { + "size": len(encoded_entity_keys), "_source": {"includes": includes, "excludes": ["*.vector_value"]}, "query": { "bool": {"filter": [{"terms": {"entity_key": encoded_entity_keys}}]} @@ -173,10 +176,14 @@ def online_read( response = self._get_client(config).search(index=table.name, body=body) - results = [] + # Build a lookup dict keyed by entity_key to preserve input order + entity_key_to_result: Dict[ + str, Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] + ] = {} for hit in response["hits"]["hits"]: source = hit["_source"] + entity_key_val = source.get("entity_key") timestamp = source.get("timestamp") timestamp = datetime.fromisoformat(timestamp) @@ -199,7 +206,15 @@ def online_read( f"Failed to parse feature '{feature_name}' from hit: {e}" ) - results.append((timestamp, features if features else None)) + entity_key_to_result[entity_key_val] = ( + timestamp, + features if features else None, + ) + + # Return results in the same order as input entity_keys + results: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + for encoded_key in encoded_entity_keys: + results.append(entity_key_to_result.get(encoded_key, (None, None))) return results @@ -261,9 +276,11 @@ def update( ): # implement the update method for table in tables_to_delete: - self._get_client(config).delete_by_query(index=table.name) + if self._get_client(config).indices.exists(index=table.name): + self._get_client(config).delete_by_query(index=table.name) for table in tables_to_keep: - self.create_index(config, table) + if not self._get_client(config).indices.exists(index=table.name): + self.create_index(config, table) def teardown( self, @@ -274,7 +291,8 @@ def teardown( project = config.project try: for table in tables: - self._get_client(config).indices.delete(index=table.name) + if self._get_client(config).indices.exists(index=table.name): + self._get_client(config).indices.delete(index=table.name) except Exception as e: logging.exception(f"Error deleting index in project {project}: {e}") raise @@ -376,12 +394,12 @@ def retrieve_online_documents_v2( Optional[Dict[str, ValueProto]], ] ] = [] - if not config.online_store.vector_enabled: - raise ValueError("Vector search is not enabled in the online store config") - if embedding is None and query_string is None: raise ValueError("Either embedding or query_string must be provided") + if embedding is not None and not config.online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config") + es_index = table.name body: Dict[str, Any] = { "size": top_k, @@ -489,14 +507,14 @@ def _to_value_proto(value: Any) -> ValueProto: val_proto = ValueProto() if isinstance(value, ValueProto): return value - if isinstance(value, float): + if isinstance(value, bool): + val_proto.bool_val = value + elif isinstance(value, float): val_proto.float_val = value elif isinstance(value, str): val_proto.string_val = value elif isinstance(value, int): val_proto.int64_val = value - elif isinstance(value, bool): - val_proto.bool_val = value elif isinstance(value, list) and all(isinstance(v, float) for v in value): val_proto.float_list_val.val.extend(value) elif isinstance(value, dict) and "feature_value" in value: From 7de6f9ab0cf5f9bd9e5c04b5af7ceafc079f06c4 Mon Sep 17 00:00:00 2001 From: Manisha Sudhir <30449541+Manisha4@users.noreply.github.com> Date: Thu, 9 Apr 2026 16:21:32 -0700 Subject: [PATCH 4/7] Feature/vector store (#357) * feat: Valkey Online Write Batch Vector Search Support (#351) * Adding support for Valkey Search, adding changes to the online_write_batch functionality * Addressing PR comments * addressing linting error * fix tests * addressing PR comments * addressing PR comments * fixing linting --------- Co-authored-by: Manisha4 * feat: Support Vector Search in Valkey (#354) * Adding support for Valkey Search, adding changes to the online_write_batch functionality * Addressing PR comments * addressing linting error * Adding changes to support search in valkey * fix tests * adding unit tests * reformatting files and adding checks and more tests * reformatting files and adding checks and more tests * reformatting files and adding checks and more tests * Fix linter errors: type annotations and code formatting - Add explicit type annotation for schema_fields to support both TagField and VectorField - Encode project string to bytes for consistency with other hash values - Decode doc_key bytes to string for hmget compatibility - Fix code formatting: break long lines and remove extra blank lines - Remove tests for multiple vector fields (Feast enforces one vector per feature view) - Fix config type: use 'eg-valkey' (hyphen) not 'eg_valkey' (underscore) Co-Authored-By: Claude Opus 4.6 * addressing PR comments * addressing PR comments * fixing linting * Fix missing feature_name argument in retrieve_online_documents_v2 Add the third argument (vector_field.name) to _get_vector_index_name call to match the updated function signature. Co-Authored-By: Claude Opus 4.5 * addressing comments, PR changes for some fixes and merge conflicts * fixing tests * fixing tests * fixing linting * fixing linting --------- Co-authored-by: Manisha4 Co-authored-by: Claude Opus 4.6 * fix: Valkey vector search - remove unsupported SORTBY (#356) * fix: Valkey vector search - remove unsupported SORTBY and fix tag filter syntax Valkey Search KNN queries return results pre-sorted by distance, so explicit SORTBY is not supported and causes a ResponseError. This removes the .sort_by() call from the query builder. Additionally, fixes the project tag filter to use unquoted syntax with backslash escaping for special characters (e.g. hyphens, dots) instead of the quoted syntax which was returning empty results. Updates unit tests to reflect both changes: replaces three metric-specific sort order tests with a single test asserting no SORTBY is set, and updates escaping assertions to match the new backslash-escape approach. Co-Authored-By: Claude Opus 4.6 * style: apply ruff format to eg_valkey.py and test_valkey.py Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Manisha4 Co-authored-by: Claude Opus 4.6 --------- Co-authored-by: Manisha4 Co-authored-by: Claude Opus 4.6 --- .../feast/infra/online_stores/eg_valkey.py | 664 ++++++++- .../unit/infra/online_store/test_valkey.py | 1281 ++++++++++++++++- 2 files changed, 1917 insertions(+), 28 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/eg_valkey.py b/sdk/python/feast/infra/online_stores/eg_valkey.py index 27d45b18935..3ef14fc89a0 100644 --- a/sdk/python/feast/infra/online_stores/eg_valkey.py +++ b/sdk/python/feast/infra/online_stores/eg_valkey.py @@ -31,24 +31,36 @@ Union, ) +import numpy as np from google.protobuf.timestamp_pb2 import Timestamp from pydantic import StrictStr -from valkey.exceptions import ValkeyError +from valkey.exceptions import ResponseError, ValkeyError from feast import Entity, FeatureView, RepoConfig, utils -from feast.infra.key_encoding_utils import serialize_entity_key +from feast.field import Field +from feast.infra.key_encoding_utils import ( + deserialize_entity_key, + serialize_entity_key, +) from feast.infra.online_stores.helpers import _mmh3, _redis_key, _redis_key_prefix from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.protos.feast.types.Value_pb2 import FloatList +from feast.protos.feast.types.Value_pb2 import ( + Value as ValueProto, +) from feast.repo_config import FeastConfigBaseModel from feast.sorted_feature_view import SortedFeatureView +from feast.types import Array, Float64 from feast.value_type import ValueType try: from valkey import Valkey from valkey import asyncio as valkey_asyncio from valkey.cluster import ClusterNode, ValkeyCluster + from valkey.commands.search.field import TagField, VectorField + from valkey.commands.search.indexDefinition import IndexDefinition, IndexType + from valkey.commands.search.query import Query from valkey.sentinel import Sentinel except ImportError as e: from feast.errors import FeastExtrasDependencyImportError @@ -58,6 +70,91 @@ logger = logging.getLogger(__name__) +def _get_vector_index_name( + project: str, feature_view_name: str, feature_name: str +) -> str: + """Generate Valkey Search index name for a vector field.""" + return f"{project}_{feature_view_name}_{feature_name}_vidx" + + +def _get_valkey_vector_type(feast_dtype) -> str: + """ + Map Feast dtype to Valkey vector TYPE parameter. + + Valkey Search only supports FLOAT32 vectors. Float64 arrays will be + converted to float32 during serialization. + + Args: + feast_dtype: Feast data type (e.g., Array(Float32)) + + Returns: + Valkey vector type string: always "FLOAT32" + """ + if feast_dtype == Array(Float64): + logger.warning( + "Valkey Search only supports FLOAT32 vectors. " + "Float64 data will be converted to float32 (possible precision loss)." + ) + return "FLOAT32" + + +def _serialize_vector_to_bytes(val: ValueProto, field: Field) -> bytes: + """ + Serialize a vector ValueProto to raw float32 bytes for Valkey storage. + + Vector fields must be stored as raw bytes (not protobuf serialized) to be + compatible with Valkey Search FT.SEARCH queries. Valkey only supports + FLOAT32, so float64 data is converted to float32. + + Args: + val: The ValueProto containing the vector data + field: The Field metadata for dtype and dimension information + + Returns: + Raw float32 bytes in the format expected by Valkey vector search + + Raises: + ValueError: If vector type is unsupported or dimension mismatches + """ + if val.HasField("float_list_val"): + vector = np.array(val.float_list_val.val, dtype=np.float32) + elif val.HasField("double_list_val"): + # Convert float64 to float32 (Valkey only supports float32) + vector = np.array(val.double_list_val.val, dtype=np.float32) + else: + raise ValueError( + f"Unsupported vector type for field {field.name}. " + f"Expected float_list_val or double_list_val." + ) + + # Validate dimension matches expected + if field.vector_length > 0 and len(vector) != field.vector_length: + raise ValueError( + f"Vector dimension mismatch for field {field.name}: " + f"expected {field.vector_length}, got {len(vector)}" + ) + + return vector.tobytes() + + +def _deserialize_vector_from_bytes(raw_bytes: bytes, field: Field) -> ValueProto: + """ + Deserialize raw vector bytes back to ValueProto. + + Valkey stores all vectors as float32, so we always deserialize as float32 + regardless of the original field dtype. + + Args: + raw_bytes: Raw float32 bytes from Valkey + field: Field metadata (unused, kept for API consistency) + + Returns: + ValueProto with float_list_val (always float32) + """ + vector = np.frombuffer(raw_bytes, dtype=np.float32) + return ValueProto(float_list_val=FloatList(val=vector.tolist())) + + class EGValkeyType(str, Enum): valkey = "valkey" valkey_cluster = "valkey_cluster" @@ -100,6 +197,19 @@ class EGValkeyOnlineStoreConfig(FeastConfigBaseModel): max_pipeline_commands: Optional[int] = 500 """(Optional) The maximum number of Valkey commands to queue in a pipeline before sending them to Valkey in a single batch.""" + # Vector search configuration + vector_index_algorithm: Literal["FLAT", "HNSW"] = "HNSW" + """Algorithm for vector indexing. FLAT for exact search (<100K vectors), HNSW for approximate search (large datasets).""" + + vector_index_hnsw_m: Optional[int] = 16 + """HNSW: Max number of outgoing edges per node.""" + + vector_index_hnsw_ef_construction: Optional[int] = 200 + """HNSW: Size of dynamic candidate list during index construction.""" + + vector_index_hnsw_ef_runtime: Optional[int] = 10 + """HNSW: Size of dynamic candidate list during search.""" + class EGValkeyOnlineStore(OnlineStore): """ @@ -144,7 +254,12 @@ def delete_table(self, config: RepoConfig, table: FeatureView): deleted_count = 0 prefix = _redis_key_prefix(table.join_keys) - valkey_hash_keys = [_mmh3(f"{table.name}:{f.name}") for f in table.features] + # Build list of hash keys to delete + # Vector fields use original name, non-vector fields use mmh3 hash + valkey_hash_keys = [ + f.name.encode("utf8") if f.vector_index else _mmh3(f"{table.name}:{f.name}") + for f in table.features + ] valkey_hash_keys.append(bytes(f"_ts:{table.name}", "utf8")) with client.pipeline(transaction=False) as pipe: @@ -165,6 +280,33 @@ def delete_table(self, config: RepoConfig, table: FeatureView): logger.debug(f"Deleted {deleted_count} rows for feature view {table.name}") + # Drop vector index if it exists + self._drop_vector_index_if_exists(client, config.project, table) + + def _drop_vector_index_if_exists( + self, + client: Union[Valkey, ValkeyCluster], + project: str, + table: FeatureView, + ) -> None: + """Drop Valkey Search vector indexes for all vector fields in a feature view.""" + vector_fields = [f for f in table.features if f.vector_index] + + # Drop index for each vector field + for field in vector_fields: + index_name = _get_vector_index_name(project, table.name, field.name) + try: + client.ft(index_name).dropindex(delete_documents=False) + logger.info(f"Dropped vector index {index_name}") + except ResponseError as e: + # Index doesn't exist - this is fine + if "unknown index" in str(e).lower(): + logger.debug( + f"Vector index {index_name} does not exist, skipping drop" + ) + else: + raise + def update( self, config: RepoConfig, @@ -202,8 +344,14 @@ def teardown( """ We delete the keys in valkey for tables/views being removed. """ - join_keys_to_delete = set(tuple(table.join_keys) for table in tables) + client = self._get_client(config.online_store) + # Drop vector indexes for each table + for table in tables: + self._drop_vector_index_if_exists(client, config.project, table) + + # Delete entity values + join_keys_to_delete = set(tuple(table.join_keys) for table in tables) for join_keys in join_keys_to_delete: self.delete_entity_values(config, list(join_keys)) @@ -289,6 +437,96 @@ async def _get_client_async(self, online_store_config: EGValkeyOnlineStoreConfig self._client_async = valkey_asyncio.Valkey(**kwargs) return self._client_async + def _create_vector_index_if_not_exists( + self, + client: Union[Valkey, ValkeyCluster], + config: RepoConfig, + table: FeatureView, + vector_fields: Dict[str, Field], + ) -> None: + """ + Create Valkey Search index for each vector field if not already exists. + + Uses FT.CREATE with VECTOR field type and appropriate algorithm parameters. + Creates one index per vector field for future multi-vector support. + + Args: + client: Valkey client + config: Feast repo configuration + table: Feature view with vector fields + vector_fields: Dictionary of vector field name to Field object + """ + online_store_config = config.online_store + assert isinstance(online_store_config, EGValkeyOnlineStoreConfig) + + # Define index on HASH keys with specific prefix (shared across all indexes) + key_prefix = _redis_key_prefix(table.join_keys) + definition = IndexDefinition( + prefix=[key_prefix], + index_type=IndexType.HASH, + ) + + # Create one index per vector field + for field_name, field in vector_fields.items(): + index_name = _get_vector_index_name(config.project, table.name, field_name) + + # Check if index exists + try: + client.ft(index_name).info() + logger.debug(f"Vector index {index_name} already exists") + continue + except ResponseError: + pass # Index doesn't exist, create it + + # Validate required properties + if field.vector_length <= 0: + raise ValueError( + f"Field {field_name} has vector_index=True but vector_length is not set. " + f"vector_length must be > 0 for vector indexing." + ) + + # Determine vector type from Feast dtype + vector_type = _get_valkey_vector_type(field.dtype) + + # Build algorithm attributes + attributes = { + "TYPE": vector_type, # Always FLOAT32 (Valkey limitation) + "DIM": field.vector_length, + "DISTANCE_METRIC": field.vector_search_metric or "COSINE", + } + + # Add algorithm-specific parameters + algorithm = online_store_config.vector_index_algorithm + if algorithm == "HNSW": + attributes["M"] = online_store_config.vector_index_hnsw_m + attributes["EF_CONSTRUCTION"] = ( + online_store_config.vector_index_hnsw_ef_construction + ) + attributes["EF_RUNTIME"] = ( + online_store_config.vector_index_hnsw_ef_runtime + ) + + # Create the index with vector field and project tag for filtering + # __project__ TAG field enables filtering by project in hybrid queries + try: + client.ft(index_name).create_index( + fields=[ + VectorField(field_name, algorithm, attributes), + TagField("__project__"), + ], + definition=definition, + ) + logger.info(f"Created vector index {index_name} for field {field_name}") + except ResponseError as e: + if "already exists" in str(e).lower(): + logger.debug(f"Vector index {index_name} already exists") + continue + logger.error( + f"Failed to create vector index {index_name}: {e}. " + f"Ensure Valkey Search module is loaded." + ) + raise + def online_write_batch( self, config: RepoConfig, @@ -307,6 +545,7 @@ def online_write_batch( feature_view = table.name ts_key = f"_ts:{feature_view}" keys = [] + # Track all ZSET keys touched in this batch for TTL cleanup & trimming zsets_to_cleanup: set[Tuple[bytes, bytes]] = ( set() @@ -448,6 +687,15 @@ def online_write_batch( ) raise else: + # Identify vector fields (only for regular FeatureViews, not SortedFeatureView) + vector_fields = {f.name: f for f in table.features if f.vector_index} + + # Create vector index if needed (only on first write with vector fields) + if vector_fields: + self._create_vector_index_if_not_exists( + client, config, table, vector_fields + ) + # check if a previous record under the key bin exists # TODO: investigate if check and set is a better approach rather than pulling all entity ts and then setting # it may be significantly slower but avoids potential (rare) race conditions @@ -463,9 +711,12 @@ def online_write_batch( # flattening the list of lists. `hmget` does the lookup assuming a list of keys in the key bin prev_event_timestamps = [i[0] for i in prev_event_timestamps] - for valkey_key_bin, prev_event_time, (_, values, timestamp, _) in zip( - keys, prev_event_timestamps, data - ): + for valkey_key_bin, prev_event_time, ( + entity_key, + values, + timestamp, + _, + ) in zip(keys, prev_event_timestamps, data): event_time_seconds = int(utils.make_tzaware(timestamp).timestamp()) # ignore if event_timestamp is before the event features that are currently in the feature store @@ -482,10 +733,24 @@ def online_write_batch( ts.seconds = event_time_seconds entity_hset = dict() entity_hset[ts_key] = ts.SerializeToString() + # Store project and entity key for vector search + entity_hset["__project__"] = project.encode() + entity_hset["__entity_key__"] = serialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) for feature_name, val in values.items(): - f_key = _mmh3(f"{feature_view}:{feature_name}") - entity_hset[f_key] = val.SerializeToString() + if feature_name in vector_fields: + # Vector field: store with ORIGINAL name and RAW bytes + vector_bytes = _serialize_vector_to_bytes( + val, vector_fields[feature_name] + ) + entity_hset[feature_name] = vector_bytes + else: + # Non-vector field: store with mmh3 hash and protobuf serialization + f_key = _mmh3(f"{feature_view}:{feature_name}") + entity_hset[f_key] = val.SerializeToString() pipe.hset(valkey_key_bin, mapping=entity_hset) @@ -580,28 +845,53 @@ def _generate_hset_keys_for_features( self, feature_view: FeatureView, requested_features: Optional[List[str]] = None, - ) -> Tuple[List[str], List[str]]: + ) -> Tuple[List[str], List[str], Dict[str, Field]]: + """ + Generate HSET keys for feature retrieval. + + Returns: + Tuple of (feature_names, hset_keys, vector_fields dict) + """ if not requested_features: requested_features = [f.name for f in feature_view.features] - hset_keys = [_mmh3(f"{feature_view.name}:{k}") for k in requested_features] + vector_fields = {f.name: f for f in feature_view.features if f.vector_index} + + hset_keys = [] + for feature_name in requested_features: + if feature_name in vector_fields: + # Vector field: use original name + hset_keys.append(feature_name) + else: + # Non-vector: use mmh3 hash + hset_keys.append(_mmh3(f"{feature_view.name}:{feature_name}")) ts_key = f"_ts:{feature_view.name}" hset_keys.append(ts_key) - requested_features.append(ts_key) + requested_features = list(requested_features) + [ts_key] - return requested_features, hset_keys + return requested_features, hset_keys, vector_fields def _convert_valkey_values_to_protobuf( self, valkey_values: List[List[ByteString]], - feature_view: str, + feature_view: FeatureView, requested_features: List[str], + vector_fields: Dict[str, Field], ): + """ + Convert Valkey values back to protobuf, handling vector fields. + + Args: + valkey_values: Raw values from Valkey + feature_view: Feature view object (not just name) + requested_features: List of feature names + vector_fields: Dict of field name to Field for vector fields + """ result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for values in valkey_values: features = self._get_features_for_entity( - values, feature_view, requested_features + values, feature_view, requested_features, vector_fields ) result.append(features) return result @@ -619,8 +909,8 @@ def online_read( client = self._get_client(online_store_config) feature_view = table - requested_features, hset_keys = self._generate_hset_keys_for_features( - feature_view, requested_features + requested_features, hset_keys, vector_fields = ( + self._generate_hset_keys_for_features(feature_view, requested_features) ) keys = self._generate_valkey_keys_for_entities(config, entity_keys) @@ -631,7 +921,7 @@ def online_read( valkey_values = pipe.execute() return self._convert_valkey_values_to_protobuf( - valkey_values, feature_view.name, requested_features + valkey_values, feature_view, requested_features, vector_fields ) async def online_read_async( @@ -647,8 +937,8 @@ async def online_read_async( client = await self._get_client_async(online_store_config) feature_view = table - requested_features, hset_keys = self._generate_hset_keys_for_features( - feature_view, requested_features + requested_features, hset_keys, vector_fields = ( + self._generate_hset_keys_for_features(feature_view, requested_features) ) keys = self._generate_valkey_keys_for_entities(config, entity_keys) @@ -658,27 +948,47 @@ async def online_read_async( valkey_values = await pipe.execute() return self._convert_valkey_values_to_protobuf( - valkey_values, feature_view.name, requested_features + valkey_values, feature_view, requested_features, vector_fields ) def _get_features_for_entity( self, values: List[ByteString], - feature_view: str, + feature_view: FeatureView, requested_features: List[str], + vector_fields: Dict[str, Field], ) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]: + """ + Parse features for a single entity, handling vector deserialization. + + Args: + values: Raw bytes from Valkey + feature_view: Feature view object + requested_features: List of feature names (includes _ts key) + vector_fields: Dict of field name to Field for vector fields (O(1) lookup) + """ res_val = dict(zip(requested_features, values)) res_ts = Timestamp() - ts_val = res_val.pop(f"_ts:{feature_view}") + ts_val = res_val.pop(f"_ts:{feature_view.name}") if ts_val: res_ts.ParseFromString(bytes(ts_val)) res = {} for feature_name, val_bin in res_val.items(): - val = ValueProto() - if val_bin: + if not val_bin: + res[feature_name] = ValueProto() + continue + + if feature_name in vector_fields: + # Vector field: deserialize from raw bytes + field = vector_fields[feature_name] + val = _deserialize_vector_from_bytes(bytes(val_bin), field) + else: + # Regular field: parse protobuf + val = ValueProto() val.ParseFromString(bytes(val_bin)) + res[feature_name] = val if not res: @@ -686,3 +996,305 @@ def _get_features_for_entity( else: timestamp = datetime.fromtimestamp(res_ts.seconds, tz=timezone.utc) return timestamp, res + + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Retrieve documents using vector similarity search from Valkey. + + Args: + config: Feast configuration object + table: FeatureView to search + requested_features: List of feature names to return + embedding: Query embedding vector + top_k: Number of results to return + distance_metric: Optional override for distance metric (COSINE, L2, IP) + query_string: Not supported in V1 (reserved for future BM25 search) + + Returns: + List of tuples containing (timestamp, entity_key, features_dict) + """ + if embedding is None: + raise ValueError("embedding must be provided for vector search") + + if query_string is not None: + raise NotImplementedError( + "Keyword search (query_string) is not yet supported for Valkey. " + "Only vector similarity search is available." + ) + + online_store_config = config.online_store + assert isinstance(online_store_config, EGValkeyOnlineStoreConfig) + + client = self._get_client(online_store_config) + project = config.project + + # Find the vector field to search against + vector_field = self._get_vector_field_for_search(table, requested_features) + if vector_field is None: + raise ValueError( + f"No vector field found in FeatureView {table.name}. " + "Ensure the FeatureView has a field with vector_index=True." + ) + + # Determine distance metric + metric = distance_metric or vector_field.vector_search_metric or "COSINE" + + # Serialize query embedding to bytes + embedding_bytes = self._serialize_embedding_for_search(embedding, vector_field) + + # Build and execute FT.SEARCH query + index_name = _get_vector_index_name(project, table.name, vector_field.name) + search_results = self._execute_vector_search( + client=client, + index_name=index_name, + project=project, + vector_field_name=vector_field.name, + embedding_bytes=embedding_bytes, + top_k=top_k, + metric=metric, + ) + + if not search_results: + return [] + + # Fetch features for each result using pipeline HMGET + return self._fetch_features_for_search_results( + client=client, + config=config, + table=table, + requested_features=requested_features, + search_results=search_results, + ) + + def _get_vector_field_for_search( + self, + table: FeatureView, + requested_features: Optional[List[str]], + ) -> Optional[Field]: + """Find the vector field to use for search.""" + vector_fields = [f for f in table.features if f.vector_index] + + if not vector_fields: + return None + + # If requested_features specified, prefer a vector field from that list + if requested_features: + # Convert to set for O(1) lookup instead of O(n) list search + requested_set = set(requested_features) + for f in vector_fields: + if f.name in requested_set: + return f + + # Default to first vector field + return vector_fields[0] + + def _serialize_embedding_for_search( + self, + embedding: List[float], + vector_field: Field, + ) -> bytes: + """Serialize query embedding to bytes matching the field's dtype.""" + # Validate embedding dimension matches field configuration + if len(embedding) != vector_field.vector_length: + raise ValueError( + f"Embedding dimension {len(embedding)} does not match " + f"vector field '{vector_field.name}' dimension {vector_field.vector_length}" + ) + + if vector_field.dtype == Array(Float64): + return np.array(embedding, dtype=np.float64).tobytes() + else: + # Default to float32 + return np.array(embedding, dtype=np.float32).tobytes() + + def _execute_vector_search( + self, + client: Union[Valkey, ValkeyCluster], + index_name: str, + project: str, + vector_field_name: str, + embedding_bytes: bytes, + top_k: int, + metric: str, + ) -> List[Tuple[bytes, float]]: + """ + Execute FT.SEARCH with KNN query. + + Returns: + List of (doc_key, distance) tuples + """ + # Escape special characters in project name for tag filter. + # In Valkey Search tag queries, characters like - . @ need backslash escaping. + escaped_project = project + for ch in r'\-.@+~<>{}[]^":|!*()': + escaped_project = escaped_project.replace(ch, f"\\{ch}") + + query_str = ( + f"(@__project__:{{{escaped_project}}})" + f"=>[KNN {top_k} @{vector_field_name} $vec AS __distance__]" + ) + + # KNN results are already sorted by distance (ascending) by the engine. + # No explicit SORTBY is needed — Valkey Search does not support SORTBY + # with KNN queries. + query = ( + Query(query_str).return_fields("__distance__").paging(0, top_k).dialect(2) + ) + + try: + results = client.ft(index_name).search( + query, + query_params={"vec": embedding_bytes}, + ) + except ResponseError as e: + if "no such index" in str(e).lower(): + raise ValueError( + f"Vector index '{index_name}' does not exist. " + "Ensure data has been materialized with 'feast materialize'." + ) + raise + + # Parse results: extract doc keys and distances + search_results = [] + for doc in results.docs: + doc_key = doc.id.encode() if isinstance(doc.id, str) else doc.id + # Default to inf (worst distance) if __distance__ is missing + # 0.0 would incorrectly indicate a perfect match + distance = float(getattr(doc, "__distance__", float("inf"))) + search_results.append((doc_key, distance)) + + return search_results + + def _fetch_features_for_search_results( + self, + client: Union[Valkey, ValkeyCluster], + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + search_results: List[Tuple[bytes, float]], + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Fetch features for search results using pipeline HMGET. + + This is the second step of two-step retrieval: + 1. FT.SEARCH returns doc keys and distances + 2. HMGET fetches the actual feature values + """ + # Pre-compute mappings once (avoid repeated dict/hash operations in loops) + vector_fields_dict = {f.name: f for f in table.features if f.vector_index} + + # Build feature_name -> hset_key mapping and hset_keys list in single pass + feature_to_hset_key: Dict[str, Any] = {} + hset_keys = [] + for feature_name in requested_features: + if feature_name in vector_fields_dict: + hset_key = feature_name + else: + hset_key = _mmh3(f"{table.name}:{feature_name}") + feature_to_hset_key[feature_name] = hset_key + hset_keys.append(hset_key) + + # Add timestamp and entity key + ts_key = f"_ts:{table.name}" + hset_keys.append(ts_key) + hset_keys.append("__entity_key__") + + # Extract doc_keys and distances in single pass + doc_keys = [] + distances = {} + for doc_key, dist in search_results: + doc_keys.append(doc_key) + distances[doc_key] = dist + + # Pipeline HMGET for all results (single round-trip to Valkey) + with client.pipeline(transaction=False) as pipe: + for doc_key in doc_keys: + key_str = doc_key.decode() if isinstance(doc_key, bytes) else doc_key + pipe.hmget(key_str, hset_keys) + fetched_values = pipe.execute() + + # Pre-fetch serialization version once + entity_key_serialization_version = config.entity_key_serialization_version + + # Build result list + results: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + + for doc_key, values in zip(doc_keys, fetched_values): + # Parse values into dict + val_dict = dict(zip(hset_keys, values)) + + # Parse timestamp + timestamp = None + ts_val = val_dict.get(ts_key) + if ts_val: + ts_proto = Timestamp() + ts_proto.ParseFromString(bytes(ts_val)) + timestamp = datetime.fromtimestamp(ts_proto.seconds, tz=timezone.utc) + + # Parse entity key + entity_key_proto = None + entity_key_bytes = val_dict.get("__entity_key__") + if entity_key_bytes: + entity_key_proto = deserialize_entity_key( + bytes(entity_key_bytes), + entity_key_serialization_version=entity_key_serialization_version, + ) + + # Build feature dict with pre-allocated capacity hint + feature_dict: Dict[str, ValueProto] = {} + + # Add distance as a feature + distance_proto = ValueProto() + distance_proto.double_val = distances[doc_key] + feature_dict["distance"] = distance_proto + + # Parse requested features using pre-computed mappings + for feature_name in requested_features: + hset_key = feature_to_hset_key[feature_name] + val_bin = val_dict.get(hset_key) + + if not val_bin: + feature_dict[feature_name] = ValueProto() + continue + + if feature_name in vector_fields_dict: + # Vector field: deserialize from raw bytes + feature_dict[feature_name] = _deserialize_vector_from_bytes( + bytes(val_bin), vector_fields_dict[feature_name] + ) + else: + # Regular field: parse protobuf + val = ValueProto() + val.ParseFromString(bytes(val_bin)) + feature_dict[feature_name] = val + + results.append((timestamp, entity_key_proto, feature_dict)) + + return results diff --git a/sdk/python/tests/unit/infra/online_store/test_valkey.py b/sdk/python/tests/unit/infra/online_store/test_valkey.py index 02e9cb0cbdb..f172838a846 100644 --- a/sdk/python/tests/unit/infra/online_store/test_valkey.py +++ b/sdk/python/tests/unit/infra/online_store/test_valkey.py @@ -1,22 +1,36 @@ import time from datetime import datetime, timedelta, timezone +import numpy as np import pytest from valkey import Valkey -from feast import Entity, Field, FileSource, RepoConfig, ValueType +from feast import Entity, FeatureView, Field, FileSource, RepoConfig, ValueType from feast.infra.online_stores.eg_valkey import ( EGValkeyOnlineStore, EGValkeyOnlineStoreConfig, + _deserialize_vector_from_bytes, + _get_valkey_vector_type, + _get_vector_index_name, + _serialize_vector_to_bytes, ) from feast.infra.online_stores.helpers import _mmh3, _redis_key from feast.protos.feast.core.SortedFeatureView_pb2 import SortOrder from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.protos.feast.types.Value_pb2 import ( + DoubleList, + FloatList, +) +from feast.protos.feast.types.Value_pb2 import ( + Value as ValueProto, +) from feast.sorted_feature_view import SortedFeatureView, SortKey from feast.types import ( + Array, Float32, + Float64, Int32, + Int64, String, UnixTimestamp, ) @@ -455,3 +469,1266 @@ def test_ttl_cleanup_no_expired_members(repo_config): remaining = redis_client.zrange(zset_key, 0, -1) assert active_member in remaining + + +class TestVectorIndexName: + """Tests for _get_vector_index_name helper function.""" + + def test_get_vector_index_name(self): + """Test index name generation follows expected format.""" + assert ( + _get_vector_index_name("my_project", "item_embeddings", "embedding") + == "my_project_item_embeddings_embedding_vidx" + ) + + def test_get_vector_index_name_with_special_chars(self): + """Test index name with underscores in names.""" + assert ( + _get_vector_index_name("prod_project", "user_item_embeddings", "vec_field") + == "prod_project_user_item_embeddings_vec_field_vidx" + ) + + +class TestGetValkeyVectorType: + """Tests for _get_valkey_vector_type helper function.""" + + def test_get_valkey_vector_type_float32(self): + """Test Float32 array maps to FLOAT32.""" + assert _get_valkey_vector_type(Array(Float32)) == "FLOAT32" + + def test_get_valkey_vector_type_float64_converts_to_float32(self): + """Test Float64 array also maps to FLOAT32 (Valkey only supports float32).""" + assert _get_valkey_vector_type(Array(Float64)) == "FLOAT32" + + def test_get_valkey_vector_type_unsupported_defaults_to_float32(self): + """Test unsupported types default to FLOAT32.""" + # Int32 array is not a valid vector type, should default to FLOAT32 + assert _get_valkey_vector_type(Array(Int32)) == "FLOAT32" + + +class TestSerializeVectorToBytes: + """Tests for _serialize_vector_to_bytes helper function.""" + + def test_serialize_vector_float32(self): + """Test Float32 vector serialization to raw bytes.""" + val = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3, 0.4])) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + result = _serialize_vector_to_bytes(val, field) + + expected = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32).tobytes() + assert result == expected + + def test_serialize_vector_float64_converts_to_float32(self): + """Test Float64 vector is converted to float32 bytes (Valkey limitation).""" + val = ValueProto(double_list_val=DoubleList(val=[0.1, 0.2, 0.3, 0.4])) + field = Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + result = _serialize_vector_to_bytes(val, field) + + # Should be float32 bytes, not float64 + expected = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32).tobytes() + assert result == expected + + def test_serialize_vector_dimension_mismatch(self): + """Test error when vector dimension doesn't match expected length.""" + val = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=128, # Expected 128, but vector has 3 elements + ) + + with pytest.raises(ValueError, match="dimension mismatch"): + _serialize_vector_to_bytes(val, field) + + def test_serialize_vector_unsupported_type(self): + """Test error when vector type is not float_list or double_list.""" + val = ValueProto(int32_val=123) # Not a list type + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + with pytest.raises(ValueError, match="Unsupported vector type"): + _serialize_vector_to_bytes(val, field) + + def test_serialize_vector_no_length_validation_when_zero(self): + """Test that vector_length=0 skips dimension validation.""" + val = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3])) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=0, # No validation + ) + + # Should not raise + result = _serialize_vector_to_bytes(val, field) + assert len(result) == 3 * 4 # 3 floats * 4 bytes each + + +class TestDeserializeVectorFromBytes: + """Tests for _deserialize_vector_from_bytes helper function.""" + + def test_deserialize_vector_float32(self): + """Test Float32 vector deserialization from raw bytes.""" + original = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + raw_bytes = original.tobytes() + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + result = _deserialize_vector_from_bytes(raw_bytes, field) + + assert result.HasField("float_list_val") + np.testing.assert_array_almost_equal( + result.float_list_val.val, original, decimal=5 + ) + + def test_deserialize_always_returns_float32(self): + """Test deserialization always returns float32 (Valkey only supports float32).""" + original = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + raw_bytes = original.tobytes() + # Even with Float64 field dtype, result should be float32 + field = Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + result = _deserialize_vector_from_bytes(raw_bytes, field) + + # Should always return float_list_val regardless of field dtype + assert result.HasField("float_list_val") + np.testing.assert_array_almost_equal( + result.float_list_val.val, original, decimal=5 + ) + + def test_roundtrip_float32(self): + """Test serialize then deserialize preserves Float32 vector values.""" + original_values = [0.123, 0.456, 0.789, 1.0] + val = ValueProto(float_list_val=FloatList(val=original_values)) + field = Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + raw_bytes = _serialize_vector_to_bytes(val, field) + result = _deserialize_vector_from_bytes(raw_bytes, field) + + np.testing.assert_array_almost_equal( + result.float_list_val.val, original_values, decimal=5 + ) + + def test_roundtrip_float64_converts_to_float32(self): + """Test Float64 input is converted to float32 during roundtrip.""" + original_values = [0.123456789, 0.987654321, 0.111111111, 0.999999999] + val = ValueProto(double_list_val=DoubleList(val=original_values)) + field = Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + raw_bytes = _serialize_vector_to_bytes(val, field) + result = _deserialize_vector_from_bytes(raw_bytes, field) + + # Result is float32, so we get float_list_val with reduced precision + assert result.HasField("float_list_val") + np.testing.assert_array_almost_equal( + result.float_list_val.val, original_values, decimal=5 + ) + + +class TestVectorConfigOptions: + """Tests for vector-related configuration options.""" + + def test_default_vector_config_values(self): + """Test that vector config has sensible defaults.""" + config = EGValkeyOnlineStoreConfig() + + assert config.vector_index_algorithm == "HNSW" + assert config.vector_index_hnsw_m == 16 + assert config.vector_index_hnsw_ef_construction == 200 + assert config.vector_index_hnsw_ef_runtime == 10 + + def test_vector_config_custom_values(self): + """Test that vector config can be customized.""" + config = EGValkeyOnlineStoreConfig( + vector_index_algorithm="FLAT", + vector_index_hnsw_m=32, + vector_index_hnsw_ef_construction=400, + vector_index_hnsw_ef_runtime=20, + ) + + assert config.vector_index_algorithm == "FLAT" + assert config.vector_index_hnsw_m == 32 + assert config.vector_index_hnsw_ef_construction == 400 + assert config.vector_index_hnsw_ef_runtime == 20 + + +class TestGenerateHsetKeysForFeatures: + """Tests for _generate_hset_keys_for_features helper method.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with mixed vector and non-vector fields.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="entity_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="entity_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + Field(name="string_feature", dtype=String), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView with only non-vector fields.""" + return FeatureView( + name="test_fv_no_vec", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="entity_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="entity_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + Field(name="string_feature", dtype=String), + ], + ) + + def test_vector_field_uses_original_name(self, feature_view_with_vector): + """Test that vector fields use original name as hset key.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=["embedding"] + ) + ) + + # Vector field should use original name + assert "embedding" in hset_keys + assert "embedding" in vector_fields + + def test_non_vector_field_uses_mmh3_hash(self, feature_view_with_vector): + """Test that non-vector fields use mmh3 hash as hset key.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=["scalar_feature"] + ) + ) + + # Non-vector field should use mmh3 hash + expected_hash = _mmh3(f"{feature_view_with_vector.name}:scalar_feature") + assert expected_hash in hset_keys + assert "scalar_feature" not in vector_fields + + def test_timestamp_key_appended(self, feature_view_with_vector): + """Test that timestamp key is always appended to hset keys.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=["embedding"] + ) + ) + + ts_key = f"_ts:{feature_view_with_vector.name}" + assert ts_key in hset_keys + assert ts_key in requested_features + + def test_mixed_fields_correct_keys(self, feature_view_with_vector): + """Test that mixed vector and non-vector fields get correct keys.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, + requested_features=["embedding", "scalar_feature", "string_feature"], + ) + ) + + # Vector field uses original name + assert "embedding" in hset_keys + + # Non-vector fields use mmh3 hash + scalar_hash = _mmh3(f"{feature_view_with_vector.name}:scalar_feature") + string_hash = _mmh3(f"{feature_view_with_vector.name}:string_feature") + assert scalar_hash in hset_keys + assert string_hash in hset_keys + + # Only embedding should be in vector_fields (now a dict) + assert set(vector_fields.keys()) == {"embedding"} + + def test_no_requested_features_uses_all(self, feature_view_with_vector): + """Test that None requested_features returns all feature view features.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_with_vector, requested_features=None + ) + ) + + # Should include all features from the feature view + # Features are: scalar_feature, string_feature, embedding (excluding entity_id which is join key) + assert len(requested_features) == 4 # 3 features + timestamp key + + def test_feature_view_without_vectors(self, feature_view_no_vector): + """Test feature view with no vector fields returns empty vector_fields dict.""" + store = EGValkeyOnlineStore() + + requested_features, hset_keys, vector_fields = ( + store._generate_hset_keys_for_features( + feature_view_no_vector, + requested_features=["scalar_feature", "string_feature"], + ) + ) + + # No vector fields (empty dict) + assert vector_fields == {} + + # All fields should use mmh3 hash + for key in hset_keys: + if not isinstance(key, str) or not key.startswith("_ts:"): + assert isinstance(key, bytes) # mmh3 returns bytes + + +class TestVectorFieldValidation: + """Tests for vector field validation during index creation.""" + + def test_vector_field_missing_vector_length_raises_error( + self, valkey_online_store, repo_config_without_docker_connection_string + ): + """Test that vector field without vector_length raises ValueError.""" + from unittest.mock import MagicMock + + from valkey.exceptions import ResponseError + + # Create a FeatureView with vector field but no vector_length + fv = FeatureView( + name="test_missing_length", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + # vector_length intentionally not set (defaults to 0) + ), + ], + ) + + # Get vector fields + vector_fields = {f.name: f for f in fv.features if f.vector_index} + + # Mock client to avoid actual connection + mock_client = MagicMock() + # Simulate index doesn't exist (ResponseError is raised by valkey-py) + mock_client.ft.return_value.info.side_effect = ResponseError("Unknown index") + + with pytest.raises(ValueError, match="vector_length"): + valkey_online_store._create_vector_index_if_not_exists( + mock_client, + repo_config_without_docker_connection_string, + fv, + vector_fields, + ) + + def test_vector_field_with_negative_vector_length_raises_error( + self, valkey_online_store, repo_config_without_docker_connection_string + ): + """Test that vector field with negative vector_length raises ValueError.""" + from unittest.mock import MagicMock + + from valkey.exceptions import ResponseError + + # Create a FeatureView with vector field with negative vector_length + fv = FeatureView( + name="test_negative_length", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=-1, + ), + ], + ) + + vector_fields = {f.name: f for f in fv.features if f.vector_index} + + mock_client = MagicMock() + mock_client.ft.return_value.info.side_effect = ResponseError("Unknown index") + + with pytest.raises(ValueError, match="vector_length"): + valkey_online_store._create_vector_index_if_not_exists( + mock_client, + repo_config_without_docker_connection_string, + fv, + vector_fields, + ) + + +class TestVectorIndexCreation: + """Tests for vector index creation with correct schema.""" + + def test_index_includes_project_tag_field( + self, valkey_online_store, repo_config_without_docker_connection_string + ): + """Test that index schema includes TagField for __project__ filtering.""" + from unittest.mock import MagicMock + + from valkey.exceptions import ResponseError + + fv = FeatureView( + name="test_with_project_tag", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ), + ], + ) + + vector_fields = {f.name: f for f in fv.features if f.vector_index} + + mock_client = MagicMock() + # Simulate index doesn't exist + mock_client.ft.return_value.info.side_effect = ResponseError("Unknown index") + + valkey_online_store._create_vector_index_if_not_exists( + mock_client, + repo_config_without_docker_connection_string, + fv, + vector_fields, + ) + + # Verify create_index was called + mock_client.ft.return_value.create_index.assert_called_once() + + # Get the fields argument + call_kwargs = mock_client.ft.return_value.create_index.call_args + fields = call_kwargs.kwargs.get("fields") or call_kwargs.args[0] + + # Verify we have both VectorField and TagField + field_types = [type(f).__name__ for f in fields] + assert "VectorField" in field_types, "Index should include VectorField" + assert "TagField" in field_types, ( + "Index should include TagField for __project__" + ) + + # Verify TagField is for __project__ + tag_fields = [f for f in fields if type(f).__name__ == "TagField"] + assert len(tag_fields) == 1 + assert tag_fields[0].name == "__project__" + + +# ============================================================================ +# Vector Support Integration Tests (Docker Required) +# ============================================================================ + + +def _create_feature_view_with_vector_field(): + """Create a FeatureView with a vector embedding field.""" + fv = FeatureView( + name="item_embeddings", + source=FileSource( + name="item_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="item_name", dtype=String), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + return fv + + +def _make_vector_rows(): + """Generate rows with vector embeddings.""" + now = datetime.now(tz=timezone.utc) + return [ + ( + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + { + "item_name": ValueProto(string_val="item_1"), + "embedding": ValueProto( + float_list_val=FloatList(val=[0.1, 0.2, 0.3, 0.4]) + ), + }, + now, + None, + ), + ( + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=2)], + ), + { + "item_name": ValueProto(string_val="item_2"), + "embedding": ValueProto( + float_list_val=FloatList(val=[0.5, 0.6, 0.7, 0.8]) + ), + }, + now, + None, + ), + ] + + +@pytest.mark.docker +def test_valkey_online_write_batch_with_vector_field( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test writing a FeatureView with vector field stores data correctly.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data - note: index creation will fail without Search module, + # but the write itself should work for storage verification + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + # If Search module is not available, index creation will fail + # This is expected with basic Valkey container + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + # Verify data was stored + redis_client = _make_redis_client(repo_config) + + entity_key = EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ) + valkey_key_bin = _redis_key( + repo_config.project, + entity_key, + entity_key_serialization_version=repo_config.entity_key_serialization_version, + ) + + stored_data = redis_client.hgetall(valkey_key_bin) + + # Verify vector field is stored with original name (not hashed) + assert b"embedding" in stored_data + + # Verify non-vector field is stored with mmh3 hash + item_name_key = _mmh3(f"{feature_view.name}:item_name") + assert item_name_key in stored_data + + # Verify vector bytes can be deserialized + embedding_bytes = stored_data[b"embedding"] + vector = np.frombuffer(embedding_bytes, dtype=np.float32) + np.testing.assert_array_almost_equal(vector, [0.1, 0.2, 0.3, 0.4], decimal=5) + + # Verify __project__ is stored for vector search filtering + assert b"__project__" in stored_data + # Should be stored as string (valkey-py encodes to bytes, but value should match project) + assert stored_data[b"__project__"] == repo_config.project.encode() + + # Verify __entity_key__ is stored for entity key retrieval + assert b"__entity_key__" in stored_data + + +@pytest.mark.docker +def test_valkey_online_read_with_vector_field( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading a FeatureView with vector field deserializes correctly.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + # Read data back + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=2)], + ), + ] + + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + ) + + # Verify results + assert len(results) == 2 + + # Check first entity + ts1, features1 = results[0] + assert ts1 is not None + assert "embedding" in features1 + assert "item_name" in features1 + + # Verify vector values + embedding1 = features1["embedding"] + assert embedding1.HasField("float_list_val") + np.testing.assert_array_almost_equal( + embedding1.float_list_val.val, [0.1, 0.2, 0.3, 0.4], decimal=5 + ) + + # Check second entity + ts2, features2 = results[1] + embedding2 = features2["embedding"] + np.testing.assert_array_almost_equal( + embedding2.float_list_val.val, [0.5, 0.6, 0.7, 0.8], decimal=5 + ) + + +@pytest.mark.docker +def test_valkey_online_read_with_requested_features_vector_only( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading only the vector field using requested_features parameter.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + ] + + # Request only the vector field + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + requested_features=["embedding"], + ) + + assert len(results) == 1 + ts, features = results[0] + + # Should only have the embedding feature + assert "embedding" in features + assert "item_name" not in features + + # Verify vector values + embedding = features["embedding"] + assert embedding.HasField("float_list_val") + np.testing.assert_array_almost_equal( + embedding.float_list_val.val, [0.1, 0.2, 0.3, 0.4], decimal=5 + ) + + +@pytest.mark.docker +def test_valkey_online_read_with_requested_features_non_vector_only( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading only non-vector fields using requested_features parameter.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + ] + + # Request only the non-vector field + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + requested_features=["item_name"], + ) + + assert len(results) == 1 + ts, features = results[0] + + # Should only have the item_name feature + assert "item_name" in features + assert "embedding" not in features + + # Verify string value + assert features["item_name"].string_val == "item_1" + + +@pytest.mark.docker +def test_valkey_online_read_with_requested_features_mixed( + repo_config: RepoConfig, + valkey_online_store: EGValkeyOnlineStore, +): + """Test reading a mix of vector and non-vector fields using requested_features.""" + feature_view = _create_feature_view_with_vector_field() + data = _make_vector_rows() + + # Write data first + try: + valkey_online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=data, + progress=None, + ) + except Exception as e: + if "Search" in str(e) or "unknown command" in str(e).lower(): + pytest.skip("Valkey Search module not available in test container") + raise + + entity_keys = [ + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=2)], + ), + ] + + # Request both vector and non-vector fields + results = valkey_online_store.online_read( + config=repo_config, + table=feature_view, + entity_keys=entity_keys, + requested_features=["embedding", "item_name"], + ) + + assert len(results) == 1 + ts, features = results[0] + + # Should have both features + assert "embedding" in features + assert "item_name" in features + + # Verify vector values + embedding = features["embedding"] + np.testing.assert_array_almost_equal( + embedding.float_list_val.val, [0.5, 0.6, 0.7, 0.8], decimal=5 + ) + + # Verify string value + assert features["item_name"].string_val == "item_2" + + +class TestGetVectorFieldForSearch: + """Tests for _get_vector_field_for_search helper method.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with vector field for testing.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView without vector fields.""" + return FeatureView( + name="test_fv_no_vector", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + ], + ) + + def test_returns_vector_field_from_requested_features( + self, feature_view_with_vector + ): + """Test that vector field is returned when in requested_features.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_with_vector, + requested_features=["embedding", "scalar_feature"], + ) + assert result is not None + assert result.name == "embedding" + + def test_returns_first_vector_field_when_not_in_requested( + self, feature_view_with_vector + ): + """Test that first vector field is returned when not in requested_features.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_with_vector, requested_features=["scalar_feature"] + ) + assert result is not None + assert result.name == "embedding" + + def test_returns_none_for_no_vector_fields(self, feature_view_no_vector): + """Test that None is returned when no vector fields exist.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_no_vector, requested_features=["scalar_feature"] + ) + assert result is None + + +class TestSerializeEmbeddingForSearch: + """Tests for _serialize_embedding_for_search helper method.""" + + @pytest.fixture + def float32_vector_field(self): + """Create a Float32 vector field.""" + return Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + @pytest.fixture + def float64_vector_field(self): + """Create a Float64 vector field.""" + return Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + def test_serializes_to_float32_bytes(self, float32_vector_field): + """Test that embedding is serialized to float32 bytes.""" + store = EGValkeyOnlineStore() + embedding = [0.1, 0.2, 0.3, 0.4] + result = store._serialize_embedding_for_search(embedding, float32_vector_field) + + # Verify it's bytes + assert isinstance(result, bytes) + + # Verify length (4 floats * 4 bytes each = 16 bytes) + assert len(result) == 16 + + # Verify values can be deserialized back + arr = np.frombuffer(result, dtype=np.float32) + np.testing.assert_array_almost_equal(arr, embedding, decimal=5) + + def test_serializes_to_float64_bytes(self, float64_vector_field): + """Test that embedding is serialized to float64 bytes for Float64 fields.""" + store = EGValkeyOnlineStore() + embedding = [0.1, 0.2, 0.3, 0.4] + result = store._serialize_embedding_for_search(embedding, float64_vector_field) + + # Verify it's bytes + assert isinstance(result, bytes) + + # Verify length (4 doubles * 8 bytes each = 32 bytes) + assert len(result) == 32 + + # Verify values can be deserialized back + arr = np.frombuffer(result, dtype=np.float64) + np.testing.assert_array_almost_equal(arr, embedding, decimal=10) + + def test_raises_error_on_dimension_mismatch(self, float32_vector_field): + """Test that ValueError is raised when embedding dimension doesn't match field.""" + store = EGValkeyOnlineStore() + # Field expects 4 dimensions, but we provide 3 + embedding = [0.1, 0.2, 0.3] + with pytest.raises(ValueError, match="dimension .* does not match"): + store._serialize_embedding_for_search(embedding, float32_vector_field) + + +class TestRetrieveOnlineDocumentsV2Validation: + """Tests for retrieve_online_documents_v2 input validation.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with vector field for testing.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView without vector fields.""" + return FeatureView( + name="test_fv_no_vector", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + ], + ) + + @pytest.fixture + def repo_config(self): + """Create a minimal RepoConfig for testing.""" + return RepoConfig( + project="test_project", + provider="local", + registry="test_registry.db", + online_store=EGValkeyOnlineStoreConfig( + type="eg-valkey", + connection_string="localhost:6379", + ), + entity_key_serialization_version=3, + ) + + def test_raises_error_when_embedding_is_none( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when embedding is None.""" + store = EGValkeyOnlineStore() + with pytest.raises(ValueError, match="embedding must be provided"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=None, + top_k=10, + ) + + def test_raises_error_when_query_string_provided( + self, repo_config, feature_view_with_vector + ): + """Test that NotImplementedError is raised when query_string is provided.""" + store = EGValkeyOnlineStore() + with pytest.raises(NotImplementedError, match="Keyword search"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + query_string="test query", + ) + + def test_raises_error_when_no_vector_field( + self, repo_config, feature_view_no_vector + ): + """Test that ValueError is raised when FeatureView has no vector fields.""" + store = EGValkeyOnlineStore() + with pytest.raises(ValueError, match="No vector field found"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_no_vector, + requested_features=["scalar_feature"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + ) + + def test_raises_error_when_dimension_mismatch( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when embedding dimension doesn't match field.""" + store = EGValkeyOnlineStore() + # feature_view_with_vector has vector_length=4, so 3-dim embedding should fail + with pytest.raises(ValueError, match="Embedding dimension .* does not match"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3], # Wrong dimension (3 instead of 4) + top_k=10, + ) + + def test_raises_error_when_index_does_not_exist( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when vector index doesn't exist.""" + from unittest.mock import MagicMock, patch + + from valkey.exceptions import ResponseError + + store = EGValkeyOnlineStore() + + # Mock the client to simulate "no such index" error + mock_client = MagicMock() + mock_client.ft.return_value.search.side_effect = ResponseError("no such index") + + with patch.object(store, "_get_client", return_value=mock_client): + with pytest.raises(ValueError, match="does not exist.*materialize"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + ) + + +class TestExecuteVectorSearch: + """Tests for _execute_vector_search helper method.""" + + @pytest.fixture + def store(self): + return EGValkeyOnlineStore() + + def test_project_name_with_hyphen_is_escaped(self, store): + """Test that project names with hyphens are backslash-escaped in queries.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="my-project", # Hyphen in project name + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + mock_client.ft.return_value.search.assert_called_once() + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # Hyphen should be backslash-escaped to prevent interpretation as negation + assert r"my\-project" in query.query_string() + + def test_project_name_with_double_quote_is_escaped(self, store): + """Test that double quotes in project names are backslash-escaped.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project='my"project', # Double quote in project name + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + mock_client.ft.return_value.search.assert_called_once() + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # Double quote should be backslash-escaped + assert r"\"" in query.query_string() + + def test_no_sortby_in_knn_query(self, store): + """Test that KNN queries do not use SORTBY (engine sorts by distance automatically).""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # KNN results are sorted by the engine; no explicit SORTBY should be set + assert query._sortby is None + + def test_default_distance_is_infinity_not_zero(self, store): + """Test that missing __distance__ defaults to infinity, not 0.0.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_doc = MagicMock() + mock_doc.id = "test_key" + # Simulate missing __distance__ attribute + del mock_doc.__distance__ + + mock_result = MagicMock() + mock_result.docs = [mock_doc] + mock_client.ft.return_value.search.return_value = mock_result + + results = store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + # Distance should default to infinity, not 0.0 + # 0.0 would incorrectly indicate a perfect match + assert len(results) == 1 + doc_key, distance = results[0] + assert distance == float("inf") From 23bfd01e81ab922126d71101d12434714fd79ce3 Mon Sep 17 00:00:00 2001 From: vanitabhagwat <92561664+vanitabhagwat@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:01:04 -0700 Subject: [PATCH 5/7] feat: Quantization support for elastic search (#355) --- .../elasticsearch.py | 407 +++++++++++++++--- .../test_elasticsearch_online_store.py | 341 +++++++++++++++ 2 files changed, 677 insertions(+), 71 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index 58ff9b5f3b0..79751cea5b7 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -3,11 +3,13 @@ import base64 import json import logging +import math from collections import defaultdict from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from elasticsearch import Elasticsearch, helpers +from pydantic import model_validator from feast import Entity, FeatureView, RepoConfig from feast.infra.key_encoding_utils import ( @@ -48,6 +50,110 @@ class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): # The number of rows to write in a single batch write_batch_size: Optional[int] = 40 + # Quantization / index_options configuration + vector_index_type: Optional[str] = None + # One of: "hnsw", "int8_hnsw", "int4_hnsw", "bbq_hnsw", + # "flat", "int8_flat", "int4_flat", "bbq_flat" + # None = use ES default (hnsw for <8.x, int8_hnsw for 9.0+) + + # HNSW tuning parameters (only apply to HNSW index types) + hnsw_m: Optional[int] = None # Neighbor connections (ES default: 16) + hnsw_ef_construction: Optional[int] = ( + None # Build-time candidates (ES default: 100) + ) + + # Rescore configuration for quantized indices only (int4/int8/bbq) + rescore_oversample: Optional[float] = ( + None # Must be (1.0, 10.0) exclusive; None to disable + ) + + # Query method toggle + use_native_knn: bool = False # False = script_score (backward compatible) + # True = native knn query (faster, approximate) + + # KNN query tuning + knn_num_candidates_multiplier: Optional[float] = ( + None # Default: 2.0; num_candidates = top_k * multiplier (must be >= 1.0) + ) + + @model_validator(mode="after") + def validate_quantization_config(self): + """Validate quantization configuration constraints.""" + # Validate vector_index_type is a known value + valid_index_types = { + "hnsw", + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "flat", + "int8_flat", + "int4_flat", + "bbq_flat", + } + if ( + self.vector_index_type is not None + and self.vector_index_type not in valid_index_types + ): + raise ValueError( + f"vector_index_type must be one of {valid_index_types}, got {self.vector_index_type}" + ) + + # Validate rescore_oversample range and constraints + # ES requires: (1.0, 10.0) exclusive, per https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/dense-vector + if self.rescore_oversample is not None: + if self.rescore_oversample <= 1.0 or self.rescore_oversample >= 10.0: + raise ValueError( + f"rescore_oversample must be in the range (1.0, 10.0) exclusive, " + f"got {self.rescore_oversample}" + ) + + # Validate rescore_oversample only applies to quantized indices + quantized_types = { + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "int8_flat", + "int4_flat", + "bbq_flat", + } + if ( + self.vector_index_type is None + or self.vector_index_type not in quantized_types + ): + raise ValueError( + f"rescore_oversample can only be used with quantized index types {quantized_types}, " + f"got vector_index_type={self.vector_index_type}" + ) + + # Validate HNSW parameters only apply to HNSW index types + hnsw_types = {"hnsw", "int8_hnsw", "int4_hnsw", "bbq_hnsw"} + if (self.hnsw_m is not None or self.hnsw_ef_construction is not None) and ( + self.vector_index_type is not None + and self.vector_index_type not in hnsw_types + ): + raise ValueError( + f"hnsw_m and hnsw_ef_construction only apply to HNSW index types {hnsw_types}, " + f"got vector_index_type='{self.vector_index_type}'" + ) + + # Validate HNSW parameter ranges (basic sanity only; ES enforces its own limits) + if self.hnsw_m is not None and self.hnsw_m < 1: + raise ValueError(f"hnsw_m must be >= 1, got {self.hnsw_m}") + + if self.hnsw_ef_construction is not None and self.hnsw_ef_construction < 1: + raise ValueError( + f"hnsw_ef_construction must be >= 1, got {self.hnsw_ef_construction}" + ) + + # Validate knn_num_candidates_multiplier range (must be >= 1.0) + if self.knn_num_candidates_multiplier is not None: + if self.knn_num_candidates_multiplier < 1.0: + raise ValueError( + f"knn_num_candidates_multiplier must be >= 1.0, got {self.knn_num_candidates_multiplier}" + ) + + return self + class ElasticSearchOnlineStore(OnlineStore): _client: Optional[Elasticsearch] = None @@ -231,6 +337,14 @@ def create_index(self, config: RepoConfig, table: FeatureView): or 512 ) + # Validate vector_field_length is positive + if vector_field_length <= 0: + raise ValueError( + f"vector_field_length must be > 0, got {vector_field_length} for table '{table.name}'" + ) + + vector_mapping = _build_vector_mapping(config, vector_field_length, table.name) + index_mapping = { "dynamic_templates": [ { @@ -242,12 +356,7 @@ def create_index(self, config: RepoConfig, table: FeatureView): "properties": { "feature_value": {"type": "binary"}, "value_text": {"type": "text"}, - "vector_value": { - "type": "dense_vector", - "dims": vector_field_length, - "index": True, - "similarity": config.online_store.similarity, - }, + "vector_value": vector_mapping, }, }, } @@ -260,10 +369,11 @@ def create_index(self, config: RepoConfig, table: FeatureView): }, } - self._get_client(config).indices.create( - index=table.name, - mappings=index_mapping, - ) + client = self._get_client(config) + if not client.indices.exists(index=table.name): + client.indices.create(index=table.name, mappings=index_mapping) + else: + logger.info(f"Index '{table.name}' already exists; skipping creation. ") def update( self, @@ -274,12 +384,27 @@ def update( entities_to_keep: Sequence[Entity], partial: bool, ): - # implement the update method + client = self._get_client(config) + + # Cache existing indices to reduce API calls + all_table_names = [t.name for t in tables_to_delete] + [ + t.name for t in tables_to_keep + ] + existing_indices: Set[str] = set() + for table_name in all_table_names: + if client.indices.exists(index=table_name): + existing_indices.add(table_name) + + # Delete data from indices that should be removed for table in tables_to_delete: - if self._get_client(config).indices.exists(index=table.name): - self._get_client(config).delete_by_query(index=table.name) + if table.name in existing_indices: + client.delete_by_query( + index=table.name, body={"query": {"match_all": {}}} + ) + + # Create indices for tables that should be kept for table in tables_to_keep: - if not self._get_client(config).indices.exists(index=table.name): + if table.name not in existing_indices: self.create_index(config, table) def teardown( @@ -289,10 +414,17 @@ def teardown( entities: Sequence[Entity], ): project = config.project + client = self._get_client(config) try: + # Cache existing indices to reduce API calls + existing_indices: Set[str] = set() for table in tables: - if self._get_client(config).indices.exists(index=table.name): - self._get_client(config).indices.delete(index=table.name) + if client.indices.exists(index=table.name): + existing_indices.add(table.name) + + # Delete all existing indices + for table_name in existing_indices: + client.indices.delete(index=table_name) except Exception as e: logging.exception(f"Error deleting index in project {project}: {e}") raise @@ -330,18 +462,42 @@ def retrieve_online_documents( if vector_field else config.online_store.vector_field_path or "embedding.vector_value" ) - query = { - "script_score": { - "query": { - "bool": {"filter": [{"exists": {"field": vector_field_path}}]} - }, - "script": { - "source": f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0", - "params": {"query_vector": embedding}, - }, + + # Build query based on use_native_knn config + body: Dict[str, Any] = {"size": top_k, "_source": True} + + if config.online_store.use_native_knn: + # Native knn query (fast, approximate) + # Uses the similarity metric configured in the index mapping + multiplier = config.online_store.knn_num_candidates_multiplier or 2.0 + num_candidates: int = max(top_k, math.ceil(top_k * multiplier)) + + knn_query: Dict[str, Any] = { + "field": vector_field_path, + "query_vector": embedding, + "k": top_k, + "num_candidates": num_candidates, + } + + if config.online_store.rescore_oversample is not None: + knn_query["rescore_vector"] = { + "oversample": config.online_store.rescore_oversample + } + + body["knn"] = knn_query + else: + # Legacy script_score query (slow, exact, backward compatible) + body["query"] = { + "script_score": { + "query": { + "bool": {"filter": [{"exists": {"field": vector_field_path}}]} + }, + "script": { + "source": f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0", + "params": {"query_vector": embedding}, + }, + } } - } - body = {"size": top_k, "_source": True, "query": query} response = self._get_client(config).search(index=table.name, body=body) rows = response["hits"]["hits"][0:top_k] for row in rows: @@ -427,45 +583,100 @@ def retrieve_online_documents_v2( ) or config.online_store.similarity ).lower() - if similarity == "cosine": - script = f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0" - elif similarity == "dot_product": - script = f"dotProduct(params.query_vector, '{vector_field_path}')" - elif similarity in ("l2", "l2_norm", "euclidean"): - script = f"1 / (1 + l2norm(params.query_vector, '{vector_field_path}'))" + + # Determine query method: native knn or script_score + use_native_knn = config.online_store.use_native_knn + + if use_native_knn: + # Native knn query (fast, approximate) + # Uses the similarity metric configured in the index mapping + # Validate that the requested similarity is supported + if similarity not in ( + "cosine", + "dot_product", + "l2", + "l2_norm", + "euclidean", + ): + raise ValueError( + f"Unsupported similarity for native knn: {similarity}" + ) + + # Calculate num_candidates for approximate nearest neighbor search + multiplier = config.online_store.knn_num_candidates_multiplier or 2.0 + num_candidates: int = max(top_k, math.ceil(top_k * multiplier)) + + knn_clause: Dict[str, Any] = { + "field": vector_field_path, + "query_vector": embedding, + "k": top_k, + "num_candidates": num_candidates, + } + + if config.online_store.rescore_oversample is not None: + knn_clause["rescore_vector"] = { + "oversample": config.online_store.rescore_oversample + } else: - raise ValueError( - f"Unsupported similarity/distance_metric: {similarity}" - ) + # Legacy script_score query (slow, exact, backward compatible) + if similarity == "cosine": + script = f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0" + elif similarity == "dot_product": + script = f"dotProduct(params.query_vector, '{vector_field_path}')" + elif similarity in ("l2", "l2_norm", "euclidean"): + script = ( + f"1 / (1 + l2norm(params.query_vector, '{vector_field_path}'))" + ) + else: + raise ValueError( + f"Unsupported similarity/distance_metric: {similarity}" + ) - # Hybrid search + # Build query based on search type and query method + # Hybrid search (embedding + keyword) if embedding and query_string: - body["query"] = { - "script_score": { - "query": { - "bool": { - "must": [ - {"query_string": {"query": f'"{query_string}"'}}, - {"exists": {"field": vector_field_path}}, - ] - } - }, - "script": { - "source": script, - "params": {"query_vector": embedding}, - }, + if use_native_knn: + # Native knn with query filter + body["knn"] = knn_clause + body["query"] = {"query_string": {"query": f'"{query_string}"'}} + else: + # Legacy script_score with keyword filter + body["query"] = { + "script_score": { + "query": { + "bool": { + "must": [ + {"query_string": {"query": f'"{query_string}"'}}, + {"exists": {"field": vector_field_path}}, + ] + } + }, + "script": { + "source": script, + "params": {"query_vector": embedding}, + }, + } } - } # Vector search only elif embedding: - body["query"] = { - "script_score": { - "query": { - "bool": {"filter": [{"exists": {"field": vector_field_path}}]} - }, - "script": {"source": script, "params": {"query_vector": embedding}}, + if use_native_knn: + # Native knn query + body["knn"] = knn_clause + else: + # Legacy script_score + body["query"] = { + "script_score": { + "query": { + "bool": { + "filter": [{"exists": {"field": vector_field_path}}] + } + }, + "script": { + "source": script, + "params": {"query_vector": embedding}, + }, + } } - } # Keyword search only elif query_string: body["query"] = {"query_string": {"query": f'"{query_string}"'}} @@ -500,6 +711,47 @@ def retrieve_online_documents_v2( return result +def _build_vector_mapping( + config: RepoConfig, vector_field_length: int, table_name: str +) -> Dict[str, Any]: + """ + Build the dense_vector mapping for an Elasticsearch index, including + quantization index_options when configured. + """ + # Validate dimension-based quantization constraints + if config.online_store.vector_index_type: + index_type = config.online_store.vector_index_type + if "int4" in index_type and vector_field_length % 2 != 0: + raise ValueError( + f"int4 quantization ('{index_type}') requires even number of dimensions, " + f"got {vector_field_length} for table '{table_name}'. " + f"See https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/dense-vector" + ) + if "bbq" in index_type and vector_field_length < 64: + raise ValueError( + f"bbq quantization ('{index_type}') requires >= 64 dimensions, " + f"got {vector_field_length} for table '{table_name}'. " + f"See https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/dense-vector" + ) + + vector_mapping: Dict[str, Any] = { + "type": "dense_vector", + "dims": vector_field_length, + "index": True, + "similarity": config.online_store.similarity, + } + + if config.online_store.vector_index_type: + index_options: Dict[str, Any] = {"type": config.online_store.vector_index_type} + if config.online_store.hnsw_m is not None: + index_options["m"] = config.online_store.hnsw_m + if config.online_store.hnsw_ef_construction is not None: + index_options["ef_construction"] = config.online_store.hnsw_ef_construction + vector_mapping["index_options"] = index_options + + return vector_mapping + + def _to_value_proto(value: Any) -> ValueProto: """ Convert a value to a ValueProto object. @@ -507,24 +759,37 @@ def _to_value_proto(value: Any) -> ValueProto: val_proto = ValueProto() if isinstance(value, ValueProto): return value + # Check bool before int/float since bool is a subclass of int in Python if isinstance(value, bool): val_proto.bool_val = value elif isinstance(value, float): val_proto.float_val = value - elif isinstance(value, str): - val_proto.string_val = value elif isinstance(value, int): val_proto.int64_val = value - elif isinstance(value, list) and all(isinstance(v, float) for v in value): - val_proto.float_list_val.val.extend(value) - elif isinstance(value, dict) and "feature_value" in value: - try: - raw_bytes = base64.b64decode(value["feature_value"]) - val_proto.ParseFromString(raw_bytes) - except Exception as e: - raise ValueError(f"Failed to decode feature_value from dict: {e}") + elif isinstance(value, str): + val_proto.string_val = value + elif isinstance(value, list): + if not value: + val_proto.float_list_val.val.extend(value) + elif all(isinstance(v, float) for v in value): + val_proto.float_list_val.val.extend(value) + elif all(isinstance(v, int) for v in value): + val_proto.int64_list_val.val.extend(value) + else: + raise ValueError(f"List contains mixed or unsupported types: {value}") + elif isinstance(value, dict): + if "feature_value" in value: + try: + raw_bytes = base64.b64decode(value["feature_value"]) + val_proto.ParseFromString(raw_bytes) + except Exception as e: + raise ValueError(f"Failed to decode feature_value from dict: {e}") + else: + raise ValueError(f"Dict missing 'feature_value' key: {value}") else: - raise ValueError(f"Unsupported type for ValueProto: {type(value)}") + raise ValueError( + f"Unsupported type for ValueProto: {type(value).__name__} (value: {value})" + ) return val_proto diff --git a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py index cb94205d1f3..12f77eab9b0 100644 --- a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py +++ b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py @@ -4,6 +4,7 @@ from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( _encode_feature_value, + _to_value_proto, ) from feast.protos.feast.types.Value_pb2 import ( FloatList, @@ -66,3 +67,343 @@ def test_default_is_vector_false(self): result = _encode_feature_value(value) assert "vector_value" not in result + + +class TestElasticSearchOnlineStoreConfig: + def test_defaults(self): + """Test default config values.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + config = ElasticSearchOnlineStoreConfig() + assert config.vector_index_type is None + assert config.hnsw_m is None + assert config.hnsw_ef_construction is None + assert config.rescore_oversample is None + assert config.use_native_knn is False + assert config.knn_num_candidates_multiplier is None + + def test_valid_index_type(self): + """Test valid vector_index_type values.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + for index_type in [ + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "hnsw", + "flat", + "bbq_flat", + ]: + config = ElasticSearchOnlineStoreConfig(vector_index_type=index_type) + assert config.vector_index_type == index_type + + def test_invalid_index_type(self): + """Test invalid vector_index_type raises ValueError.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + with pytest.raises(ValueError, match="vector_index_type must be one of"): + ElasticSearchOnlineStoreConfig(vector_index_type="invalid_type") + + def test_rescore_range_validation(self): + """Test rescore_oversample range validation.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid values: (1.0, 10.0) exclusive + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=2.0 + ) + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=5.5 + ) + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=9.9 + ) + # None disables rescore + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=None + ) + + # Invalid: at or below 1.0 + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=1.0 + ) + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=0.5 + ) + + # Invalid: at or above 10.0 + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=10.0 + ) + with pytest.raises( + ValueError, match="must be in the range \\(1.0, 10.0\\) exclusive" + ): + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=50.0 + ) + + def test_rescore_requires_quantized_type(self): + """Test rescore_oversample only works with quantized types.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid: quantized type + ElasticSearchOnlineStoreConfig( + vector_index_type="int8_hnsw", rescore_oversample=2.0 + ) + + # Invalid: non-quantized type + with pytest.raises(ValueError, match="can only be used with quantized"): + ElasticSearchOnlineStoreConfig( + vector_index_type="hnsw", rescore_oversample=2.0 + ) + + # Invalid: vector_index_type is None + with pytest.raises(ValueError, match="can only be used with quantized"): + ElasticSearchOnlineStoreConfig( + vector_index_type=None, rescore_oversample=2.0 + ) + + def test_hnsw_params_require_hnsw_type(self): + """Test HNSW params only work with HNSW types.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid: HNSW type + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=32) + + # Invalid: flat type + with pytest.raises(ValueError, match="only apply to HNSW index types"): + ElasticSearchOnlineStoreConfig(vector_index_type="int8_flat", hnsw_m=32) + + def test_hnsw_m_range(self): + """Test hnsw_m range validation.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid: ES enforces its own upper limits, Feast only rejects < 1 + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=1) + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=100) + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=200) + + # Invalid: zero or negative + with pytest.raises(ValueError, match="must be >= 1"): + ElasticSearchOnlineStoreConfig(vector_index_type="int8_hnsw", hnsw_m=0) + + def test_knn_multiplier_validation(self): + """Test knn_num_candidates_multiplier validation.""" + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStoreConfig, + ) + + # Valid + ElasticSearchOnlineStoreConfig(knn_num_candidates_multiplier=1.0) + ElasticSearchOnlineStoreConfig(knn_num_candidates_multiplier=10.0) + + # Invalid: too low + with pytest.raises(ValueError, match="must be >= 1.0"): + ElasticSearchOnlineStoreConfig(knn_num_candidates_multiplier=0.5) + + +class TestCreateIndexWithQuantization: + def test_index_mapping_with_int8_quantization(self): + """Test index mapping includes quantization settings.""" + from unittest.mock import MagicMock + + from feast import FeatureView, Field, RepoConfig + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, + ) + from feast.types import Array, Float32 + + config = RepoConfig( + project="test", + registry="registry.db", + provider="local", + online_store=ElasticSearchOnlineStoreConfig( + vector_enabled=True, + similarity="cosine", + vector_index_type="int8_hnsw", + hnsw_m=32, + hnsw_ef_construction=200, + ), + ) + + fv = MagicMock(spec=FeatureView) + fv.name = "test_fv" + fv.schema = [ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_length=128, + vector_search_metric="cosine", + ) + ] + + store = ElasticSearchOnlineStore() + mock_client = MagicMock() + mock_client.indices.exists.return_value = False + store._client = mock_client + + store.create_index(config, fv) + + # Verify create was called + assert mock_client.indices.create.called + call_args = mock_client.indices.create.call_args + mapping = call_args.kwargs["mappings"] + + # Check quantization settings in dynamic template + template = mapping["dynamic_templates"][0]["feature_objects"]["mapping"] + vector_props = template["properties"]["vector_value"] + + assert vector_props["type"] == "dense_vector" + assert vector_props["dims"] == 128 + assert "index_options" in vector_props + assert vector_props["index_options"]["type"] == "int8_hnsw" + assert vector_props["index_options"]["m"] == 32 + assert vector_props["index_options"]["ef_construction"] == 200 + + def test_int4_requires_even_dimensions(self): + """Test int4 quantization validates even dimensions.""" + from unittest.mock import MagicMock + + from feast import FeatureView, Field, RepoConfig + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, + ) + from feast.types import Array, Float32 + + config = RepoConfig( + project="test", + registry="registry.db", + provider="local", + online_store=ElasticSearchOnlineStoreConfig( + vector_enabled=True, vector_index_type="int4_hnsw" + ), + ) + + fv = MagicMock(spec=FeatureView) + fv.name = "test_fv" + fv.schema = [ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_length=127, # Odd number + ) + ] + + store = ElasticSearchOnlineStore() + mock_client = MagicMock() + mock_client.indices.exists.return_value = False + store._client = mock_client + + with pytest.raises(ValueError, match="requires even number of dimensions"): + store.create_index(config, fv) + + def test_bbq_requires_min_dimensions(self): + """Test bbq quantization validates minimum dimensions.""" + from unittest.mock import MagicMock + + from feast import FeatureView, Field, RepoConfig + from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, + ) + from feast.types import Array, Float32 + + config = RepoConfig( + project="test", + registry="registry.db", + provider="local", + online_store=ElasticSearchOnlineStoreConfig( + vector_enabled=True, vector_index_type="bbq_hnsw" + ), + ) + + fv = MagicMock(spec=FeatureView) + fv.name = "test_fv" + fv.schema = [ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_length=32, # Less than 64 + ) + ] + + store = ElasticSearchOnlineStore() + mock_client = MagicMock() + mock_client.indices.exists.return_value = False + store._client = mock_client + + with pytest.raises(ValueError, match="requires >= 64 dimensions"): + store.create_index(config, fv) + + +class TestToValueProto: + def test_bool_not_treated_as_int(self): + """bool is a subclass of int in Python; ensure True -> bool_val, not int64_val.""" + result = _to_value_proto(True) + assert result.bool_val is True + assert result.int64_val == 0 + + result = _to_value_proto(False) + assert result.bool_val is False + + def test_int(self): + result = _to_value_proto(42) + assert result.int64_val == 42 + assert result.bool_val is False + + def test_float(self): + result = _to_value_proto(3.14) + assert result.float_val == pytest.approx(3.14) + + def test_string(self): + result = _to_value_proto("hello") + assert result.string_val == "hello" + + def test_float_list(self): + result = _to_value_proto([1.0, 2.0, 3.0]) + assert list(result.float_list_val.val) == pytest.approx([1.0, 2.0, 3.0]) + + def test_int_list(self): + result = _to_value_proto([1, 2, 3]) + assert list(result.int64_list_val.val) == [1, 2, 3] + + def test_mixed_list_raises(self): + with pytest.raises(ValueError, match="mixed or unsupported"): + _to_value_proto([1, "two", 3.0]) + + def test_passthrough_value_proto(self): + original = ValueProto(string_val="already a proto") + result = _to_value_proto(original) + assert result is original + + def test_unsupported_type_raises(self): + with pytest.raises(ValueError, match="Unsupported type"): + _to_value_proto(object()) From b338233f46822ed956acca7e86607c1912918da7 Mon Sep 17 00:00:00 2001 From: Manisha Sudhir <30449541+Manisha4@users.noreply.github.com> Date: Wed, 6 May 2026 10:11:52 -0700 Subject: [PATCH 6/7] feat: Add retrieve_online_documents_v3 SDK with multi-vector and hybrid fusion (#358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: implement retrieve_online_documents_v3 SDK method - Multi-vector search with configurable fusion (RRF, WEIGHTED_LINEAR, VECTOR_ONLY) via the ES retriever API. Valkey gracefully degrades to single-vector KNN with warnings. - "embedding" magic key for V2→V3 migration convenience - Reserved output fields: final_score, signal_scores - include_signal_scores and distance_metric accepted as reserved params - ODFV and reserved-name collision validation - Shared signal_scores encoding via _signal_scores helper * update tests * update tests * fixing linting * docs: clarify final_score semantics in Valkey V3 docstring Correct the Valkey final_score description — Valkey's __distance__ is lower-is-better across all metrics (COSINE, L2, IP), not higher-is-better for IP. Call out the ordering inversion vs Elasticsearch so callers don't assume cross-backend score portability. Co-Authored-By: Claude Opus 4.7 * fix: plumb include_signal_scores through V3 and align defaults to False Valkey/ES/provider/passthrough/online-store defaults were True, mismatching the SDK's False default. Align all layers on False and thread the parameter from retrieve_online_documents_v3 through the internal dispatcher, provider, and online stores so callers can opt in today and transparently pick up the explain-based per-signal path when it lands — no API change required. Tighten docstrings to describe the current best-effort behavior instead of hinting at latency tradeoffs that aren't wired yet. Co-Authored-By: Claude Opus 4.7 * updating doc string * fix: preserve ranked row order in V3 retrieve_online_documents _retrieve_from_online_store_v3 was passing the driver's ranked rows through _get_unique_entities_from_values, which sorts and dedupes by entity-key bytes. That helper is correct for batch entity lookups but wrong here — ES/Valkey have already ordered rows by relevance, and the sort was scrambling them in the final DataFrame (e.g. doc_10 jumping ahead of doc_3 because "10" < "3" lexicographically). Replace the helper call with an identity mapping so the driver's rank order flows through untouched. No change to V1, V2, batch reads, or the helper itself; V3 output now matches the order returned by the online store. Co-Authored-By: Claude Opus 4.7 --------- Co-authored-by: Manisha4 Co-authored-by: Claude Opus 4.7 --- sdk/python/feast/feature_store.py | 214 ++++++ sdk/python/feast/feature_view.py | 4 - .../infra/online_stores/_signal_scores.py | 18 + .../feast/infra/online_stores/eg_valkey.py | 116 +++ .../elasticsearch.py | 218 ++++++ .../feast/infra/online_stores/online_store.py | 25 + .../feast/infra/online_stores/remote.py | 25 + .../feast/infra/passthrough_provider.py | 31 + sdk/python/feast/infra/provider.py | 23 + sdk/python/feast/utils.py | 4 - sdk/python/tests/foo_provider.py | 22 + .../unit/infra/online_store/test_valkey.py | 421 +++++++++++ .../test_elasticsearch_online_store.py | 700 ++++++++++++++++++ .../unit/online_store/test_signal_scores.py | 76 ++ 14 files changed, 1889 insertions(+), 8 deletions(-) create mode 100644 sdk/python/feast/infra/online_stores/_signal_scores.py create mode 100644 sdk/python/tests/unit/online_store/test_signal_scores.py diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index b4ec250ec00..b1b5660f35e 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2547,6 +2547,137 @@ def retrieve_online_documents_v2( query_string, ) + def retrieve_online_documents_v3( + self, + features: List[str], + top_k: int, + embeddings: Optional[Dict[str, List[float]]] = None, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> OnlineResponse: + """ + Retrieve documents using multi-vector search with configurable fusion. + + Args: + features: Feature references (e.g., ["doc_fv:title", "doc_fv:body"]). + top_k: Number of results to return. + embeddings: Map of vector field name to query vector. Required. + Single entry is equivalent to V2's embedding param. + Special case: the key "embedding" auto-resolves to the + FeatureView's single vector field, easing V2 to V3 migration. + FeatureViews with multiple vector fields require explicit names. + query_string: Text query. Ranking signal on Elasticsearch; logged + dropped + on Valkey (Valkey does not support text-as-ranking). + fusion_strategy: AUTO | RRF | WEIGHTED_LINEAR | VECTOR_ONLY. + signal_weights: Per-signal weights for WEIGHTED_LINEAR. + Keys are embedding field names and/or "bm25". + rrf_k: RRF rank constant (default 60). + distance_metric: Override distance metric. + include_signal_scores: When True, requests per-signal score + breakdowns for fusion strategies (RRF, WEIGHTED_LINEAR) at + additional latency cost. Currently a no-op — signal_scores + always follows the best-effort behavior documented in the V3 + design, and the parameter is plumbed through so callers can + opt in today and automatically pick up the explain-based + path when it lands. Default False. + """ + if not embeddings: + raise ValueError( + "V3 requires at least one embedding. " + "For text-only search, use retrieve_online_documents_v2." + ) + + effective_strategy = fusion_strategy.upper() + valid_strategies = {"AUTO", "RRF", "WEIGHTED_LINEAR", "VECTOR_ONLY"} + if effective_strategy not in valid_strategies: + raise ValueError( + f"Unknown fusion_strategy '{fusion_strategy}'. " + f"Must be one of: {', '.join(sorted(valid_strategies))}" + ) + + if effective_strategy == "VECTOR_ONLY": + query_string = None + + ( + available_feature_views, + available_odfv_views, + ) = utils._get_feature_views_to_use( + registry=self._registry, + project=self.project, + features=features, + allow_cache=True, + hide_dummy_entity=False, + ) + + feature_view_set = {f.split(":")[0] for f in features} + if len(feature_view_set) > 1: + raise ValueError("Document retrieval only supports a single feature view.") + + requested_features = [ + f.split(":")[1] for f in features if isinstance(f, str) and ":" in f + ] + + if not available_feature_views and not available_odfv_views: + raise ValueError(f"No feature view found for features {features}.") + + if not available_feature_views: + available_feature_views.extend(available_odfv_views) # type: ignore[arg-type] + + requested_feature_view = available_feature_views[0] + + if isinstance(requested_feature_view, OnDemandFeatureView): + raise ValueError( + "V3 vector search is not supported on OnDemandFeatureViews. " + "Use a regular FeatureView with vector-indexed fields." + ) + + RESERVED_NAMES = {"final_score", "signal_scores"} + collisions = set(requested_features) & RESERVED_NAMES + if collisions: + raise ValueError( + f"Feature names {sorted(collisions)} are reserved by V3 and cannot be " + f"requested directly. Rename your feature view fields or use V2." + ) + + # "embedding" magic key: auto-resolve to the FeatureView's single vector field + if len(embeddings) == 1 and list(embeddings.keys()) == ["embedding"]: + vector_fields = { + f.name: f for f in requested_feature_view.features if f.vector_index + } + if len(vector_fields) == 0: + raise ValueError( + f"FeatureView '{requested_feature_view.name}' has no vector-indexed " + f"fields. Cannot perform vector search." + ) + elif len(vector_fields) == 1: + actual_name = next(iter(vector_fields.keys())) + embeddings = {actual_name: embeddings["embedding"]} + else: + raise ValueError( + f"FeatureView '{requested_feature_view.name}' has multiple vector " + f"fields {sorted(vector_fields.keys())}. " + f"Specify the field name explicitly in embeddings." + ) + + provider = self._get_provider() + return self._retrieve_from_online_store_v3( + provider, + requested_feature_view, + requested_features, + embeddings, + top_k, + query_string, + effective_strategy, + signal_weights, + rrf_k, + distance_metric, + include_signal_scores, + ) + def _retrieve_from_online_store( self, provider: Provider, @@ -2691,6 +2822,89 @@ def _retrieve_from_online_store_v2( return OnlineResponse(online_features_response) + def _retrieve_from_online_store_v3( + self, + provider: Provider, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str], + fusion_strategy: str, + signal_weights: Optional[Dict[str, float]], + rrf_k: int, + distance_metric: Optional[str], + include_signal_scores: bool, + ) -> OnlineResponse: + documents = provider.retrieve_online_documents_v3( + config=self.config, + table=table, + requested_features=requested_features, + embeddings=embeddings, + top_k=top_k, + query_string=query_string, + fusion_strategy=fusion_strategy, + signal_weights=signal_weights, + rrf_k=rrf_k, + distance_metric=distance_metric, + include_signal_scores=include_signal_scores, + ) + + entity_key_dict: Dict[str, List[ValueProto]] = {} + datevals, list_of_feature_dicts = [], [] + for row_ts, entity_key, feature_dict in documents: # type: ignore[misc] + datevals.append(row_ts) + list_of_feature_dicts.append(feature_dict) + if entity_key: + for key, value in zip(entity_key.join_keys, entity_key.entity_values): + python_value = value + if key not in entity_key_dict: + entity_key_dict[key] = [] + entity_key_dict[key].append(python_value) + + features_to_request: List[str] = requested_features + [ + "final_score", + "signal_scores", + ] + + if not datevals: + online_features_response = GetOnlineFeaturesResponse(results=[]) + for _ in features_to_request: + field = online_features_response.results.add() + field.values.extend([]) + field.statuses.extend([]) + field.event_timestamps.extend([]) + online_features_response.metadata.feature_names.val.extend( + features_to_request + ) + return OnlineResponse(online_features_response) + + output_len = len(datevals) + idxs = tuple([i] for i in range(output_len)) + + feature_data = utils._convert_rows_to_protobuf( + requested_features=features_to_request, + read_rows=list(zip(datevals, list_of_feature_dicts)), + ) + + online_features_response = GetOnlineFeaturesResponse(results=[]) + utils._populate_response_from_feature_data( + feature_data=feature_data, + indexes=idxs, + online_features_response=online_features_response, + full_feature_names=False, + requested_features=features_to_request, + table=table, + output_len=output_len, + ) + + utils._populate_result_rows_from_columnar( + online_features_response=online_features_response, + data=entity_key_dict, + ) + + return OnlineResponse(online_features_response) + def _lazy_init_go_server(self): """Lazily initialize self._go_server if it hasn't been initialized before.""" from feast.embedded_go.online_features_service import ( diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 341c09c461a..c556e482743 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -237,10 +237,6 @@ def __init__( else: features.append(field) - assert len([f for f in features if f.vector_index]) < 2, ( - f"Only one vector feature is allowed per feature view. Please update {self.name}." - ) - # TODO(felixwang9817): Add more robust validation of features. if self.batch_source is not None: cols = [field.name for field in schema] diff --git a/sdk/python/feast/infra/online_stores/_signal_scores.py b/sdk/python/feast/infra/online_stores/_signal_scores.py new file mode 100644 index 00000000000..706c4cec87c --- /dev/null +++ b/sdk/python/feast/infra/online_stores/_signal_scores.py @@ -0,0 +1,18 @@ +import json +from typing import Dict + +from feast.protos.feast.types.Value_pb2 import Value as ValueProto + + +def encode_signal_scores(scores: Dict[str, float]) -> ValueProto: + """Encode a signal_scores dict as a JSON string in ValueProto.""" + val = ValueProto() + val.string_val = json.dumps(scores, separators=(",", ":"), sort_keys=True) + return val + + +def decode_signal_scores(value: ValueProto) -> Dict[str, float]: + """Decode a signal_scores ValueProto back to a dict.""" + if not value.HasField("string_val") or not value.string_val: + return {} + return json.loads(value.string_val) diff --git a/sdk/python/feast/infra/online_stores/eg_valkey.py b/sdk/python/feast/infra/online_stores/eg_valkey.py index 3ef14fc89a0..b3cbd66090e 100644 --- a/sdk/python/feast/infra/online_stores/eg_valkey.py +++ b/sdk/python/feast/infra/online_stores/eg_valkey.py @@ -42,6 +42,7 @@ deserialize_entity_key, serialize_entity_key, ) +from feast.infra.online_stores._signal_scores import encode_signal_scores from feast.infra.online_stores.helpers import _mmh3, _redis_key, _redis_key_prefix from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -1081,6 +1082,121 @@ def retrieve_online_documents_v2( search_results=search_results, ) + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + V3 document retrieval on Valkey backend. + + Valkey supports a subset of V3 features: + - Single embedding only (multi-embedding raises ValueError) + - AUTO and VECTOR_ONLY fusion strategies (others raise ValueError) + - query_string is silently dropped with a warning (Valkey cannot + use text as a ranking signal) + + Returns the same tuple shape as V2 with final_score and signal_scores + added to the feature dict. final_score is the raw Valkey + ``__distance__`` from ``FT.SEARCH KNN`` — lower = better for all + supported metrics (COSINE returns ``1 - cos``, L2 returns squared + distance, IP returns ``1 - inner_product``). Note this ordering is + the opposite of Elasticsearch's ``final_score``, which is a + relevance score where higher = better. + + Reserved parameters (accepted but currently unused): + - ``include_signal_scores``: No-op today. ``signal_scores`` is + populated best-effort (single-entry dict for the one vector + signal). Reserved so callers can opt in now and automatically + pick up a richer breakdown when the explain-based path lands. + """ + del include_signal_scores + valid_strategies = {"AUTO", "RRF", "WEIGHTED_LINEAR", "VECTOR_ONLY"} + effective_strategy = fusion_strategy.upper() + if effective_strategy not in valid_strategies: + raise ValueError( + f"Unknown fusion_strategy '{fusion_strategy}'. " + f"Valid options: {sorted(valid_strategies)}" + ) + + if not embeddings: + raise ValueError( + "V3 requires at least one embedding. " + "Pass embeddings={field_name: vector}." + ) + + if len(embeddings) > 1: + raise ValueError( + "Multi-vector fusion requires the Elasticsearch backend. " + "Valkey supports single-vector search only. " + "Use a single embedding or switch to the Elasticsearch online store." + ) + + if effective_strategy in ("RRF", "WEIGHTED_LINEAR"): + raise ValueError( + f"Fusion strategy '{effective_strategy}' is not supported on Valkey. " + "Use fusion_strategy='AUTO' or 'VECTOR_ONLY', " + "or switch to the Elasticsearch backend for fusion support." + ) + + if query_string is not None and effective_strategy != "VECTOR_ONLY": + logger.warning( + "query_string is being dropped — Valkey backend does not support " + "text search as a ranking signal. To use text as a ranking signal, " + "switch to the Elasticsearch backend." + ) + + embed_key, embed_vector = next(iter(embeddings.items())) + + v2_results = self.retrieve_online_documents_v2( + config=config, + table=table, + requested_features=requested_features, + embedding=embed_vector, + top_k=top_k, + distance_metric=distance_metric, + query_string=None, # Valkey does not support query_string + ) + + v3_results: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + for timestamp, entity_key_proto, feature_dict in v2_results: + if feature_dict is None: + v3_results.append((timestamp, entity_key_proto, None)) + continue + + distance_val = feature_dict.pop("distance", None) + if distance_val is not None and distance_val.HasField("double_val"): + feature_dict["final_score"] = distance_val + signal_scores = {f"vec_{embed_key}": distance_val.double_val} + else: + signal_scores = {} + + feature_dict["signal_scores"] = encode_signal_scores(signal_scores) + v3_results.append((timestamp, entity_key_proto, feature_dict)) + + return v3_results + def _get_vector_field_for_search( self, table: FeatureView, diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index 79751cea5b7..a1f2fcb73b3 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -17,6 +17,7 @@ get_list_val_str, serialize_entity_key, ) +from feast.infra.online_stores._signal_scores import encode_signal_scores from feast.infra.online_stores.online_store import OnlineStore from feast.infra.online_stores.vector_store import VectorStoreConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -710,6 +711,223 @@ def retrieve_online_documents_v2( result.append((timestamp, entity_key_proto, feature_dict)) return result + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + V3 document retrieval on Elasticsearch backend. + + Uses the ES retriever API (ES 8.14+) for all query types: single-signal + kNN, multi-signal RRF, and weighted linear fusion. + + Reserved output fields (always present in each result's feature_dict): + - ``final_score``: ES _score (higher = better). For single-signal this + is the raw kNN score; for fusion it is the rank-based composite score. + - ``signal_scores``: JSON-encoded Dict[str, float] with per-signal + scores when available, empty dict for fused results (ES does not + expose per-retriever scores after fusion). + + Reserved parameters (accepted but currently unused): + - ``distance_metric``: V3-ES always uses the metric configured in the + index mapping; this param is reserved for future per-query override. + - ``include_signal_scores``: No-op today. ``signal_scores`` follows + best-effort behavior — populated for single-signal queries, empty + for RRF/WEIGHTED_LINEAR fusion (ES does not expose per-retriever + scores after fusion). Reserved for a future ES-explain path that + will populate the breakdown for fusion strategies at extra latency + cost. + """ + del distance_metric + del include_signal_scores + + valid_strategies = {"AUTO", "RRF", "WEIGHTED_LINEAR", "VECTOR_ONLY"} + effective_strategy = fusion_strategy.upper() + if effective_strategy not in valid_strategies: + raise ValueError( + f"Unknown fusion_strategy '{fusion_strategy}'. " + f"Valid options: {sorted(valid_strategies)}" + ) + + if not embeddings: + raise ValueError( + "V3 requires at least one embedding. " + "Pass embeddings={field_name: vector}." + ) + + if not config.online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config.") + + if effective_strategy == "VECTOR_ONLY": + query_string = None + + # Normalize empty/whitespace query_string to None + if query_string is not None and not query_string.strip(): + query_string = None + + # Validate embedding keys against FeatureView schema + vector_fields = {f.name: f for f in table.features if f.vector_index} + for key in embeddings: + if key not in vector_fields: + available = sorted(vector_fields.keys()) + if not available: + raise ValueError( + f"FeatureView '{table.name}' has no vector-indexed fields. " + f"Cannot perform vector search." + ) + raise ValueError( + f"Embedding key '{key}' does not match any vector-indexed field " + f"in FeatureView '{table.name}'. " + f"Available vector fields: {available}" + ) + + # Build retrievers: one kNN per embedding, optional BM25 for query_string + retrievers_with_names: List[Tuple[str, Dict[str, Any]]] = [] + for field_name, vec in embeddings.items(): + knn_retriever: Dict[str, Any] = { + "knn": { + "field": f"{field_name}.vector_value", + "query_vector": vec, + } + } + retrievers_with_names.append((field_name, knn_retriever)) + + has_text_signal = query_string is not None + if has_text_signal: + text_retriever: Dict[str, Any] = { + "standard": {"query": {"query_string": {"query": query_string}}} + } + retrievers_with_names.append(("bm25", text_retriever)) + + is_single_signal = len(retrievers_with_names) == 1 + + if is_single_signal and effective_strategy in ("RRF", "WEIGHTED_LINEAR"): + logger.warning( + "Only one signal present — fusion_strategy '%s' has no effect. " + "The query will execute as a single-signal retrieval.", + effective_strategy, + ) + + # Set inner k based on signal count + multiplier = ( + getattr(config.online_store, "knn_num_candidates_multiplier", 2.0) or 2.0 + ) + if is_single_signal: + inner_k = top_k + else: + inner_k = min(max(top_k * 10, 100), 1000) + num_candidates = max(inner_k, math.ceil(inner_k * multiplier)) + + for _, retriever in retrievers_with_names: + if "knn" in retriever: + retriever["knn"]["k"] = inner_k + retriever["knn"]["num_candidates"] = num_candidates + + # Resolve execution mode + if is_single_signal: + execution_mode = "single" + elif effective_strategy == "WEIGHTED_LINEAR": + execution_mode = "linear" + else: + execution_mode = "rrf" + + # Validate WEIGHTED_LINEAR signal coverage + if execution_mode == "linear": + expected_signals = {name for name, _ in retrievers_with_names} + provided = set(signal_weights.keys()) if signal_weights else set() + missing = expected_signals - provided + if missing: + raise ValueError( + f"WEIGHTED_LINEAR fusion missing weights for signals: " + f"{sorted(missing)}. Provide a weight for each signal: " + f"embedding field names and/or 'bm25'." + ) + + # Compose query body + retrievers = [r for _, r in retrievers_with_names] + composite_key_name = _get_composite_key_name(table) + source_fields = requested_features.copy() + source_fields += ["entity_key", "timestamp"] + source_fields += composite_key_name + + if execution_mode == "single": + top_retriever = retrievers[0] + elif execution_mode == "rrf": + top_retriever = {"rrf": {"retrievers": retrievers, "rank_constant": rrf_k}} + else: + assert signal_weights is not None + weighted = [] + for signal_name, retriever in retrievers_with_names: + weight = signal_weights[signal_name] + weighted.append({"retriever": retriever, "weight": weight}) + top_retriever = {"linear": {"retrievers": weighted}} + + body: Dict[str, Any] = { + "retriever": top_retriever, + "size": top_k, + "_source": source_fields, + } + + response = self._get_client(config).search(index=table.name, body=body) + + # Parse results + result: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + + rows = response["hits"]["hits"][:top_k] + for row in rows: + entity_key = row["_source"]["entity_key"] + entity_key_proto = deserialize_entity_key( + base64.b64decode(entity_key), + entity_key_serialization_version=config.entity_key_serialization_version, + ) + timestamp = datetime.fromisoformat(row["_source"]["timestamp"]) + + feature_dict: Dict[str, ValueProto] = {} + feature_dict["final_score"] = _to_value_proto(float(row["_score"])) + + signal_scores: Dict[str, float] = {} + if is_single_signal: + embed_key = next(iter(embeddings.keys())) + signal_scores[f"vec_{embed_key}"] = float(row["_score"]) + + feature_dict["signal_scores"] = encode_signal_scores(signal_scores) + + join_key_values = _extract_join_keys(entity_key_proto) + feature_dict.update(join_key_values) + + for feature in requested_features: + if feature in ("final_score", "signal_scores"): + continue + value = row["_source"].get(feature, None) + if value is not None: + feature_dict[feature] = _to_value_proto(value) + + result.append((timestamp, entity_key_proto, feature_dict)) + + return result + def _build_vector_mapping( config: RepoConfig, vector_field_length: int, table_name: str diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index b77185229d5..500297309b5 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -467,6 +467,31 @@ def retrieve_online_documents_v2( f"Online store {self.__class__.__name__} does not support online retrieval" ) + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + raise NotImplementedError( + f"Online store {self.__class__.__name__} does not support " + f"V3 document retrieval" + ) + async def initialize(self, config: RepoConfig) -> None: pass diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index ec2b05759ba..79cd2a1073e 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -329,6 +329,31 @@ def retrieve_online_documents_v2( logger.error(error_msg) raise RuntimeError(error_msg) + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + raise NotImplementedError( + "V3 document retrieval is not yet supported via the remote online store. " + "Use the SDK directly against a local online store." + ) + def _extract_requested_feature_value( self, response_json: dict, diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 26d3ca3d6bf..f0475e511cb 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -435,6 +435,37 @@ def retrieve_online_documents_v2( ) return result + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: Optional[List[str]], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List: + result = [] + if self.online_store: + result = self.online_store.retrieve_online_documents_v3( + config, + table, + requested_features, + embeddings, + top_k, + query_string, + fusion_strategy, + signal_weights, + rrf_k, + distance_metric, + include_signal_scores, + ) + return result + @staticmethod def _prep_table_and_join_keys_for_ingestion( feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 3255e34de4c..b88af17efea 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -496,6 +496,29 @@ def retrieve_online_documents_v2( """ pass + @abstractmethod + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + pass + @abstractmethod def validate_data_source( self, diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 8b076af92d4..af55266c682 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1421,10 +1421,6 @@ def _get_feature_view_vector_field_metadata( feature_view, ) -> Optional[Field]: vector_fields = [field for field in feature_view.schema if field.vector_index] - if len(vector_fields) > 1: - raise ValueError( - f"Feature view {feature_view.name} has multiple vector fields. Only one vector field per feature view is supported." - ) if not vector_fields: return None return vector_fields[0] diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index f8396acc2df..4706ec9f49f 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -184,6 +184,28 @@ def retrieve_online_documents_v2( ]: return [] + def retrieve_online_documents_v3( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embeddings: Dict[str, List[float]], + top_k: int, + query_string: Optional[str] = None, + fusion_strategy: str = "AUTO", + signal_weights: Optional[Dict[str, float]] = None, + rrf_k: int = 60, + distance_metric: Optional[str] = None, + include_signal_scores: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + return [] + def validate_data_source( self, config: RepoConfig, diff --git a/sdk/python/tests/unit/infra/online_store/test_valkey.py b/sdk/python/tests/unit/infra/online_store/test_valkey.py index f172838a846..ed911bcb92a 100644 --- a/sdk/python/tests/unit/infra/online_store/test_valkey.py +++ b/sdk/python/tests/unit/infra/online_store/test_valkey.py @@ -1732,3 +1732,424 @@ def test_default_distance_is_infinity_not_zero(self, store): assert len(results) == 1 doc_key, distance = results[0] assert distance == float("inf") + + +class TestRetrieveOnlineDocumentsV3Validation: + """Tests for retrieve_online_documents_v3 input validation on Valkey.""" + + @pytest.fixture + def store(self): + return EGValkeyOnlineStore() + + @pytest.fixture + def feature_view_with_vector(self): + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + Field(name="title", dtype=String), + ], + ) + + @pytest.fixture + def feature_view_multi_vector(self): + return FeatureView( + name="test_fv_multi", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="title_vec", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + Field( + name="body_vec", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def repo_config(self): + return RepoConfig( + project="test_project", + provider="local", + registry="test_registry.db", + online_store=EGValkeyOnlineStoreConfig( + type="eg-valkey", + connection_string="localhost:6379", + ), + entity_key_serialization_version=3, + ) + + def test_empty_embeddings_raises( + self, store, repo_config, feature_view_with_vector + ): + with pytest.raises(ValueError, match="at least one embedding"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={}, + top_k=5, + ) + + def test_multi_embedding_raises( + self, store, repo_config, feature_view_multi_vector + ): + with pytest.raises(ValueError, match="single-vector search only"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_multi_vector, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + + def test_rrf_strategy_raises(self, store, repo_config, feature_view_with_vector): + with pytest.raises(ValueError, match="not supported on Valkey"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="RRF", + ) + + def test_weighted_linear_strategy_raises( + self, store, repo_config, feature_view_with_vector + ): + with pytest.raises(ValueError, match="not supported on Valkey"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"embedding": 1.0}, + ) + + def test_unknown_fusion_strategy_raises( + self, store, repo_config, feature_view_with_vector + ): + with pytest.raises(ValueError, match="Unknown fusion_strategy"): + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="BOGUS", + ) + + def test_auto_strategy_accepted(self, store, repo_config, feature_view_with_vector): + """AUTO should not raise — it delegates to V2.""" + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.5)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="AUTO", + ) + assert len(results) == 1 + + def test_vector_only_strategy_accepted( + self, store, repo_config, feature_view_with_vector + ): + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.3)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="VECTOR_ONLY", + ) + assert len(results) == 1 + + def test_query_string_warns_and_dropped( + self, store, repo_config, feature_view_with_vector + ): + """query_string should trigger a logger.warning and be passed as None to V2.""" + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.5)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ) as mock_v2: + store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="test query", + fusion_strategy="AUTO", + ) + # Verify query_string=None was passed to V2 + call_kwargs = mock_v2.call_args[1] + assert call_kwargs.get("query_string") is None + + def test_include_signal_scores_accepted( + self, store, repo_config, feature_view_with_vector + ): + from unittest.mock import patch + + mock_results = [] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view_with_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + include_signal_scores=False, + ) + assert results == [] + + +class TestRetrieveOnlineDocumentsV3ResponseTransform: + """Tests for V3 response transformation on Valkey (V2→V3 wrapper).""" + + @pytest.fixture + def store(self): + return EGValkeyOnlineStore() + + @pytest.fixture + def feature_view(self): + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + Field(name="title", dtype=String), + ], + ) + + @pytest.fixture + def repo_config(self): + return RepoConfig( + project="test_project", + provider="local", + registry="test_registry.db", + online_store=EGValkeyOnlineStoreConfig( + type="eg-valkey", + connection_string="localhost:6379", + ), + entity_key_serialization_version=3, + ) + + def test_distance_renamed_to_final_score(self, store, repo_config, feature_view): + from unittest.mock import patch + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + { + "distance": ValueProto(double_val=0.25), + "title": ValueProto(string_val="hello"), + }, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert len(results) == 1 + ts, ek, feat_dict = results[0] + assert "final_score" in feat_dict + assert "distance" not in feat_dict + assert feat_dict["final_score"].double_val == pytest.approx(0.25) + + def test_signal_scores_populated(self, store, repo_config, feature_view): + from unittest.mock import patch + + from feast.infra.online_stores._signal_scores import decode_signal_scores + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"distance": ValueProto(double_val=0.5)}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + assert "signal_scores" in feat_dict + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert "vec_embedding" in scores + assert scores["vec_embedding"] == pytest.approx(0.5) + + def test_none_feature_dict_passthrough(self, store, repo_config, feature_view): + from unittest.mock import patch + + mock_results = [ + (datetime(2024, 1, 1), None, None), + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert len(results) == 1 + assert results[0][2] is None + + def test_missing_distance_no_final_score(self, store, repo_config, feature_view): + """If V2 returns no distance, signal_scores should be empty.""" + from unittest.mock import patch + + from feast.infra.online_stores._signal_scores import decode_signal_scores + + mock_results = [ + ( + datetime(2024, 1, 1), + EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ), + {"title": ValueProto(string_val="test")}, + ) + ] + with patch.object( + store, "retrieve_online_documents_v2", return_value=mock_results + ): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert scores == {} + + def test_empty_v2_results(self, store, repo_config, feature_view): + from unittest.mock import patch + + with patch.object(store, "retrieve_online_documents_v2", return_value=[]): + results = store.retrieve_online_documents_v3( + config=repo_config, + table=feature_view, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert results == [] diff --git a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py index 12f77eab9b0..68404bb3835 100644 --- a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py +++ b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py @@ -1,8 +1,17 @@ import base64 +import json +import math +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch import pytest +from feast import Entity, FeatureView, RepoConfig +from feast.field import Field +from feast.infra.online_stores._signal_scores import decode_signal_scores from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import ( + ElasticSearchOnlineStore, + ElasticSearchOnlineStoreConfig, _encode_feature_value, _to_value_proto, ) @@ -13,6 +22,8 @@ from feast.protos.feast.types.Value_pb2 import ( Value as ValueProto, ) +from feast.types import Array, Float32, Int64, String +from feast.value_type import ValueType class TestEncodeFeatureValue: @@ -69,6 +80,695 @@ def test_default_is_vector_false(self): assert "vector_value" not in result +def _make_feature_view( + name="test_fv", + vector_fields=None, + extra_fields=None, +): + """Helper to build a FeatureView with optional vector fields.""" + from feast import FileSource + + schema = [Field(name="item_id", dtype=Int64)] + if vector_fields is None: + vector_fields = [("embedding", 4)] + for fname, dim in vector_fields: + schema.append( + Field( + name=fname, + dtype=Array(Float32), + vector_index=True, + vector_length=dim, + vector_search_metric="COSINE", + ) + ) + for fname, dtype in extra_fields or []: + schema.append(Field(name=fname, dtype=dtype)) + + return FeatureView( + name=name, + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=schema, + ) + + +_repo_config_counter = 0 + + +def _make_repo_config(vector_enabled=True, **overrides): + """Helper to build a RepoConfig with ES online store.""" + global _repo_config_counter + _repo_config_counter += 1 + es_config = ElasticSearchOnlineStoreConfig( + type="elasticsearch", + host="localhost", + port=9200, + vector_enabled=vector_enabled, + **overrides, + ) + return RepoConfig( + project="test_project", + provider="local", + registry=f"/tmp/test_registry_{_repo_config_counter}.db", + online_store=es_config, + entity_key_serialization_version=3, + ) + + +class TestRetrieveOnlineDocumentsV3Validation: + """Tests for retrieve_online_documents_v3 input validation.""" + + @pytest.fixture + def store(self): + return ElasticSearchOnlineStore() + + @pytest.fixture + def config(self): + return _make_repo_config() + + @pytest.fixture + def fv_single_vector(self): + return _make_feature_view( + vector_fields=[("embedding", 4)], + extra_fields=[("title", String)], + ) + + @pytest.fixture + def fv_multi_vector(self): + return _make_feature_view( + vector_fields=[("title_vec", 4), ("body_vec", 4)], + ) + + @pytest.fixture + def fv_no_vector(self): + return _make_feature_view(vector_fields=[]) + + def test_empty_embeddings_raises(self, store, config, fv_single_vector): + with pytest.raises(ValueError, match="at least one embedding"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={}, + top_k=5, + ) + + def test_vector_not_enabled_raises(self, store, fv_single_vector): + config = _make_repo_config(vector_enabled=False) + with pytest.raises(ValueError, match="not enabled"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + def test_unknown_fusion_strategy_raises(self, store, config, fv_single_vector): + with pytest.raises(ValueError, match="Unknown fusion_strategy"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="INVALID", + ) + + def test_unknown_embedding_key_raises(self, store, config, fv_single_vector): + with pytest.raises(ValueError, match="does not match any vector-indexed"): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"nonexistent_field": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + def test_no_vector_fields_raises(self, store, config, fv_no_vector): + with pytest.raises(ValueError, match="no vector-indexed fields"): + store.retrieve_online_documents_v3( + config=config, + table=fv_no_vector, + requested_features=["item_id"], + embeddings={"some_field": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + def test_weighted_linear_missing_weights_raises( + self, store, config, fv_multi_vector + ): + with pytest.raises(ValueError, match="missing weights for signals"): + store.retrieve_online_documents_v3( + config=config, + table=fv_multi_vector, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + query_string="test", + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"title_vec": 0.5}, + ) + + def test_weighted_linear_partial_weights_raises( + self, store, config, fv_multi_vector + ): + """Missing bm25 weight when query_string is present.""" + with pytest.raises(ValueError, match=r"missing weights for signals.*\bbm25\b"): + store.retrieve_online_documents_v3( + config=config, + table=fv_multi_vector, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + query_string="test", + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"title_vec": 0.5, "body_vec": 0.3}, + ) + + def test_vector_only_nullifies_query_string(self, store, config, fv_single_vector): + """VECTOR_ONLY should drop query_string before building retrievers.""" + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="should be dropped", + fusion_strategy="VECTOR_ONLY", + ) + + call_body = mock_client.search.call_args[1]["body"] + retriever = call_body["retriever"] + assert "knn" in retriever, "VECTOR_ONLY should produce a knn retriever" + assert "standard" not in json.dumps(retriever) + assert "rrf" not in retriever + + def test_empty_query_string_treated_as_none(self, store, config, fv_single_vector): + """Whitespace-only query_string should not create a BM25 retriever.""" + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string=" ", + ) + + call_body = mock_client.search.call_args[1]["body"] + retriever = call_body["retriever"] + assert "knn" in retriever + assert "standard" not in json.dumps(retriever) + + @pytest.mark.parametrize( + "strategy", ["auto", "Auto", "AUTO", "rrf", "Rrf", "vector_only"] + ) + def test_strategy_case_insensitive(self, store, config, fv_single_vector, strategy): + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy=strategy, + ) + mock_client.search.assert_called_once() + + @pytest.mark.parametrize("flag", [True, False]) + def test_include_signal_scores_accepted_but_ignored( + self, store, config, fv_single_vector, flag + ): + """include_signal_scores is a reserved param; should not raise for True or False.""" + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3( + config=config, + table=fv_single_vector, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + include_signal_scores=flag, + ) + + +class TestRetrieveOnlineDocumentsV3QueryBuilding: + """Tests for the ES query body construction.""" + + @pytest.fixture + def store(self): + return ElasticSearchOnlineStore() + + @pytest.fixture + def config(self): + return _make_repo_config() + + @pytest.fixture + def fv_single(self): + return _make_feature_view( + vector_fields=[("embedding", 4)], + extra_fields=[("title", String)], + ) + + @pytest.fixture + def fv_multi(self): + return _make_feature_view( + vector_fields=[("title_vec", 4), ("body_vec", 4)], + ) + + def _call_and_capture_body(self, store, config, table, **kwargs): + mock_client = MagicMock() + mock_client.search.return_value = {"hits": {"hits": []}} + with patch.object(store, "_get_client", return_value=mock_client): + store.retrieve_online_documents_v3(config=config, table=table, **kwargs) + return mock_client.search.call_args[1]["body"] + + def test_single_vector_uses_knn_retriever(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + retriever = body["retriever"] + assert "knn" in retriever + assert retriever["knn"]["field"] == "embedding.vector_value" + assert retriever["knn"]["query_vector"] == [0.1, 0.2, 0.3, 0.4] + assert retriever["knn"]["k"] == 5 + assert body["size"] == 5 + + def test_single_vector_knn_k_equals_top_k(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=10, + ) + assert body["retriever"]["knn"]["k"] == 10 + + def test_multi_vector_uses_rrf_by_default(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + retriever = body["retriever"] + assert "rrf" in retriever + assert len(retriever["rrf"]["retrievers"]) == 2 + + def test_multi_vector_rrf_has_rank_constant(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + rrf_k=42, + ) + assert body["retriever"]["rrf"]["rank_constant"] == 42 + + def test_query_string_adds_bm25_retriever(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="search term", + ) + retriever = body["retriever"] + assert "rrf" in retriever + retrievers = retriever["rrf"]["retrievers"] + assert len(retrievers) == 2 + retriever_types = [list(r.keys())[0] for r in retrievers] + assert "knn" in retriever_types + assert "standard" in retriever_types + + def test_single_vector_plus_bm25_uses_rrf(self, store, config, fv_single): + """Single vector + query_string should produce RRF with knn + standard retrievers.""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + query_string="search term", + fusion_strategy="RRF", + ) + retriever = body["retriever"] + assert "rrf" in retriever + retrievers = retriever["rrf"]["retrievers"] + assert len(retrievers) == 2 + types = {list(r.keys())[0] for r in retrievers} + assert types == {"knn", "standard"} + for r in retrievers: + if "knn" in r: + assert r["knn"]["field"] == "embedding.vector_value" + if "standard" in r: + assert r["standard"]["query"]["query_string"]["query"] == "search term" + + def test_weighted_linear_uses_linear_retriever(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + fusion_strategy="WEIGHTED_LINEAR", + signal_weights={"title_vec": 0.7, "body_vec": 0.3}, + ) + retriever = body["retriever"] + assert "linear" in retriever + weighted = retriever["linear"]["retrievers"] + assert len(weighted) == 2 + weights = [w["weight"] for w in weighted] + assert 0.7 in weights + assert 0.3 in weights + + def test_multi_signal_inner_k_larger_than_top_k(self, store, config, fv_multi): + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + for r in body["retriever"]["rrf"]["retrievers"]: + if "knn" in r: + assert r["knn"]["k"] >= 100 + assert r["knn"]["k"] <= 1000 + + def test_num_candidates_uses_math_ceil(self, store, config, fv_single): + """Verify math.ceil is applied by using a multiplier that produces a fraction.""" + object.__setattr__(config.online_store, "knn_num_candidates_multiplier", 1.5) + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=3, + ) + k = body["retriever"]["knn"]["k"] + num_candidates = body["retriever"]["knn"]["num_candidates"] + # 3 * 1.5 = 4.5 → ceil → 5, proving ceil is used (floor would give 4) + assert num_candidates == math.ceil(k * 1.5) + assert num_candidates == 5 + assert num_candidates != int(k * 1.5) + + def test_rrf_single_signal_executes_as_single(self, store, config, fv_single): + """RRF with only one signal should still succeed (logged warning, not error).""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="RRF", + ) + retriever = body["retriever"] + assert "knn" in retriever, "Single signal RRF degrades to single retriever" + assert "rrf" not in retriever + + def test_auto_single_signal_uses_direct_knn(self, store, config, fv_single): + """AUTO with one vector and no query_string should produce a bare knn retriever.""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + fusion_strategy="AUTO", + ) + retriever = body["retriever"] + assert "knn" in retriever, "Single signal AUTO should use direct knn" + assert "rrf" not in retriever + assert "linear" not in retriever + + def test_source_fields_include_metadata(self, store, config, fv_single): + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + source = body["_source"] + assert "entity_key" in source + assert "timestamp" in source + assert "title" in source + + +class TestRetrieveOnlineDocumentsV3ResponseParsing: + """Tests for parsing ES response into V3 result tuples.""" + + @pytest.fixture + def store(self): + return ElasticSearchOnlineStore() + + @pytest.fixture + def config(self): + return _make_repo_config() + + @pytest.fixture + def fv(self): + return _make_feature_view( + vector_fields=[("embedding", 4)], + extra_fields=[("title", String)], + ) + + def _mock_es_response(self, hits): + return {"hits": {"hits": hits}} + + def _make_hit(self, score, timestamp, features=None): + from feast.infra.key_encoding_utils import serialize_entity_key + from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto + + ek = EntityKeyProto( + join_keys=["item_id"], + entity_values=[ValueProto(int64_val=1)], + ) + ek_bytes = serialize_entity_key(ek, entity_key_serialization_version=3) + ek_b64 = base64.b64encode(ek_bytes).decode("utf-8") + source = { + "entity_key": ek_b64, + "timestamp": timestamp, + } + if features: + source.update(features) + return {"_source": source, "_score": score} + + def test_single_result_has_final_score(self, store, config, fv): + hit = self._make_hit(0.95, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert len(results) == 1 + ts, ek, feat_dict = results[0] + assert feat_dict["final_score"].float_val == pytest.approx(0.95) + + def test_single_result_has_signal_scores(self, store, config, fv): + hit = self._make_hit(0.95, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert "vec_embedding" in scores + assert scores["vec_embedding"] == pytest.approx(0.95) + + def test_signal_scores_is_compact_sorted_json(self, store, config, fv): + """signal_scores should be compact JSON with sorted keys.""" + hit = self._make_hit(0.95, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + raw = results[0][2]["signal_scores"].string_val + assert " " not in raw + parsed = json.loads(raw) + assert list(parsed.keys()) == sorted(parsed.keys()) + + def test_multi_signal_signal_scores_are_empty(self, store, config): + fv = _make_feature_view( + vector_fields=[("title_vec", 4), ("body_vec", 4)], + ) + hit = self._make_hit(0.8, "2024-01-01T00:00:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + + feat_dict = results[0][2] + scores = decode_signal_scores(feat_dict["signal_scores"]) + assert scores == {} + + def test_empty_results(self, store, config, fv): + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + assert results == [] + + def test_timestamp_parsed(self, store, config, fv): + hit = self._make_hit(0.9, "2024-06-15T12:30:00") + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + ts = results[0][0] + assert isinstance(ts, datetime) + assert ts.year == 2024 + assert ts.month == 6 + + def test_top_k_limits_results(self, store, config, fv): + """Verify that at most top_k results are returned even if ES returns more.""" + hits = [self._make_hit(0.9 - i * 0.1, "2024-01-01T00:00:00") for i in range(5)] + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response(hits) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=3, + ) + + assert len(results) <= 3 + body = mock_client.search.call_args[1]["body"] + assert body["size"] == 3 + + def test_feature_values_included(self, store, config, fv): + encoded_val = base64.b64encode( + ValueProto(string_val="hello world").SerializeToString() + ).decode("utf-8") + hit = self._make_hit( + 0.9, + "2024-01-01T00:00:00", + features={"title": {"feature_value": encoded_val}}, + ) + mock_client = MagicMock() + mock_client.search.return_value = self._mock_es_response([hit]) + + with patch.object(store, "_get_client", return_value=mock_client): + results = store.retrieve_online_documents_v3( + config=config, + table=fv, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + + feat_dict = results[0][2] + assert "title" in feat_dict + assert feat_dict["title"].string_val == "hello world" + + class TestElasticSearchOnlineStoreConfig: def test_defaults(self): """Test default config values.""" diff --git a/sdk/python/tests/unit/online_store/test_signal_scores.py b/sdk/python/tests/unit/online_store/test_signal_scores.py new file mode 100644 index 00000000000..4b426202d3c --- /dev/null +++ b/sdk/python/tests/unit/online_store/test_signal_scores.py @@ -0,0 +1,76 @@ +import json + +import pytest + +from feast.infra.online_stores._signal_scores import ( + decode_signal_scores, + encode_signal_scores, +) +from feast.protos.feast.types.Value_pb2 import Value as ValueProto + + +class TestEncodeSignalScores: + def test_single_score(self): + result = encode_signal_scores({"vec_embedding": 0.95}) + assert result.HasField("string_val") + parsed = json.loads(result.string_val) + assert parsed == {"vec_embedding": 0.95} + + def test_multiple_scores(self): + scores = {"vec_title": 0.8, "vec_body": 0.6, "bm25": 12.5} + result = encode_signal_scores(scores) + parsed = json.loads(result.string_val) + assert parsed == scores + + def test_empty_dict(self): + result = encode_signal_scores({}) + assert result.string_val == "{}" + + def test_sort_keys_deterministic(self): + result_a = encode_signal_scores({"z_field": 1.0, "a_field": 2.0}) + result_b = encode_signal_scores({"a_field": 2.0, "z_field": 1.0}) + assert result_a.string_val == result_b.string_val + parsed = json.loads(result_a.string_val) + keys = list(parsed.keys()) + assert keys == sorted(keys) + + def test_compact_json_no_spaces(self): + result = encode_signal_scores({"a": 1.0, "b": 2.0}) + assert " " not in result.string_val + + def test_returns_value_proto(self): + result = encode_signal_scores({"x": 1.0}) + assert isinstance(result, ValueProto) + + +class TestDecodeSignalScores: + def test_roundtrip(self): + original = {"vec_embedding": 0.95, "bm25": 12.5} + encoded = encode_signal_scores(original) + decoded = decode_signal_scores(encoded) + assert decoded == original + + def test_roundtrip_empty(self): + encoded = encode_signal_scores({}) + decoded = decode_signal_scores(encoded) + assert decoded == {} + + def test_empty_string_val(self): + val = ValueProto() + val.string_val = "" + assert decode_signal_scores(val) == {} + + def test_no_string_field(self): + val = ValueProto() + val.int64_val = 42 + assert decode_signal_scores(val) == {} + + def test_default_value_proto(self): + val = ValueProto() + assert decode_signal_scores(val) == {} + + def test_malformed_json_raises(self): + val = ValueProto() + val.string_val = "not-json" + with pytest.raises(json.JSONDecodeError): + decode_signal_scores(val) From f147877e063893e19a06a331c6b66ea5d394963c Mon Sep 17 00:00:00 2001 From: Manisha Sudhir <30449541+Manisha4@users.noreply.github.com> Date: Thu, 7 May 2026 12:03:48 -0700 Subject: [PATCH 7/7] fix: apply rescore_oversample to V3 ES kNN retrievers (#363) V3's retriever construction was silently ignoring rescore_oversample. V2 honored it (lines 483-486, 617-620) but V3 never added rescore_vector to its kNN clauses. On quantized indices (int8_hnsw / int4_hnsw / bbq_hnsw), this meant V3 queries returned lower recall than the config promised, with no error or warning. Wire rescore_oversample into each kNN retriever the same way V2 does. Covers single-vector and multi-vector V3 queries; BM25 retrievers skip the branch since they lack a "knn" key. Existing config validation (lines 102-127) already prevents rescore on non-quantized indices, so no new validation needed. Added three unit tests in TestRetrieveOnlineDocumentsV3QueryBuilding: - rescore_vector appears in single-vector query body when configured - rescore_vector appears on every kNN retriever in multi-vector query - rescore_vector absent when rescore_oversample is None Co-authored-by: Manisha4 Co-authored-by: Claude Opus 4.7 --- .../elasticsearch.py | 5 ++ .../test_elasticsearch_online_store.py | 54 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index a1f2fcb73b3..0895a144625 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -834,10 +834,15 @@ def retrieve_online_documents_v3( inner_k = min(max(top_k * 10, 100), 1000) num_candidates = max(inner_k, math.ceil(inner_k * multiplier)) + rescore_oversample = config.online_store.rescore_oversample for _, retriever in retrievers_with_names: if "knn" in retriever: retriever["knn"]["k"] = inner_k retriever["knn"]["num_candidates"] = num_candidates + if rescore_oversample is not None: + retriever["knn"]["rescore_vector"] = { + "oversample": rescore_oversample + } # Resolve execution mode if is_single_signal: diff --git a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py index 68404bb3835..81c58666450 100644 --- a/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py +++ b/sdk/python/tests/unit/online_store/test_elasticsearch_online_store.py @@ -569,6 +569,60 @@ def test_source_fields_include_metadata(self, store, config, fv_single): assert "timestamp" in source assert "title" in source + def test_rescore_oversample_applied_to_single_knn(self, store, fv_single): + """When rescore_oversample is configured on a quantized index, the V3 + kNN retriever should include a rescore_vector clause.""" + config = _make_repo_config( + vector_index_type="int8_hnsw", rescore_oversample=3.0 + ) + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + knn = body["retriever"]["knn"] + assert knn["rescore_vector"] == {"oversample": 3.0} + + def test_rescore_oversample_applied_to_all_multi_vector_knns(self, store, fv_multi): + """Multi-vector V3 queries should apply rescore_vector to every kNN + retriever, not just the first one.""" + config = _make_repo_config( + vector_index_type="int4_hnsw", rescore_oversample=2.5 + ) + body = self._call_and_capture_body( + store, + config, + fv_multi, + requested_features=["item_id"], + embeddings={ + "title_vec": [0.1, 0.2, 0.3, 0.4], + "body_vec": [0.5, 0.6, 0.7, 0.8], + }, + top_k=5, + ) + retrievers = body["retriever"]["rrf"]["retrievers"] + assert len(retrievers) == 2 + for r in retrievers: + assert r["knn"]["rescore_vector"] == {"oversample": 2.5} + + def test_rescore_oversample_absent_when_not_configured( + self, store, config, fv_single + ): + """Default config has no rescore_oversample; the kNN clause should not + include rescore_vector.""" + body = self._call_and_capture_body( + store, + config, + fv_single, + requested_features=["title"], + embeddings={"embedding": [0.1, 0.2, 0.3, 0.4]}, + top_k=5, + ) + assert "rescore_vector" not in body["retriever"]["knn"] + class TestRetrieveOnlineDocumentsV3ResponseParsing: """Tests for parsing ES response into V3 result tuples."""