88
99from feast import Entity , FeatureView , RepoConfig
1010from feast .infra .key_encoding_utils import serialize_entity_key
11+ from feast .infra .online_stores .helpers import compute_table_id
1112from feast .infra .online_stores .online_store import OnlineStore
1213from feast .protos .feast .types .EntityKey_pb2 import EntityKey as EntityKeyProto
1314from feast .protos .feast .types .Value_pb2 import Value as ValueProto
@@ -43,16 +44,21 @@ def teardown(self):
4344 self .entity_keys = {}
4445
4546
47+ def _table_id (project : str , table : FeatureView , enable_versioning : bool = False ) -> str :
48+ return compute_table_id (project , table , enable_versioning )
49+
50+
4651class FaissOnlineStore (OnlineStore ):
47- _index : Optional [faiss .IndexIVFFlat ] = None
48- _in_memory_store : InMemoryStore = InMemoryStore ()
49- _config : Optional [FaissOnlineStoreConfig ] = None
5052 _logger : logging .Logger = logging .getLogger (__name__ )
5153
52- def _get_index (self , config : RepoConfig ) -> faiss .IndexIVFFlat :
53- if self ._index is None or self ._config is None :
54- raise ValueError ("Index is not initialized" )
55- return self ._index
54+ def __init__ (self ):
55+ super ().__init__ ()
56+ self ._indices : Dict [str , faiss .IndexIVFFlat ] = {}
57+ self ._in_memory_stores : Dict [str , InMemoryStore ] = {}
58+ self ._config : Optional [FaissOnlineStoreConfig ] = None
59+
60+ def _get_index (self , table_key : str ) -> Optional [faiss .IndexIVFFlat ]:
61+ return self ._indices .get (table_key )
5662
5763 def update (
5864 self ,
@@ -63,32 +69,45 @@ def update(
6369 entities_to_keep : Sequence [Entity ],
6470 partial : bool ,
6571 ):
66- feature_views = tables_to_keep
67- if not feature_views :
68- return
69-
70- feature_names = [f .name for f in feature_views [0 ].features ]
71- dimension = len (feature_names )
72-
7372 self ._config = FaissOnlineStoreConfig (** config .online_store .dict ())
74- if self ._index is None or not partial :
75- quantizer = faiss .IndexFlatL2 (dimension )
76- self ._index = faiss .IndexIVFFlat (quantizer , dimension , self ._config .nlist )
77- self ._index .train (
78- np .random .rand (self ._config .nlist * 100 , dimension ).astype (np .float32 )
79- )
80- self ._in_memory_store = InMemoryStore ()
73+ versioning = config .registry .enable_online_feature_view_versioning
74+
75+ for table in tables_to_delete :
76+ table_key = _table_id (config .project , table , versioning )
77+ self ._indices .pop (table_key , None )
78+ self ._in_memory_stores .pop (table_key , None )
79+
80+ for table in tables_to_keep :
81+ table_key = _table_id (config .project , table , versioning )
82+ feature_names = [f .name for f in table .features ]
83+ dimension = len (feature_names )
84+
85+ if table_key not in self ._indices or not partial :
86+ quantizer = faiss .IndexFlatL2 (dimension )
87+ index = faiss .IndexIVFFlat (quantizer , dimension , self ._config .nlist )
88+ index .train (
89+ np .random .rand (self ._config .nlist * 100 , dimension ).astype (
90+ np .float32
91+ )
92+ )
93+ self ._indices [table_key ] = index
94+ self ._in_memory_stores [table_key ] = InMemoryStore ()
8195
82- self ._in_memory_store .update (feature_names , {})
96+ self ._in_memory_stores [ table_key ] .update (feature_names , {})
8397
8498 def teardown (
8599 self ,
86100 config : RepoConfig ,
87101 tables : Sequence [FeatureView ],
88102 entities : Sequence [Entity ],
89103 ):
90- self ._index = None
91- self ._in_memory_store .teardown ()
104+ versioning = config .registry .enable_online_feature_view_versioning
105+ for table in tables :
106+ table_key = _table_id (config .project , table , versioning )
107+ self ._indices .pop (table_key , None )
108+ store = self ._in_memory_stores .pop (table_key , None )
109+ if store is not None :
110+ store .teardown ()
92111
93112 def online_read (
94113 self ,
@@ -97,23 +116,28 @@ def online_read(
97116 entity_keys : List [EntityKeyProto ],
98117 requested_features : Optional [List [str ]] = None ,
99118 ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
100- if self ._index is None :
119+ versioning = config .registry .enable_online_feature_view_versioning
120+ table_key = _table_id (config .project , table , versioning )
121+ index = self ._get_index (table_key )
122+ in_memory_store = self ._in_memory_stores .get (table_key )
123+
124+ if index is None or in_memory_store is None :
101125 return [(None , None )] * len (entity_keys )
102126
103127 results : List [Tuple [Optional [datetime ], Optional [Dict [str , Any ]]]] = []
104128 for entity_key in entity_keys :
105129 serialized_key = serialize_entity_key (
106130 entity_key , config .entity_key_serialization_version
107131 ).hex ()
108- idx = self . _in_memory_store .entity_keys .get (serialized_key , - 1 )
132+ idx = in_memory_store .entity_keys .get (serialized_key , - 1 )
109133 if idx == - 1 :
110134 results .append ((None , None ))
111135 else :
112- feature_vector = self . _index .reconstruct (int (idx ))
136+ feature_vector = index .reconstruct (int (idx ))
113137 feature_dict = {
114138 name : ValueProto (double_val = value )
115139 for name , value in zip (
116- self . _in_memory_store .feature_names , feature_vector
140+ in_memory_store .feature_names , feature_vector
117141 )
118142 }
119143 results .append ((None , feature_dict ))
@@ -128,8 +152,16 @@ def online_write_batch(
128152 ],
129153 progress : Optional [Callable [[int ], Any ]],
130154 ) -> None :
131- if self ._index is None :
132- self ._logger .warning ("Index is not initialized. Skipping write operation." )
155+ versioning = config .registry .enable_online_feature_view_versioning
156+ table_key = _table_id (config .project , table , versioning )
157+ index = self ._get_index (table_key )
158+ in_memory_store = self ._in_memory_stores .get (table_key )
159+
160+ if index is None or in_memory_store is None :
161+ self ._logger .warning (
162+ "Index for table '%s' is not initialized. Skipping write operation." ,
163+ table_key ,
164+ )
133165 return
134166
135167 feature_vectors = []
@@ -142,7 +174,7 @@ def online_write_batch(
142174 feature_vector = np .array (
143175 [
144176 feature_dict [name ].double_val
145- for name in self . _in_memory_store .feature_names
177+ for name in in_memory_store .feature_names
146178 ],
147179 dtype = np .float32 ,
148180 )
@@ -153,21 +185,17 @@ def online_write_batch(
153185 feature_vectors_array = np .array (feature_vectors )
154186
155187 existing_indices = [
156- self . _in_memory_store .entity_keys .get (sk , - 1 ) for sk in serialized_keys
188+ in_memory_store .entity_keys .get (sk , - 1 ) for sk in serialized_keys
157189 ]
158190 mask = np .array (existing_indices ) != - 1
159191 if np .any (mask ):
160- self ._index .remove_ids (
161- np .array ([idx for idx in existing_indices if idx != - 1 ])
162- )
192+ index .remove_ids (np .array ([idx for idx in existing_indices if idx != - 1 ]))
163193
164- new_indices = np .arange (
165- self ._index .ntotal , self ._index .ntotal + len (feature_vectors_array )
166- )
167- self ._index .add (feature_vectors_array )
194+ new_indices = np .arange (index .ntotal , index .ntotal + len (feature_vectors_array ))
195+ index .add (feature_vectors_array )
168196
169197 for sk , idx in zip (serialized_keys , new_indices ):
170- self . _in_memory_store .entity_keys [sk ] = idx
198+ in_memory_store .entity_keys [sk ] = idx
171199
172200 if progress :
173201 progress (len (data ))
@@ -189,12 +217,16 @@ def retrieve_online_documents(
189217 Optional [ValueProto ],
190218 ]
191219 ]:
192- if self ._index is None :
220+ versioning = config .registry .enable_online_feature_view_versioning
221+ table_key = _table_id (config .project , table , versioning )
222+ index = self ._get_index (table_key )
223+
224+ if index is None :
193225 self ._logger .warning ("Index is not initialized. Returning empty result." )
194226 return []
195227
196228 query_vector = np .array (embedding , dtype = np .float32 ).reshape (1 , - 1 )
197- distances , indices = self . _index .search (query_vector , top_k )
229+ distances , indices = index .search (query_vector , top_k )
198230
199231 results : List [
200232 Tuple [
@@ -209,7 +241,7 @@ def retrieve_online_documents(
209241 if idx == - 1 :
210242 continue
211243
212- feature_vector = self . _index .reconstruct (int (idx ))
244+ feature_vector = index .reconstruct (int (idx ))
213245
214246 timestamp = Timestamp ()
215247 timestamp .GetCurrentTime ()
@@ -237,5 +269,4 @@ async def online_read_async(
237269 entity_keys : List [EntityKeyProto ],
238270 requested_features : Optional [List [str ]] = None ,
239271 ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
240- # Implement async read if needed
241272 raise NotImplementedError ("Async read is not implemented for FaissOnlineStore" )
0 commit comments