2222import re
2323import pickle
2424import shutil
25+ import warnings
2526
27+ from bluegraph .exceptions import BlueGraphException , BlueGraphWarning
2628from .similarity import SimilarityProcessor
2729
2830
@@ -87,6 +89,7 @@ def predict(self, pgframe, predict_elements=None):
8789 return self .model .predict (data )
8890
8991
92+
9093class Preprocessor (ABC ):
9194 """Preprocessor inferface for EmbeddingPipeline."""
9295
@@ -118,11 +121,9 @@ def fit_model(self):
118121class EmbeddingPipeline (object ):
119122
120123 def __init__ (self , preprocessor = None , embedder = None ,
121- embedding_table = None ,
122124 similarity_processor = None ):
123125 self .preprocessor = preprocessor
124126 self .embedder = embedder
125- self .embedding_table = embedding_table
126127 self .similarity_processor = similarity_processor
127128
128129 def is_transductive (self ):
@@ -138,21 +139,38 @@ def run_fitting(self, data):
138139 train_data = self .preprocessor .transform (data )
139140 else :
140141 train_data = data
141- # Train the embedder
142- self .embedding_table = self .embedder .fit_model (train_data )
142+ if not self .embedder :
143+ raise EmbeddingPipelineException (
144+ "Embedder object is not specified: cannot run fitting" )
145+ else :
146+ # Train the embedder
147+ embedding_table = self .embedder .fit_model (train_data )
143148 # Create a similarity processor
144149 vectors = \
145- self . embedding_table ["embedding" ].tolist ()
150+ embedding_table ["embedding" ].tolist ()
146151 self .similarity_processor ._initialize_model (vectors )
147- self .similarity_processor .add (vectors , self . embedding_table .index )
148- self .similarity_processor .index = self . embedding_table .index
152+ self .similarity_processor .add (vectors , embedding_table .index )
153+ self .similarity_processor .index = embedding_table .index
149154
150155 def run_prediction (self , data ):
151156 pass
152157
158+ def generate_embedding_table (self ):
159+ """Generate embedding table from similarity index."""
160+ index = self .similarity_processor .index
161+ pairs = [
162+ (ind , self .similarity_processor ._model .reconstruct (i ))
163+ for i , ind in enumerate (index )
164+ ]
165+ return pd .DataFrame (
166+ pairs , columns = ["@id" , "embedding" ]).set_index ("@id" )
167+
168+
153169 def retrieve_embeddings (self , indices ):
154- if self .embedding_table is not None :
155- return self .embedding_table .loc [indices ]["embedding" ].tolist ()
170+ if self .similarity_processor is None :
171+ raise EmbeddingPipelineException (
172+ "Similarity processor object is None, cannot "
173+ "retrieve embedding vectors" )
156174 else :
157175 return [
158176 el .tolist ()
@@ -191,12 +209,6 @@ def load(cls, path, embedder_interface=None, embedder_ext="pkl"):
191209 embedder = embedder_interface .load (
192210 os .path .join (path , "embedder.zip" ))
193211
194- # Load the embedding table
195- embedding_table = None
196- if os .path .isfile (os .path .join (path , "vectors.pkl" )):
197- embedding_table = pd .read_pickle (
198- os .path .join (path , "vectors.pkl" ))
199-
200212 # Load the similarity processor
201213 similarity_processor = SimilarityProcessor .load (
202214 os .path .join (path , "similarity.pkl" ),
@@ -205,7 +217,6 @@ def load(cls, path, embedder_interface=None, embedder_ext="pkl"):
205217 pipeline = cls (
206218 preprocessor = encoder ,
207219 embedder = embedder ,
208- embedding_table = embedding_table ,
209220 similarity_processor = similarity_processor )
210221
211222 if decompressed :
@@ -214,7 +225,6 @@ def load(cls, path, embedder_interface=None, embedder_ext="pkl"):
214225 return pipeline
215226
216227 def save (self , path , compress = False ):
217-
218228 if not os .path .isdir (path ):
219229 os .mkdir (path )
220230
@@ -223,12 +233,12 @@ def save(self, path, compress=False):
223233 pickle .dump (self .preprocessor , f )
224234
225235 # Save the embedding model
226- self .embedder . save (
227- os . path . join ( path , "embedder" ), compress = True )
228-
229- # Save the embedding table
230- self . embedding_table . to_pickle (
231- os . path . join ( path , "vectors.pkl" ) )
236+ if self .embedder :
237+ self . embedder . save (
238+ os . path . join ( path , "embedder" ), compress = True )
239+ else :
240+ with open ( os . path . join ( path , "embedder.pkl" ), "wb" ) as f :
241+ pickle . dump ( self . preprocessor , f )
232242
233243 # Save the similarity processor
234244 if self .similarity_processor is not None :
@@ -239,3 +249,9 @@ def save(self, path, compress=False):
239249 if compress :
240250 shutil .make_archive (path , 'zip' , path )
241251 shutil .rmtree (path )
252+
253+ class EmbeddingPipelineException (BlueGraphException ):
254+ pass
255+
256+ class EmbeddingPipelineWarning (BlueGraphWarning ):
257+ pass
0 commit comments