Skip to content

Commit 6c13e23

Browse files
eugeniashurkoIevgeniia Oshurko
andauthored
Embedding pipeline optimization (#70)
* Fixes to EmbeddingPipeline * Removed embedding table from Embedding Pipeline * Added error message when element is not in index Co-authored-by: Ievgeniia Oshurko <eugenia.oshurko@epfl.ch>
1 parent 25b9299 commit 6c13e23

2 files changed

Lines changed: 54 additions & 29 deletions

File tree

bluegraph/downstream/data_structures.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import re
2323
import pickle
2424
import shutil
25+
import warnings
2526

27+
from bluegraph.exceptions import BlueGraphException, BlueGraphWarning
2628
from .similarity import SimilarityProcessor
2729

2830

@@ -87,6 +89,7 @@ def predict(self, pgframe, predict_elements=None):
8789
return self.model.predict(data)
8890

8991

92+
9093
class Preprocessor(ABC):
9194
"""Preprocessor inferface for EmbeddingPipeline."""
9295

@@ -118,11 +121,9 @@ def fit_model(self):
118121
class EmbeddingPipeline(object):
119122

120123
def __init__(self, preprocessor=None, embedder=None,
121-
embedding_table=None,
122124
similarity_processor=None):
123125
self.preprocessor = preprocessor
124126
self.embedder = embedder
125-
self.embedding_table = embedding_table
126127
self.similarity_processor = similarity_processor
127128

128129
def is_transductive(self):
@@ -138,21 +139,38 @@ def run_fitting(self, data):
138139
train_data = self.preprocessor.transform(data)
139140
else:
140141
train_data = data
141-
# Train the embedder
142-
self.embedding_table = self.embedder.fit_model(train_data)
142+
if not self.embedder:
143+
raise EmbeddingPipelineException(
144+
"Embedder object is not specified: cannot run fitting")
145+
else:
146+
# Train the embedder
147+
embedding_table = self.embedder.fit_model(train_data)
143148
# Create a similarity processor
144149
vectors =\
145-
self.embedding_table["embedding"].tolist()
150+
embedding_table["embedding"].tolist()
146151
self.similarity_processor._initialize_model(vectors)
147-
self.similarity_processor.add(vectors, self.embedding_table.index)
148-
self.similarity_processor.index = self.embedding_table.index
152+
self.similarity_processor.add(vectors, embedding_table.index)
153+
self.similarity_processor.index = embedding_table.index
149154

150155
def run_prediction(self, data):
151156
pass
152157

158+
def generate_embedding_table(self):
159+
"""Generate embedding table from similarity index."""
160+
index = self.similarity_processor.index
161+
pairs = [
162+
(ind, self.similarity_processor._model.reconstruct(i))
163+
for i, ind in enumerate(index)
164+
]
165+
return pd.DataFrame(
166+
pairs, columns=["@id", "embedding"]).set_index("@id")
167+
168+
153169
def retrieve_embeddings(self, indices):
154-
if self.embedding_table is not None:
155-
return self.embedding_table.loc[indices]["embedding"].tolist()
170+
if self.similarity_processor is None:
171+
raise EmbeddingPipelineException(
172+
"Similarity processor object is None, cannot "
173+
"retrieve embedding vectors")
156174
else:
157175
return [
158176
el.tolist()
@@ -191,12 +209,6 @@ def load(cls, path, embedder_interface=None, embedder_ext="pkl"):
191209
embedder = embedder_interface.load(
192210
os.path.join(path, "embedder.zip"))
193211

194-
# Load the embedding table
195-
embedding_table = None
196-
if os.path.isfile(os.path.join(path, "vectors.pkl")):
197-
embedding_table = pd.read_pickle(
198-
os.path.join(path, "vectors.pkl"))
199-
200212
# Load the similarity processor
201213
similarity_processor = SimilarityProcessor.load(
202214
os.path.join(path, "similarity.pkl"),
@@ -205,7 +217,6 @@ def load(cls, path, embedder_interface=None, embedder_ext="pkl"):
205217
pipeline = cls(
206218
preprocessor=encoder,
207219
embedder=embedder,
208-
embedding_table=embedding_table,
209220
similarity_processor=similarity_processor)
210221

211222
if decompressed:
@@ -214,7 +225,6 @@ def load(cls, path, embedder_interface=None, embedder_ext="pkl"):
214225
return pipeline
215226

216227
def save(self, path, compress=False):
217-
218228
if not os.path.isdir(path):
219229
os.mkdir(path)
220230

@@ -223,12 +233,12 @@ def save(self, path, compress=False):
223233
pickle.dump(self.preprocessor, f)
224234

225235
# Save the embedding model
226-
self.embedder.save(
227-
os.path.join(path, "embedder"), compress=True)
228-
229-
# Save the embedding table
230-
self.embedding_table.to_pickle(
231-
os.path.join(path, "vectors.pkl"))
236+
if self.embedder:
237+
self.embedder.save(
238+
os.path.join(path, "embedder"), compress=True)
239+
else:
240+
with open(os.path.join(path, "embedder.pkl"), "wb") as f:
241+
pickle.dump(self.preprocessor, f)
232242

233243
# Save the similarity processor
234244
if self.similarity_processor is not None:
@@ -239,3 +249,9 @@ def save(self, path, compress=False):
239249
if compress:
240250
shutil.make_archive(path, 'zip', path)
241251
shutil.rmtree(path)
252+
253+
class EmbeddingPipelineException(BlueGraphException):
254+
pass
255+
256+
class EmbeddingPipelineWarning(BlueGraphWarning):
257+
pass

bluegraph/downstream/similarity.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import faiss
2121
import os
2222

23+
from bluegraph.exceptions import BlueGraphException
24+
25+
2326
# This is to avoid a wierd Faiss segmentation fault (TODO: investigate)
2427
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
2528

@@ -113,8 +116,14 @@ def _initialize_model(self, initial_vectors=None):
113116

114117
def get_vectors(self, existing_indices):
115118
if self.index is not None:
116-
existing_indices = self.index.get_indexer(existing_indices)
117-
x = [self._model.reconstruct(int(i)) for i in existing_indices]
119+
int_idices = self.index.get_indexer(existing_indices)
120+
try:
121+
x = [self._model.reconstruct(int(i)) for i in int_idices]
122+
except RuntimeError:
123+
raise SimilarityProcessor.SimilarityException(
124+
"Cannot retrieve vectors for provided elements {} ".format(
125+
existing_indices) +
126+
"make sure all the elements are in the index.")
118127
return x
119128

120129
def query_existing(self, existing_indices, k=10):
@@ -161,16 +170,16 @@ def get_similar_points(self, vectors=None, vector_indices=None,
161170
indices = int_index
162171
return indices, distance
163172

164-
class TrainException(Exception):
173+
class TrainException(BlueGraphException):
165174
pass
166175

167-
class SimilarityException(Exception):
176+
class SimilarityException(BlueGraphException):
168177
pass
169178

170-
class IndexException(Exception):
179+
class IndexException(BlueGraphException):
171180
pass
172181

173-
class QueryException(Exception):
182+
class QueryException(BlueGraphException):
174183
pass
175184

176185

0 commit comments

Comments
 (0)