Skip to content

Commit b36acb7

Browse files
YassinNouh21ntkathole
authored andcommitted
feat: Add feature view versioning support to FAISS online store
When enable_online_feature_view_versioning is enabled, FAISS indices are namespaced by versioned table keys (e.g. project_driver_stats_v2) so multiple feature view versions can coexist in memory. Reuses the shared compute_table_id() from helpers.py for consistency with PostgreSQL and MySQL stores. Signed-off-by: yassinnouh21 <yassinnouh21@gmail.com>
1 parent 0c469a7 commit b36acb7

4 files changed

Lines changed: 331 additions & 45 deletions

File tree

sdk/python/feast/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class VersionedOnlineReadNotSupported(FeastError):
142142
def __init__(self, store_name: str, version: int):
143143
super().__init__(
144144
f"Versioned feature reads (@v{version}) are not yet supported by {store_name}. "
145-
f"Currently only SQLite, PostgreSQL, and MySQL support version-qualified feature references. "
145+
f"Currently only SQLite, PostgreSQL, MySQL, and FAISS support version-qualified feature references. "
146146
)
147147

148148

sdk/python/feast/infra/online_stores/faiss_online_store.py

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from feast import Entity, FeatureView, RepoConfig
1010
from feast.infra.key_encoding_utils import serialize_entity_key
11+
from feast.infra.online_stores.helpers import compute_table_id
1112
from feast.infra.online_stores.online_store import OnlineStore
1213
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
1314
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
@@ -43,16 +44,21 @@ def teardown(self):
4344
self.entity_keys = {}
4445

4546

47+
def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
48+
return compute_table_id(project, table, enable_versioning)
49+
50+
4651
class FaissOnlineStore(OnlineStore):
47-
_index: Optional[faiss.IndexIVFFlat] = None
48-
_in_memory_store: InMemoryStore = InMemoryStore()
49-
_config: Optional[FaissOnlineStoreConfig] = None
5052
_logger: logging.Logger = logging.getLogger(__name__)
5153

52-
def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat:
53-
if self._index is None or self._config is None:
54-
raise ValueError("Index is not initialized")
55-
return self._index
54+
def __init__(self):
55+
super().__init__()
56+
self._indices: Dict[str, faiss.IndexIVFFlat] = {}
57+
self._in_memory_stores: Dict[str, InMemoryStore] = {}
58+
self._config: Optional[FaissOnlineStoreConfig] = None
59+
60+
def _get_index(self, table_key: str) -> Optional[faiss.IndexIVFFlat]:
61+
return self._indices.get(table_key)
5662

