Skip to content

Commit bca582e

Browse files
committed
Introduce GraphSAGE node embedding algorithm
1 parent 8f78e7b commit bca582e

6 files changed

Lines changed: 296 additions & 2 deletions
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Node Embeddings 0b: Prepare: Calculate Degree Property.
2+
3+
CALL gds.degree.mutate(
4+
$dependencies_projection + '-cleaned', {
5+
orientation: 'UNDIRECTED'
6+
,relationshipWeightProperty: CASE $dependencies_projection_weight_property WHEN '' THEN null ELSE $dependencies_projection_weight_property END
7+
,mutateProperty: 'degreeForNodeEmbeddings'
8+
})
9+
YIELD nodePropertiesWritten
10+
,preProcessingMillis
11+
,computeMillis
12+
,mutateMillis
13+
,postProcessingMillis
14+
,centralityDistribution
15+
RETURN nodePropertiesWritten
16+
,preProcessingMillis
17+
,computeMillis
18+
,mutateMillis
19+
,postProcessingMillis
20+
,centralityDistribution.min
21+
,centralityDistribution.mean
22+
,centralityDistribution.max
23+
,centralityDistribution.p50
24+
,centralityDistribution.p75
25+
,centralityDistribution.p90
26+
,centralityDistribution.p95
27+
,centralityDistribution.p99
28+
,centralityDistribution.p999
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Node Embeddings 0b: Prepare: Calculate Degree Property.
2+
3+
CALL gds.model.drop($dependencies_projection + '-graphSAGE', false)
4+
YIELD modelName,
5+
modelType,
6+
modelInfo,
7+
creationTime,
8+
trainConfig,
9+
graphSchema,
10+
loaded,
11+
stored,
12+
published
13+
RETURN modelName,
14+
modelType,
15+
modelInfo,
16+
creationTime,
17+
trainConfig,
18+
graphSchema,
19+
loaded,
20+
stored,
21+
published
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Node Embeddings 4c using GraphSAGE (Graph Neural Networks): Train. Requires: "Node_Embeddings_0b_Prepare_Degree.cypher".
2+
3+
CALL gds.beta.graphSage.train(
4+
$dependencies_projection + '-cleaned', {
5+
modelName: $dependencies_projection + '-graphSAGE'
6+
,featureProperties: ['degreeForNodeEmbeddings']
7+
,embeddingDimension: toInteger($dependencies_projection_embedding_dimension)
8+
,relationshipWeightProperty: CASE $dependencies_projection_weight_property WHEN '' THEN null ELSE $dependencies_projection_weight_property END
9+
,batchSize: 64
10+
,activationFunction: 'relu'
11+
,sampleSizes: [25, 20, 20, 10]
12+
//,aggregator: 'pool'
13+
//,epochs: 10
14+
//,penaltyL2: 0.0000001
15+
//,tolerance: 0.0001
16+
//,learningRate: 0.1
17+
//,searchDepth: 5
18+
,randomSeed: 47
19+
}
20+
)
21+
YIELD modelInfo AS info, trainMillis
22+
RETURN
23+
info.modelName AS modelName,
24+
info.metrics.didConverge AS didConverge,
25+
info.metrics.ranEpochs AS ranEpochs,
26+
info.metrics.epochLosses AS epochLosses,
27+
trainMillis AS trainingTimeMilliseconds
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Node Embeddings 4d using GraphSAGE: Stream. Requires "Add_file_name and_extension.cypher".
2+
3+
CALL gds.beta.graphSage.stream(
4+
$dependencies_projection + '-cleaned', {
5+
modelName: $dependencies_projection + '-graphSAGE'
6+
}
7+
)
8+
YIELD nodeId, embedding
9+
WITH gds.util.asNode(nodeId) AS codeUnit
10+
,embedding
11+
OPTIONAL MATCH (artifact:Java:Artifact)-[:CONTAINS]->(codeUnit)
12+
WITH *, artifact.name AS artifactName
13+
OPTIONAL MATCH (projectRoot:Directory)<-[:HAS_ROOT]-(proj:TS:Project)-[:CONTAINS]->(codeUnit)
14+
WITH *, last(split(projectRoot.absoluteFileName, '/')) AS projectName
15+
RETURN DISTINCT
16+
coalesce(codeUnit.fqn, codeUnit.globalFqn, codeUnit.fileName, codeUnit.signature, codeUnit.name) AS codeUnitName
17+
,codeUnit.name AS shortCodeUnitName
18+
,elementId(codeUnit) AS nodeElementId
19+
,coalesce(artifactName, projectName) AS projectName
20+
,coalesce(codeUnit.communityLeidenId, 0) AS communityId
21+
,coalesce(codeUnit.centralityPageRank, 0.01) AS centrality
22+
,embedding

