Skip to content

Commit d091a85

Browse files
eugeniashurkoIevgeniia Oshurko
andauthored
Embedding pipeline / similarity processing and embedding service fixes (#71)
* Added model loading from local fs * Updated to python3.7 * Enabled embedder fetching from local models * Fixes to EmbeddingPipeline: added bugfixes and prediction to EmbeddingPipeline * App fixes * Changed prediction API of the embedder app * Added service prediction for json and nexus pgframes * Updated notebooks Co-authored-by: Ievgeniia Oshurko <eugenia.oshurko@epfl.ch>
1 parent 6c13e23 commit d091a85

33 files changed

Lines changed: 26134 additions & 1502 deletions

.dockerignore

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1-
tests/*
21
examples/*
2+
cord19kg/examples/*
3+
tests/*
34
docs/*
4-
cord19kg/*
5-
.pytest_cache/*
65
build/*
76
dist/*
8-
bluegraph.egg-info/*
7+
.pytest_cache/*
8+
.tox/*
9+
.eggs/*
10+
*.egg-info/
11+
*.ipynb_checkpoints/
12+
*.DS_Store
13+
bluegraph.egg-info/*
14+
bluegraph/version.py
15+
tests/data/*
16+
__MACOSX/
17+
build/
18+
dist/
19+
.pytest_cache
20+
__pycache__
21+
*/__pycache__
22+
attri2vec_test_model.zip
23+
neo4j_sage_embedder.zip
24+
stellar_sage_embedder.zip
25+
services/embedder/examples/*

.gitignore

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ build/
33
dist/
44
.pytest_cache
55
__pycache__
6-
kganalytics/notebooks/data/gephi*
76
*.ipynb_checkpoints/
87
docs/_build/
98
*.DS_Store
@@ -16,17 +15,24 @@ examples/data/attri2vec_test_model.zip
1615
examples/data/nasa.json
1716
examples/data/nasa_comention.json
1817
cord19kg/examples/data/Glucose_risk_3000_papers.csv
18+
cord19kg/examples/data/Glucose_risk_3000_paper_meta_data.csv
1919
cord19kg/examples/data/NCIT_ontology_linking.csv
2020
cord19kg/examples/data/NCIT_ontology_linking_3000_papers.csv
2121
cord19kg/examples/data/CORD_19_v47_occurrence_top_10000.json
2222
cord19kg/examples/data/_MACOSX/*
2323
cord19kg/examples/config/forge-config.yml
2424
cord19kg/examples/models/neuroshapes/
2525
cord19kg/examples/data/output_graphs/*
26-
cord19kg/examples/data/Glucose_risk_3000_paper_meta_data.csv
26+
cord19kg/examples/data/*.zip
27+
attri2vec_test_model.zip
28+
neo4j_sage_emedder.zip
29+
stellar_sage_emedder.zip
2730
__MACOSX/
28-
*.zip
31+
# *.zip
2932
.tox/
3033
.eggs
3134
bluegraph/version.py
3235
tests/data/*
36+
test_sim_proc.faiss
37+
test_sim_proc.pkl
38+
.coverage

bluegraph/backends/neo4j/embed/embedders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _fit_inductive_embedder(self, train_graph):
206206
train_graph.execute(train_query)
207207
return model_id
208208

209-
def _predict_embeddings(self, graph,
209+
def _predict_embeddings(self, graph, nodes=None,
210210
write=False, write_property=None):
211211
node_edge_selector = graph.get_projection_query(
212212
self.graph_configs["edge_weight"],

bluegraph/backends/neo4j/io.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,8 @@ def subgraph(self, nodes_to_include=None, edges_to_include=None,
487487
"""Get a node/edge induced subgraph."""
488488
return Neo4jGraphView(
489489
self.driver, self.node_label, self.edge_label,
490+
nodes_to_include=nodes_to_include,
491+
edges_to_include=edges_to_include,
490492
nodes_to_exclude=nodes_to_exclude,
491493
edges_to_exclude=edges_to_exclude,
492494
directed=self.directed)
@@ -507,11 +509,16 @@ class Neo4jGraphView(object):
507509
TODO: make methods public
508510
"""
509511
def __init__(self, driver, node_label,
510-
edge_label, nodes_to_exclude=None,
511-
edges_to_exclude=None, directed=True):
512+
edge_label,
513+
nodes_to_include=None, edges_to_include=None,
514+
nodes_to_exclude=None, edges_to_exclude=None,
515+
directed=True):
516+
"""Initialize an instance of Neo4jGraphView."""
512517
self.driver = driver
513518
self.node_label = node_label
514519
self.edge_label = edge_label
520+
self.nodes_to_include = nodes_to_include
521+
self.edges_to_include = edges_to_include
515522
self.nodes_to_exclude = nodes_to_exclude if nodes_to_exclude else []
516523
self.edges_to_exclude = edges_to_exclude if edges_to_exclude else []
517524
self.directed = directed

bluegraph/backends/stellargraph/embed/embedders.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,14 @@ def _fit_inductive_embedder(self, train_graph):
220220
embedding_model = Model(inputs=x_inp_src, outputs=x_out_src)
221221
return embedding_model
222222

223-
def _predict_embeddings(self, graph):
223+
def _predict_embeddings(self, graph, nodes=None):
224+
if nodes is None:
225+
nodes = graph.nodes()
224226
node_generator = _dispatch_generator(
225-
graph, self.model_name, self.params).flow(graph.nodes())
227+
graph, self.model_name, self.params).flow(nodes)
226228

227229
node_embeddings = self._embedding_model.predict(node_generator)
228-
res = dict(zip(graph.nodes(), node_embeddings.tolist()))
230+
res = dict(zip(nodes, node_embeddings.tolist()))
229231
return res
230232

231233
@staticmethod

bluegraph/core/embed/embedders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _fit_inductive_embedder(self, train_graph):
6161
pass
6262

6363
@abstractmethod
64-
def _predict_embeddings(self, graph):
64+
def _predict_embeddings(self, graph, nodes=None):
6565
pass
6666

6767
@staticmethod
@@ -159,16 +159,17 @@ def fit_model(self, pgframe):
159159
embeddings = embeddings.set_index("@id")
160160
return embeddings
161161

162-
def predict_embeddings(self, pgframe):
162+
def predict_embeddings(self, pgframe, nodes=None):
163163
"""Predict embeddings of out-sample elements."""
164+
if nodes is None:
165+
nodes = pgframe.nodes()
164166
if self._embedding_model is None:
165167
raise ElementEmbedder.PredictionException(
166168
"Embedder does not have a predictive model")
167169

168-
input_graph = self._generate_graph(
169-
pgframe, self.graph_configs)
170+
input_graph = self._generate_graph(pgframe, self.graph_configs)
170171

171-
node_embeddings = self._predict_embeddings(input_graph)
172+
node_embeddings = self._predict_embeddings(input_graph, nodes=nodes)
172173
node_embeddings = pd.DataFrame(
173174
node_embeddings.items(), columns=["@id", "embedding"])
174175
node_embeddings = node_embeddings.set_index("@id")

0 commit comments

Comments
 (0)