5763
def update(
5864
self,
@@ -63,32 +69,45 @@ def update(
6369
entities_to_keep: Sequence[Entity],
6470
partial: bool,
6571
):
66-
feature_views = tables_to_keep
67-
if not feature_views:
68-
return
69-
70-
feature_names = [f.name for f in feature_views[0].features]
71-
dimension = len(feature_names)
72-
7372
self._config = FaissOnlineStoreConfig(**config.online_store.dict())
74-
if self._index is None or not partial:
75-
quantizer = faiss.IndexFlatL2(dimension)
76-
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
77-
self._index.train(
78-
np.random.rand(self._config.nlist * 100, dimension).astype(np.float32)
79-
)
80-
self._in_memory_store = InMemoryStore()
73+
versioning = config.registry.enable_online_feature_view_versioning
74+
75+
for table in tables_to_delete:
76+
table_key = _table_id(config.project, table, versioning)
77+
self._indices.pop(table_key, None)
78+
self._in_memory_stores.pop(table_key, None)
79+
80+
for table in tables_to_keep:
81+
table_key = _table_id(config.project, table, versioning)
82+
feature_names = [f.name for f in table.features]
83+
dimension = len(feature_names)
84+
85+
if table_key not in self._indices or not partial:
86+
quantizer = faiss.IndexFlatL2(dimension)
87+
index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
88+
index.train(
89+
np.random.rand(self._config.nlist * 100, dimension).astype(
90+
np.float32
91+
)
92+
)
93+
self._indices[table_key] = index
94+
self._in_memory_stores[table_key] = InMemoryStore()
8195

82-
self._in_memory_store.update(feature_names, {})
96+
self._in_memory_stores[table_key].update(feature_names, {})
8397

8498
def teardown(
8599
self,
86100
config: RepoConfig,
87101
tables: Sequence[FeatureView],
88102
entities: Sequence[Entity],
89103
):
90-
self._index = None
91-
self._in_memory_store.teardown()
104+
versioning = config.registry.enable_online_feature_view_versioning
105+
for table in tables:
106+
table_key = _table_id(config.project, table, versioning)
107+
self._indices.pop(table_key, None)
108+
store = self._in_memory_stores.pop(table_key, None)
109+
if store is not None:
110+
store.teardown()
92111

93112
def online_read(
94113
self,
@@ -97,23 +116,28 @@ def online_read(
97116
entity_keys: List[EntityKeyProto],
98117
requested_features: Optional[List[str]] = None,
99118
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
100-
if self._index is None:
119+
versioning = config.registry.enable_online_feature_view_versioning
120+
table_key = _table_id(config.project, table, versioning)
121+
index = self._get_index(table_key)
122+
in_memory_store = self._in_memory_stores.get(table_key)
123+
124+
if index is None or in_memory_store is None:
101125
return [(None, None)] * len(entity_keys)
102126

103127
results: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = []
104128
for entity_key in entity_keys:
105129
serialized_key = serialize_entity_key(
106130
entity_key, config.entity_key_serialization_version
107131
).hex()
108-
idx = self._in_memory_store.entity_keys.get(serialized_key, -1)
132+
idx = in_memory_store.entity_keys.get(serialized_key, -1)
109133
if idx == -1:
110134
results.append((None, None))
111135
else:
112-
feature_vector = self._index.reconstruct(int(idx))
136+
feature_vector = index.reconstruct(int(idx))
113137
feature_dict = {
114138
name: ValueProto(double_val=value)
115139
for name, value in zip(
116-
self._in_memory_store.feature_names, feature_vector
140+
in_memory_store.feature_names, feature_vector
117141
)
118142
}
119143
results.append((None, feature_dict))
@@ -128,8 +152,16 @@ def online_write_batch(
128152
],
129153
progress: Optional[Callable[[int], Any]],
130154
) -> None:
131-
if self._index is None:
132-
self._logger.warning("Index is not initialized. Skipping write operation.")
155+
versioning = config.registry.enable_online_feature_view_versioning
156+
table_key = _table_id(config.project, table, versioning)
157+
index = self._get_index(table_key)
158+
in_memory_store = self._in_memory_stores.get(table_key)
159+
160+
if index is None or in_memory_store is None:
161+
self._logger.warning(
162+
"Index for table '%s' is not initialized. Skipping write operation.",
163+
table_key,
164+
)
133165
return
134166

135167
feature_vectors = []
@@ -142,7 +174,7 @@ def online_write_batch(
142174
feature_vector = np.array(
143175
[
144176
feature_dict[name].double_val
145-
for name in self._in_memory_store.feature_names
177+
for name in in_memory_store.feature_names
146178
],
147179
dtype=np.float32,
148180
)
@@ -153,21 +185,17 @@ def online_write_batch(
153185
feature_vectors_array = np.array(feature_vectors)
154186

155187
existing_indices = [
156-
self._in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
188+
in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
157189
]
158190
mask = np.array(existing_indices) != -1
159191
if np.any(mask):
160-
self._index.remove_ids(
161-
np.array([idx for idx in existing_indices if idx != -1])
162-
)
192+
index.remove_ids(np.array([idx for idx in existing_indices if idx != -1]))
163193

164-
new_indices = np.arange(
165-
self._index.ntotal, self._index.ntotal + len(feature_vectors_array)
166-
)
167-
self._index.add(feature_vectors_array)
194+
new_indices = np.arange(index.ntotal, index.ntotal + len(feature_vectors_array))
195+
index.add(feature_vectors_array)
168196

169197
for sk, idx in zip(serialized_keys, new_indices):
170-
self._in_memory_store.entity_keys[sk] = idx
198+
in_memory_store.entity_keys[sk] = idx
171199

172200
if progress:
173201
progress(len(data))
@@ -189,12 +217,16 @@ def retrieve_online_documents(
189217
Optional[ValueProto],
190218
]
191219
]:
192-
if self._index is None:
220+
versioning = config.registry.enable_online_feature_view_versioning
221+
table_key = _table_id(config.project, table, versioning)
222+
index = self._get_index(table_key)
223+
224+
if index is None:
193225
self._logger.warning("Index is not initialized. Returning empty result.")
194226
return []
195227

196228
query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1)
197-
distances, indices = self._index.search(query_vector, top_k)
229+
distances, indices = index.search(query_vector, top_k)
198230

199231
results: List[
200232
Tuple[
@@ -209,7 +241,7 @@ def retrieve_online_documents(
209241
if idx == -1:
210242
continue
211243

212-
feature_vector = self._index.reconstruct(int(idx))
244+
feature_vector = index.reconstruct(int(idx))
213245

214246
timestamp = Timestamp()
215247
timestamp.GetCurrentTime()
@@ -237,5 +269,4 @@ async def online_read_async(
237269
entity_keys: List[EntityKeyProto],
238270
requested_features: Optional[List[str]] = None,
239271
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
240-
# Implement async read if needed
241272
raise NotImplementedError("Async read is not implemented for FaissOnlineStore")

sdk/python/feast/infra/online_stores/online_store.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,12 @@ def _check_versioned_read_support(self, grouped_refs):
274274
supported_types.append(PostgreSQLOnlineStore)
275275
except ImportError:
276276
pass
277+
try:
278+
from feast.infra.online_stores.faiss_online_store import FaissOnlineStore
279+
280+
supported_types.append(FaissOnlineStore)
281+
except ImportError:
282+
pass
277283

278284
if isinstance(self, tuple(supported_types)):
279285
return

0 commit comments

Comments
 (0)