jupyter/NodeEmbeddingsJava.ipynb

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,52 @@
233233
" return embeddings"
234234
]
235235
},
236+
{
237+
"cell_type": "code",
238+
"execution_count": null,
239+
"id": "48cb52c6",
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"def create_node_embeddings_with_GraphSAGE(parameters: dict) -> pd.DataFrame: \n",
244+
" \"\"\"\n",
245+
" Creates an in-memory Graph projection by calling \"create_undirected_projection\", \n",
246+
" enriches it with a degree centrality property for every node, trains GraphSAGE \n",
247+
" and returns the resulting node embeddings as DataFrame.\n",
248+
" \n",
249+
" parameters\n",
250+
" ----------\n",
251+
" dependencies_projection : str\n",
252+
" The name prefix for the in-memory projection for dependencies. Example: \"java-package-embeddings-notebook\"\n",
253+
" dependencies_projection_node : str\n",
254+
" The label of the nodes that will be used for the projection. Example: \"Package\"\n",
255+
" dependencies_projection_weight_property : str\n",
256+
" The name of the node property that contains the dependency weight. Example: \"weight25PercentInterfaces\"\n",
257+
" dependencies_projection_embedding_dimension : str\n",
258+
" The number of the dimensions and therefore size of the resulting array of floating point numbers\n",
259+
" \"\"\"\n",
260+
" \n",
261+
" is_data_available=create_undirected_projection(parameters)\n",
262+
" \n",
263+
" if not is_data_available:\n",
264+
" print(\"No projected data for node embeddings calculation available\")\n",
265+
" empty_result = pd.DataFrame(columns=[\"codeUnitName\", \"shortCodeUnitName\", 'projectName', 'communityId', 'centrality', 'embedding'])\n",
266+
" return empty_result\n",
267+
" \n",
268+
" existing_embeddings_query_filename=\"../cypher/Node_Embeddings/Node_Embeddings_0a_Query_Calculated.cypher\"\n",
269+
" embeddings=query_cypher_to_data_frame(existing_embeddings_query_filename, parameters)\n",
270+
" if embeddings.empty:\n",
271+
" query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_0b_Prepare_Degree.cypher\", parameters)\n",
272+
" query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_0c_Drop_Model.cypher\", parameters)\n",
273+
" display(query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_4b_GraphSAGE_Train.cypher\", parameters))\n",
274+
" embeddings=query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_4d_GraphSAGE_Stream.cypher\", parameters)\n",
275+
" else:\n",
276+
" print(\"The results have been provided by the query filename: \" + existing_embeddings_query_filename)\n",
277+
" \n",
278+
" display(embeddings.head()) # Display the first entries of the table\n",
279+
" return embeddings"
280+
]
281+
},
236282
{
237283
"cell_type": "code",
238284
"execution_count": null,
@@ -696,6 +742,34 @@
696742
"plot_2d_node_embeddings(embeddings_node2vec, get_plot_title(\"Java Packages\", \"node2vec\", scores_node2vec))"
697743
]
698744
},
745+
{
746+
"cell_type": "markdown",
747+
"id": "873d6a4e",
748+
"metadata": {},
749+
"source": [
750+
"### 1.6 Node Embeddings for Java Packages using GraphSAGE"
751+
]
752+
},
753+
{
754+
"cell_type": "code",
755+
"execution_count": null,
756+
"id": "f25a062f",
757+
"metadata": {},
758+
"outputs": [],
759+
"source": [
760+
"java_package_embeddings_parameters={\n",
761+
" \"dependencies_projection\": \"java-package-embeddings-notebook\",\n",
762+
" \"dependencies_projection_node\": \"Package\",\n",
763+
" \"dependencies_projection_weight_property\": \"weight25PercentInterfaces\",\n",
764+
" \"dependencies_projection_write_property\": \"embeddingsGraphSAGE\",\n",
765+
" \"dependencies_projection_embedding_dimension\":\"32\"\n",
766+
"}\n",
767+
"embeddings_graphSAGE= create_node_embeddings_with_GraphSAGE(java_package_embeddings_parameters)\n",
768+
"embeddings_graphSAGE = prepare_node_embeddings_for_2d_visualization(embeddings_graphSAGE)\n",
769+
"scores_graphSAGE = CommunityScores.calculate(embeddings_graphSAGE)\n",
770+
"plot_2d_node_embeddings(embeddings_graphSAGE, get_plot_title(\"Java Packages\", \"GraphSAGE\", scores_graphSAGE))"
771+
]
772+
},
699773
{
700774
"cell_type": "markdown",
701775
"id": "b9a5d57b",
@@ -714,14 +788,38 @@
714788
"outputs": [],
715789
"source": [
716790
"plot_all_2d_node_embeddings_in_grid(\n",
717-
" embeddings=[embeddings_fastRP, embeddings_hashGNN, embeddings_node2vec],\n",
791+
" embeddings=[embeddings_fastRP, embeddings_hashGNN, embeddings_node2vec, embeddings_graphSAGE],\n",
718792
" titles=[\n",
719793
" get_plot_title(\"Java Packages\", \"Fast Random Projection\", scores_fastRP),\n",
720794
" get_plot_title(\"Java Packages\", \"HashGNN\", scores_hashGNN),\n",
721795
" get_plot_title(\"Java Packages\", \"node2vec\", scores_node2vec),\n",
796+
" get_plot_title(\"Java Packages\", \"GraphSAGE\", scores_graphSAGE),\n",
722797
" ],\n",
723798
")"
724799
]
800+
},
801+
{
802+
"cell_type": "markdown",
803+
"id": "6d55b6f2",
804+
"metadata": {},
805+
"source": [
806+
"#### Interpreting Node Embedding Results\n",
807+
"\n",
808+
"##### Summary of Observations\n",
809+
"\n",
810+
"- **FastRP** and **node2vec** show clear, well-separated clusters\n",
811+
"- **HashGNN** and **GraphSAGE** produce more diffuse embeddings\n",
812+
"- Silhouette scores are high for FastRP / node2vec and low for HashGNN / GraphSAGE\n",
813+
"\n",
814+
"These differences are expected and stem from the **fundamentally different objectives** of the algorithms.\n",
815+
"\n",
816+
"##### Key Takeaways\n",
817+
"\n",
818+
"- **FastRP and node2vec** are well-suited for **community discovery and visualization**\n",
819+
"- **HashGNN** is best viewed as a **fast structural fingerprint**, not a clustering embedding\n",
820+
"- **GraphSAGE** requires meaningful node features or labels and performs poorly in dense, feature-poor settings\n",
821+
"- Poor silhouette scores for HashGNN and GraphSAGE are **expected and theoretically consistent**"
822+
]
725823
}
726824
],
727825
"metadata": {

jupyter/NodeEmbeddingsTypescript.ipynb

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,52 @@
233233
" return embeddings"
234234
]
235235
},
236+
{
237+
"cell_type": "code",
238+
"execution_count": null,
239+
"id": "e2b52e51",
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"def create_node_embeddings_with_GraphSAGE(parameters: dict) -> pd.DataFrame: \n",
244+
" \"\"\"\n",
245+
" Creates an in-memory Graph projection by calling \"create_undirected_projection\", \n",
246+
" enriches it with a degree centrality property for every node, trains GraphSAGE \n",
247+
" and returns the resulting node embeddings as DataFrame.\n",
248+
" \n",
249+
" parameters\n",
250+
" ----------\n",
251+
" dependencies_projection : str\n",
252+
" The name prefix for the in-memory projection for dependencies. Example: \"java-package-embeddings-notebook\"\n",
253+
" dependencies_projection_node : str\n",
254+
" The label of the nodes that will be used for the projection. Example: \"Package\"\n",
255+
" dependencies_projection_weight_property : str\n",
256+
" The name of the node property that contains the dependency weight. Example: \"weight25PercentInterfaces\"\n",
257+
" dependencies_projection_embedding_dimension : str\n",
258+
" The number of the dimensions and therefore size of the resulting array of floating point numbers\n",
259+
" \"\"\"\n",
260+
" \n",
261+
" is_data_available=create_undirected_projection(parameters)\n",
262+
" \n",
263+
" if not is_data_available:\n",
264+
" print(\"No projected data for node embeddings calculation available\")\n",
265+
" empty_result = pd.DataFrame(columns=[\"codeUnitName\", \"shortCodeUnitName\", 'projectName', 'communityId', 'centrality', 'embedding'])\n",
266+
" return empty_result\n",
267+
" \n",
268+
" existing_embeddings_query_filename=\"../cypher/Node_Embeddings/Node_Embeddings_0a_Query_Calculated.cypher\"\n",
269+
" embeddings=query_cypher_to_data_frame(existing_embeddings_query_filename, parameters)\n",
270+
" if embeddings.empty:\n",
271+
" query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_0b_Prepare_Degree.cypher\", parameters)\n",
272+
" query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_0c_Drop_Model.cypher\", parameters)\n",
273+
" display(query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_4b_GraphSAGE_Train.cypher\", parameters))\n",
274+
" embeddings=query_cypher_to_data_frame(\"../cypher/Node_Embeddings/Node_Embeddings_4d_GraphSAGE_Stream.cypher\", parameters)\n",
275+
" else:\n",
276+
" print(\"The results have been provided by the query filename: \" + existing_embeddings_query_filename)\n",
277+
" \n",
278+
" display(embeddings.head()) # Display the first entries of the table\n",
279+
" return embeddings"
280+
]
281+
},
236282
{
237283
"cell_type": "code",
238284
"execution_count": null,
@@ -699,6 +745,34 @@
699745
"plot_2d_node_embeddings(embeddings_node2vec, get_plot_title(\"TypeScript Modules\", \"node2vec\", scores_node2vec))"
700746
]
701747
},
748+
{
749+
"cell_type": "markdown",
750+
"id": "059d162c",
751+
"metadata": {},
752+
"source": [
753+
"### 1.6 Node Embeddings for Java Packages using GraphSAGE"
754+
]
755+
},
756+
{
757+
"cell_type": "code",
758+
"execution_count": null,
759+
"id": "2c5664b9",
760+
"metadata": {},
761+
"outputs": [],
762+
"source": [
763+
"typescript_module_embeddings_parameters={\n",
764+
" \"dependencies_projection\": \"typescript-module-embeddings-notebook\",\n",
765+
" \"dependencies_projection_node\": \"Module\",\n",
766+
" \"dependencies_projection_weight_property\": \"lowCouplingElement25PercentWeight\",\n",
767+
" \"dependencies_projection_write_property\": \"embeddingsGraphSAGE\",\n",
768+
" \"dependencies_projection_embedding_dimension\":\"32\"\n",
769+
"}\n",
770+
"embeddings_graphSAGE= create_node_embeddings_with_GraphSAGE(typescript_module_embeddings_parameters)\n",
771+
"embeddings_graphSAGE = prepare_node_embeddings_for_2d_visualization(embeddings_graphSAGE)\n",
772+
"scores_graphSAGE = CommunityScores.calculate(embeddings_graphSAGE)\n",
773+
"plot_2d_node_embeddings(embeddings_graphSAGE, get_plot_title(\"TypeScript Modules\", \"GraphSAGE\", scores_graphSAGE))"
774+
]
775+
},
702776
{
703777
"cell_type": "markdown",
704778
"id": "c5c73bd3",
@@ -717,14 +791,38 @@
717791
"outputs": [],
718792
"source": [
719793
"plot_all_2d_node_embeddings_in_grid(\n",
720-
" embeddings=[embeddings_fastRP, embeddings_hashGNN, embeddings_node2vec],\n",
794+
" embeddings=[embeddings_fastRP, embeddings_hashGNN, embeddings_node2vec, embeddings_graphSAGE],\n",
721795
" titles=[\n",
722796
" get_plot_title(\"TypeScript Modules\", \"Fast Random Projection\", scores_fastRP),\n",
723797
" get_plot_title(\"TypeScript Modules\", \"HashGNN\", scores_hashGNN),\n",
724798
" get_plot_title(\"TypeScript Modules\", \"node2vec\", scores_node2vec),\n",
799+
" get_plot_title(\"TypeScript Modules\", \"GraphSAGE\", scores_graphSAGE),\n",
725800
" ],\n",
726801
")"
727802
]
803+
},
804+
{
805+
"cell_type": "markdown",
806+
"id": "75acc17d",
807+
"metadata": {},
808+
"source": [
809+
"#### Interpreting Node Embedding Results\n",
810+
"\n",
811+
"##### Summary of Observations\n",
812+
"\n",
813+
"- **FastRP** and **node2vec** show clear, well-separated clusters\n",
814+
"- **HashGNN** and **GraphSAGE** produce more diffuse embeddings\n",
815+
"- Silhouette scores are high for FastRP / node2vec and low for HashGNN / GraphSAGE\n",
816+
"\n",
817+
"These differences are expected and stem from the **fundamentally different objectives** of the algorithms.\n",
818+
"\n",
819+
"##### Key Takeaways\n",
820+
"\n",
821+
"- **FastRP and node2vec** are well-suited for **community discovery and visualization**\n",
822+
"- **HashGNN** is best viewed as a **fast structural fingerprint**, not a clustering embedding\n",
823+
"- **GraphSAGE** requires meaningful node features or labels and performs poorly in dense, feature-poor settings\n",
824+
"- Poor silhouette scores for HashGNN and GraphSAGE are **expected and theoretically consistent**"
825+
]
728826
}
729827
],
730828
"metadata": {

0 commit comments

Comments
 (0)