Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
199 commits
Select commit Hold shift + click to select a range
43fc15c
added function related to training and for GNN, needed to define GNN …
Aug 19, 2025
e1695cc
Added the gnn part, must be fine-tuned hyper-params, no test
Aug 19, 2025
07d014e
Removed the barriers in the creation of the DAG
Aug 19, 2025
1713ca6
coded tested and fixed, need to add a cross validation module
Aug 20, 2025
83d4313
fixed the problem of the predict_device_for_figure_of_merits
Aug 20, 2025
63b38d9
Hellinger test done: success
Aug 20, 2025
6a8e095
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 20, 2025
470cd8f
GNN predictor fixed with optuna and tested
Aug 21, 2025
489811d
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 21, 2025
9197707
GNN predictor fixed with optuna and tested
Aug 21, 2025
5770501
Problems modified TPESampler and not TYPESampler
Aug 21, 2025
a62b6ae
Problems modified TPESampler and not TYPESampler
Aug 21, 2025
a67491c
Problems modified TPESampler and not TYPESampler
Aug 21, 2025
69b483b
Test modified with number of epochs as parameter
Aug 21, 2025
fe5b210
Eliminated trained model
Aug 21, 2025
edff928
Changed the test estimated hellinger for windows
Aug 21, 2025
9ae1b9c
Changed the test estimated hellinger for windows
Aug 21, 2025
6ad1996
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 21, 2025
ac11b2d
Problem with windows solved eliminating warning
Aug 21, 2025
430fcf5
Files modified according suggestion
Aug 22, 2025
0d2a758
Fixed the comments related to test hellinger distance and utils
antotu Aug 25, 2025
b767c9f
Fixed modification also with pre-commit
antotu Aug 25, 2025
5fae012
Refactor the test ml predictor considering to join function related M…
antotu Aug 25, 2025
9911e64
Modified part of helper in order to solve problems code
antotu Aug 26, 2025
fc5e386
Update tests/device_selection/test_predictor_ml.py
antotu Aug 27, 2025
3574f61
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 27, 2025
8a8a6e7
first round fixes
antotu Aug 27, 2025
32ad2cd
pre-commit fixes
antotu Aug 27, 2025
22c65fa
Update src/mqt/predictor/ml/predictor.py
antotu Aug 27, 2025
366c1ee
Update src/mqt/predictor/ml/predictor.py
antotu Aug 27, 2025
ae05f74
Partial modification
antotu Aug 27, 2025
f38c9bb
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 27, 2025
74fe841
fixed comments repo
antotu Aug 27, 2025
3b1f903
Modified the gates accepted
antotu Aug 28, 2025
11e1c5f
Modified list
antotu Aug 28, 2025
bda1604
Fixed bug Swap and Cswap gates
antotu Sep 8, 2025
62e4921
Edit for saving memory GPU
antotu Sep 12, 2025
db799f6
Added patience as variable
antotu Oct 6, 2025
67f8f03
Updated GNN for predictions
antotu Nov 24, 2025
80dd146
partial modification guessed by bot
antotu Nov 24, 2025
2a9ec03
eliminated italian comments
antotu Nov 24, 2025
d5a43a8
eliminated italian comments
antotu Nov 24, 2025
40e3aea
eliminated redundancy torch_clamp
antotu Nov 24, 2025
a6500a1
removed small errors
antotu Nov 24, 2025
78d0b48
solved partial errors
antotu Nov 25, 2025
a46e1c5
Corrected for the lint
antotu Nov 25, 2025
75d7127
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 25, 2025
9c8b478
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 25, 2025
9fad119
error deprecation warning of a useless library
antotu Nov 25, 2025
3daaf58
error deprecation warning of a useless library
antotu Nov 25, 2025
94d1f9a
small error correction
antotu Nov 25, 2025
fe324ec
small error correction
antotu Nov 25, 2025
703d704
small error correction
antotu Nov 25, 2025
be6d69c
problem torch_geometric
antotu Nov 25, 2025
36d95f4
lightweight import
antotu Nov 26, 2025
92d608b
lightweight import
antotu Nov 26, 2025
8d2fc82
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 26, 2025
c0ed915
lightweight import
antotu Nov 26, 2025
cbaaf14
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 26, 2025
3390c41
fixed some code coverage
antotu Nov 26, 2025
c2e7feb
Consider also verbose
antotu Nov 27, 2025
cc7ea91
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 26, 2025
04516ca
Consider also verbose
antotu Nov 27, 2025
87b9cfa
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 27, 2025
475c9d0
Consider also verbose
antotu Nov 27, 2025
f3b0812
Consider adjusted estiamted hellinger distance as regression problem
antotu Nov 27, 2025
446c53c
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 27, 2025
2873464
Fixed regression
antotu Nov 27, 2025
f235567
Fixed regression
antotu Nov 27, 2025
844b0aa
Adjusted threshold for test regression
antotu Nov 28, 2025
8f5bb06
Fixed some code lines for more clarity
antotu Nov 28, 2025
63e747d
Modified a comment on test estimated_hellinger_distance
antotu Nov 28, 2025
d06daa8
🎨 pre-commit fixes
pre-commit-ci[bot] Nov 28, 2025
375105f
Minor fixes
Dec 8, 2025
7c47022
Minor fixes
Dec 8, 2025
fe2f08d
Minor fixes
Dec 8, 2025
8ef4977
Minor fixes
Dec 8, 2025
e9f57fc
Minor fixes
Dec 8, 2025
af30a38
Minor fixes
Dec 8, 2025
43df597
Minor fixes
Dec 8, 2025
48ba216
Minor fixes found
Dec 9, 2025
51c6d43
Minor fixes found
Dec 9, 2025
3cd3db0
Minor fixes found: error in predicting hellinger distance
Dec 9, 2025
59f3efe
Minor fixes found: error in predicting hellinger distance
Dec 9, 2025
033785a
Minor fixes found: error in predicting hellinger distance
Dec 9, 2025
d417dab
Minor modifications
Dec 9, 2025
5ec4edc
Minor modifications
Dec 9, 2025
cb9dc52
Minor modifications
Dec 9, 2025
92051e9
Minor modifications
Dec 9, 2025
fb8dafc
Minor modifications
Dec 9, 2025
f9b833f
Minor fixes code and change of libraries
Dec 10, 2025
54f0fb0
Minor fixes code and change of libraries
Dec 10, 2025
ba6351d
Reduced ignore on warning, modified doc, and minor fixes
Dec 10, 2025
aff599d
Modified documentation
Dec 10, 2025
890444f
Modified documentation
Dec 10, 2025
d7d49ed
Modified documentation
Dec 10, 2025
3990867
Modified taking into account to not use .pt for saving dataset
Dec 10, 2025
9947e12
Modified not using pt for getting the dataset
Dec 10, 2025
7a7dc23
Modified not using pt for getting the dataset
Dec 10, 2025
add1d99
Fixes for code rabbit AI
Dec 10, 2025
6025101
Minor fixes
Dec 10, 2025
a7e632b
🎨 pre-commit fixes
pre-commit-ci[bot] Dec 10, 2025
a938f8c
Minor fixes
Dec 10, 2025
dddb3ae
Minor fixes
Dec 10, 2025
3b77a3a
Minor fixes
Dec 10, 2025
1ac5273
Minor fixes
Dec 10, 2025
354488e
Minor fixes
Dec 11, 2025
34fb8f7
Minor fixes
Dec 11, 2025
21f8c3b
🎨 pre-commit fixes
pre-commit-ci[bot] Dec 11, 2025
716f1c0
Modified documentation
Dec 12, 2025
21ffc4b
🎨 pre-commit fixes
pre-commit-ci[bot] Dec 12, 2025
bd029ea
Modified documentation
Dec 12, 2025
0ea72be
Modified tests
Dec 12, 2025
11812f8
Modified tests
Dec 12, 2025
9e53e27
Modified tests
Dec 12, 2025
69bfc14
Minor modifications
Dec 16, 2025
4983c59
Minor modifications
Dec 16, 2025
0c946d9
🎨 pre-commit fixes
pre-commit-ci[bot] Dec 16, 2025
1f1799d
Minor modifications
Dec 16, 2025
030c6cd
Minor modifications
Dec 16, 2025
4af838d
Verbose regression added for code coverage
Jan 7, 2026
6031b69
Verbose regression added for code coverage
Jan 7, 2026
a37be3e
Added multi-class classification
Jan 7, 2026
214d149
Test empty circuit DAG
Jan 7, 2026
5867a11
Fixed minor issue
Jan 7, 2026
2589728
Fixed minor issue
Jan 7, 2026
301e8b0
Minor fixes
Jan 8, 2026
a10c131
🎨 pre-commit fixes
pre-commit-ci[bot] Jan 20, 2026
d239b2d
...
antotu Jan 20, 2026
de186b3
🎨 pre-commit fixes
pre-commit-ci[bot] Mar 16, 2026
1834ebb
Merge remote-tracking branch 'origin/qce-experiments' into pr/antotu/563
flowerthrower Mar 18, 2026
6ef3b7b
update using GNN
antotu Mar 26, 2026
d74ffb0
update using GNN
antotu Mar 26, 2026
f145c31
fixes pre-commit
antotu Mar 26, 2026
25476d7
Modifications for the pre-commit
Mar 26, 2026
946e969
fixed problem pre commit version
Mar 30, 2026
5fb424d
fixed problem pre commit version
Mar 30, 2026
d68ecbf
Merge branch 'qce-experiments' into RL-compilation-step
antotu Mar 30, 2026
127742e
added safetensors
Mar 30, 2026
4fd7f45
excluded version
Mar 30, 2026
c73883f
torch-geometric and errors fixed
Mar 30, 2026
ae818ea
🎨 pre-commit fixes
pre-commit-ci[bot] Mar 30, 2026
666d8a0
torch-geometric and errors fixed
Mar 30, 2026
d950532
Windows errors fixed
Mar 30, 2026
698692d
Warning deprecation
Mar 30, 2026
517f33f
reduced time for training gnn predictor
Mar 31, 2026
8283fcc
reduced parameters for testing
Mar 31, 2026
45c7f28
fixed problem test hellinger
Mar 31, 2026
0ce8565
Merge commit '505b7e5a46fc0df16ac1283b85e289b6bfad9bf0' into pr/antot…
flowerthrower Apr 6, 2026
301a5ff
🎨 pre-commit fixes
pre-commit-ci[bot] Apr 6, 2026
74a4d42
precommit
flowerthrower Apr 6, 2026
2c1c71f
Merge commit '301a5ff15ce31cd6de623c9bf88f4569eaac1791' into pr/antot…
flowerthrower Apr 6, 2026
ec7965c
Merge commit '0f8c197d1b43dd6c4a951ef72e1f0f55101e5188' into pr/antot…
flowerthrower Apr 12, 2026
74dab1f
🔀 resolve merge confilicts
flowerthrower Apr 12, 2026
3a48ba8
🎨 precommit
flowerthrower Apr 12, 2026
e98e6e7
⏪ revert precommoit config
flowerthrower Apr 12, 2026
662b9ed
🚧 add gnn hyperparam sweep
flowerthrower Apr 12, 2026
9c52792
🚧 fix gnn hyperparam sweep over mutliple circs
flowerthrower Apr 12, 2026
fb5a112
modified parameters
Apr 13, 2026
c18343d
modified hyper-parameters GNN
Apr 13, 2026
1403266
🎨 pre commit
flowerthrower Apr 14, 2026
a3b28af
✨ add optuna checkpoints
flowerthrower Apr 14, 2026
9eb701c
Merge commit 'c18343d823174651948256c1f5774b27b0617808' into pr/antot…
flowerthrower Apr 14, 2026
d9dc34b
🚧 fix gnn hyperparam sweep using passed mdp
flowerthrower Apr 14, 2026
b31d45e
🎨 pre-commit fixes
pre-commit-ci[bot] Apr 14, 2026
c60adfe
Merge commit 'b31d45e1f5f20df67039d9da23877623c903eaf9' into pr/antot…
flowerthrower Apr 14, 2026
d739bf5
✨ add max-episode-steps cap
flowerthrower Apr 15, 2026
b1863b6
🐛 fix ai routing download issue
flowerthrower Apr 15, 2026
45ed380
🚧 add atomic bqskit passes
flowerthrower Apr 16, 2026
5baca00
🎨 update defaults for narrow optuna experiments
flowerthrower Apr 16, 2026
e5ac8ca
🎨 pre-commit fixes
pre-commit-ci[bot] Apr 16, 2026
760f410
🔥 remove too complex passes
flowerthrower Apr 16, 2026
6cb4de6
🔥 remove overkill helpers
flowerthrower Apr 16, 2026
f7edc65
🎨 use assert
flowerthrower Apr 16, 2026
c3d5592
🎨 remove llm bloated files
flowerthrower Apr 16, 2026
35c6237
🎨 precommit
flowerthrower Apr 16, 2026
9c8f59e
Merge commit '35c6237b58f4bd774245e201b49b1b7537f78284' into add-indi…
flowerthrower Apr 16, 2026
03dac27
🎨 improve action logic
flowerthrower Apr 16, 2026
2e5d2e4
⏪ bring back bqskit actions
flowerthrower Apr 17, 2026
97e7b2a
🎨 refactor action file and separate pre- post-processing from environ…
flowerthrower Apr 17, 2026
291ebc3
🎨 streamline action calls
flowerthrower Apr 17, 2026
2ad9767
🎨 add invariant flags
flowerthrower Apr 17, 2026
68b977e
🎨 add checkpoints for gnn
flowerthrower Apr 17, 2026
caa7a97
🎨 pre-commit fixes
pre-commit-ci[bot] Apr 17, 2026
3deed0c
fix batch skript
flowerthrower Apr 18, 2026
6afea89
🎨 pre-commit fixes
pre-commit-ci[bot] Apr 18, 2026
7c85508
✨ add gnn evaluation
flowerthrower Apr 18, 2026
ec5f157
🎨 align ppo error handling with gnn
flowerthrower Apr 18, 2026
799d51c
✨ add timeout for long passes
flowerthrower Apr 21, 2026
c91ea8a
🐛 fix param typo in qiskit_to_tk
flowerthrower Apr 21, 2026
d864780
🐛 handle bqskit error in gate translation gracefully
flowerthrower Apr 23, 2026
85e101f
✨ average over rollouts instead of deterministic
flowerthrower Apr 23, 2026
eacabb1
🎨 only calc delta for optimization passes
flowerthrower Apr 23, 2026
2ea6f7a
🎨 add evaluation output
flowerthrower Apr 23, 2026
77b080d
🐛 zero reward for cross-phase passes
flowerthrower Apr 25, 2026
f0e6bb9
🐛 gnn updates
flowerthrower Apr 25, 2026
d51be20
🚧 add checkpoint
flowerthrower Apr 25, 2026
bc0c158
🐛 fix elide permutation invariant
flowerthrower Apr 26, 2026
80366c6
🐛 fix invariant flags
flowerthrower Apr 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ dependencies = [
"typing-extensions>=4.1", # for `assert_never`
"qiskit-ibm-transpiler>=0.15.0; sys_platform != 'win32' or python_version != '3.13'", # pulls qiskit-serverless, which pulls ray[default], and the resolved ray 2.54.0 has no win_amd64 cp313 wheel
"qiskit-ibm-ai-local-transpiler>=0.5.5",
"torch-geometric>=2.7.0",
"optuna>=3.0.0",
"safetensors>=0.7.0"
]

classifiers = [
Expand Down Expand Up @@ -134,21 +137,24 @@ filterwarnings = [
'ignore:.*The property ``qiskit.circuit.instruction.Instruction.*`` is deprecated as of qiskit 1.3.0.*:DeprecationWarning:',
# Windows: Python 3.13 can emit a RuntimeWarning about unsupported timeouts; keep tests strict otherwise.
'ignore:.*Timeout is not supported on Windows\\.:RuntimeWarning',
'ignore:.*torch_geometric.distributed.*:DeprecationWarning:',
'ignore:.*torch.jit.script.*:DeprecationWarning:',
"ignore:Failing to pass a value to the 'type_params' parameter of 'typing\\._eval_type':DeprecationWarning"

]


[tool.coverage]
run.source = ["mqt.predictor"]
run.disable_warnings = [
"no-sysmon",
]
report.exclude_also = [
'\.\.\.',
'if TYPE_CHECKING:',
'raise AssertionError',
'raise NotImplementedError',
]
run.disable_warnings = [
"no-sysmon",
]
show_missing = true
skip_empty = true
precision = 1
Expand Down Expand Up @@ -251,6 +257,10 @@ wille = "wille"
anc = "anc"
aer = "aer"
fom = "fom"
TPE = "TPE"
TPESampler = "TPESampler"
gae = "gae"
GAE = "GAE"


[tool.repo-review]
Expand Down
30 changes: 21 additions & 9 deletions src/mqt/predictor/hellinger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,18 @@ def calc_device_specific_features(
- The single and multi qubit gate ratio
"""
if ignore_gates is None:
ignore_gates = ["barrier", "id", "measure"]
ignore_gates = [
"barrier",
"id",
"measure",
"if_else",
"while_loop",
"for_loop",
"switch_case",
"box",
"break",
"continue",
]
ignored_ops = set(ignore_gates)

# Targets may advertise control-flow ops like ``if_else``; keep only actual gate features.
Expand Down Expand Up @@ -141,12 +152,13 @@ def calc_device_specific_features(
return np.array(list(feature_dict.values()))


def get_hellinger_model_path(device: Target) -> Path:
"""Returns the path to the trained model folder resulting from the machine learning training."""
training_data_path = Path(str(resources.files("mqt.predictor"))) / "ml" / "training_data"
model_path = (
training_data_path
/ "trained_model"
/ ("trained_hellinger_distance_regressor_" + device.description + ".joblib")
def get_hellinger_model_path(device: Target, gnn: bool = False) -> Path:
"""Returns the path to the trained model file resulting from the machine learning training."""
training_data_path = Path(str(resources.files("mqt.predictor"))) / "ml" / "training_data" / "trained_model"
device_description = str(device.description)
filename = (
(f"trained_hellinger_distance_regressor_gnn_{device_description}.pth")
if gnn
else (f"trained_hellinger_distance_regressor_{device_description}.joblib")
)
return Path(model_path)
return training_data_path / filename
7 changes: 6 additions & 1 deletion src/mqt/predictor/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,9 @@
from mqt.predictor.ml import helper
from mqt.predictor.ml.predictor import Predictor, predict_device_for_figure_of_merit, setup_device_predictor

__all__ = ["Predictor", "helper", "predict_device_for_figure_of_merit", "setup_device_predictor"]
__all__ = [
"Predictor",
"helper",
"predict_device_for_figure_of_merit",
"setup_device_predictor",
]
297 changes: 297 additions & 0 deletions src/mqt/predictor/ml/gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM
# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH
# All rights reserved.
#
# SPDX-License-Identifier: MIT
#
# Licensed under the MIT License

"""Graph neural network models using SAGEConv layers."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
import torch.nn as nn
import torch.nn.functional as functional
from torch_geometric.nn import (
GraphNorm,
SAGEConv,
SAGPooling,
global_mean_pool,
)

if TYPE_CHECKING:
from collections.abc import Callable

from torch_geometric.data import Data


class GraphConvolutionSage(nn.Module):
"""Graph convolutional encoder using SAGEConv layers."""

def __init__(
self,
in_feats: int,
hidden_dim: int,
num_conv_wo_resnet: int,
num_resnet_layers: int,
*,
conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu,
conv_act_kwargs: dict[str, Any] | None = None,
dropout_p: float = 0.2,
bidirectional: bool = True,
use_sag_pool: bool = False,
sag_ratio: float = 0.7,
sag_nonlinearity: Callable[..., torch.Tensor] = torch.tanh,
) -> None:
"""Initialize the graph convolutional encoder.

The encoder consists of a stack of SAGEConv layers followed by
optional SAGPooling before the global readout.

Args:
in_feats: Dimensionality of the node features.
hidden_dim: Output dimensionality of the first SAGEConv layer.
num_conv_wo_resnet: Number of SAGEConv layers before residual
connections are introduced.
num_resnet_layers: Number of SAGEConv layers with residual
connections.
conv_activation: Activation function applied after each graph
convolution. Defaults to torch.nn.functional.leaky_relu.
conv_act_kwargs: Additional keyword arguments passed to
conv_activation. Defaults to None.
dropout_p: Dropout probability applied after each graph layer.
Defaults to 0.2.
bidirectional: If True, apply message passing in both
directions (forward and reversed edges) and average the
results. Defaults to True.
use_sag_pool: If True, apply a single SAGPooling layer after
the convolutions and before readout. Defaults to False.
sag_ratio: Fraction of nodes to keep in SAGPooling. Must be in
(0, 1]. Defaults to 0.7.
sag_nonlinearity: Nonlinearity used inside SAGPooling for score
computation. Defaults to torch.tanh.
"""
super().__init__()

if num_conv_wo_resnet < 1:
msg = "num_conv_wo_resnet must be at least 1"
raise ValueError(msg)

self.conv_activation = conv_activation
self.conv_act_kwargs = conv_act_kwargs or {}
self.bidirectional = bidirectional
self.use_sag_pool = use_sag_pool

# --- GRAPH ENCODER ---
self.convs: nn.ModuleList[SAGEConv] = nn.ModuleList()
self.norms: nn.ModuleList[GraphNorm] = nn.ModuleList()

# First layer: SAGE
self.convs.append(SAGEConv(in_feats, hidden_dim))
out_dim = hidden_dim
self.graph_emb_dim = out_dim
self.norms.append(GraphNorm(out_dim))

# Subsequent layers: SAGE with fixed width == out_dim
for _ in range(num_conv_wo_resnet - 1):
self.convs.append(SAGEConv(out_dim, out_dim))
self.norms.append(GraphNorm(out_dim))
for _ in range(num_resnet_layers):
self.convs.append(SAGEConv(out_dim, out_dim))
self.norms.append(GraphNorm(out_dim))

self.drop = nn.Dropout(dropout_p)
# Start residuals after the initial non-residual stack
self._residual_start = num_conv_wo_resnet
# Expose the final node embedding width
self.out_dim = out_dim

# --- SAGPooling layer (applied once, after all convs) ---
# Uses SAGEConv internally for attention scoring to match the stack.
if self.use_sag_pool:
if not (0.0 < sag_ratio <= 1.0):
msg = "sag_ratio must be in (0, 1]"
raise ValueError(msg)
self.sag_pool: SAGPooling | None = SAGPooling(
in_channels=self.out_dim,
ratio=sag_ratio,
GNN=SAGEConv, # ty: ignore[invalid-argument-type]
nonlinearity=sag_nonlinearity,
)
else:
self.sag_pool = None

def _apply_conv_bidir(
self,
conv: SAGEConv,
x: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
"""Apply a SAGEConv layer in forward and backward directions and average.

Args:
conv: Convolution layer taken from self.convs.
x: Node feature matrix of shape [num_nodes, in_channels].
edge_index: Edge index tensor of shape [2, num_edges].

Returns:
Tensor with updated node features of shape
[num_nodes, out_channels].
"""
x_f = conv(x, edge_index)
if not self.bidirectional:
return x_f
x_b = conv(x, edge_index.flip(0))
return (x_f + x_b) / 2

def forward(self, data: Data) -> torch.Tensor:
"""Encode a batch of graphs and return pooled graph embeddings.

The input batch of graphs is processed by the SAGEConv stack,
optionally followed by SAGPooling, and finally aggregated with
global mean pooling.

Args:
data: Batched torch_geometric.data.Data object.
Expected attributes:
- x: Node features of shape [num_nodes, in_feats].
- edge_index: Edge indices of shape [2, num_edges].
- batch: Graph indices for each node of shape
[num_nodes].

Returns:
Tensor of shape [num_graphs, out_dim] containing one embedding
per input graph.
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
assert x is not None
assert edge_index is not None

for i, conv in enumerate(self.convs):
x_new = self._apply_conv_bidir(conv, x, edge_index)
x_new = self.norms[i](x_new, batch=batch)
x_new = self.conv_activation(x_new, **self.conv_act_kwargs)
x_new = self.drop(x_new)

x = x_new if i < self._residual_start else x + x_new

# --- SAGPooling (hierarchical pooling before readout) ---
if self.sag_pool is not None:
# SAGPooling may also return edge_attr, perm, score; we ignore those here.
x, edge_index, _, batch, _, _ = self.sag_pool(
x,
edge_index,
batch=batch,
)

return global_mean_pool(x, batch)


class GNN(nn.Module):
"""Graph neural network with a SAGE-based encoder and MLP head.

This model first encodes each input graph using GraphConvolutionSage
and then applies a feed-forward neural network to the resulting graph
embeddings to produce the final prediction.
"""

def __init__(
self,
in_feats: int,
hidden_dim: int,
num_conv_wo_resnet: int,
num_resnet_layers: int,
mlp_units: list[int],
*,
conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu,
conv_act_kwargs: dict[str, Any] | None = None,
mlp_activation: Callable[..., torch.Tensor] = functional.leaky_relu,
mlp_act_kwargs: dict[str, Any] | None = None,
dropout_p: float = 0.2,
bidirectional: bool = True,
output_dim: int = 1,
use_sag_pool: bool = False,
sag_ratio: float = 0.7,
sag_nonlinearity: Callable[..., torch.Tensor] = torch.tanh,
) -> None:
"""Initialize the GNN model.

Args:
in_feats: Dimensionality of the input node features.
hidden_dim: Hidden dimensionality of the SAGEConv layers.
num_conv_wo_resnet: Number of SAGEConv layers before residual
connections are introduced in the encoder.
num_resnet_layers: Number of SAGEConv layers with residual
connections in the encoder.
mlp_units: List specifying the number of units in each hidden
layer of the MLP head.
conv_activation: Activation function applied after each graph
convolution. Defaults to torch.nn.functional.leaky_relu.
conv_act_kwargs: Additional keyword arguments passed to
conv_activation. Defaults to None.
mlp_activation: Activation function applied after each MLP layer.
Defaults to torch.nn.functional.leaky_relu.
mlp_act_kwargs: Additional keyword arguments passed to
mlp_activation. Defaults to None.
dropout_p: Dropout probability applied in the model (graph encoder and the MLP).
Defaults to 0.2.
bidirectional: If True, apply bidirectional message passing in
the encoder. Defaults to True.
output_dim: Dimensionality of the model output (e.g. number of
targets per graph). Defaults to 1.
use_sag_pool: If True, enable SAGPooling in the encoder.
Defaults to False.
sag_ratio: Fraction of nodes to keep in SAGPooling. Must be in
(0, 1]. Defaults to 0.7.
sag_nonlinearity: Nonlinearity used inside SAGPooling for score
computation. Defaults to torch.tanh.
"""
super().__init__()

# Graph encoder
self.graph_conv = GraphConvolutionSage(
in_feats=in_feats,
hidden_dim=hidden_dim,
num_conv_wo_resnet=num_conv_wo_resnet,
num_resnet_layers=num_resnet_layers,
conv_activation=conv_activation,
conv_act_kwargs=conv_act_kwargs,
dropout_p=dropout_p,
bidirectional=bidirectional,
use_sag_pool=use_sag_pool,
sag_ratio=sag_ratio,
sag_nonlinearity=sag_nonlinearity,
)

self.mlp_activation = mlp_activation
self.mlp_act_kwargs = mlp_act_kwargs or {}
last_dim = self.graph_conv.graph_emb_dim
self.mlp_drop = nn.Dropout(dropout_p)
self.fcs: nn.ModuleList[nn.Linear] = nn.ModuleList()
for out_dim_ in mlp_units:
self.fcs.append(nn.Linear(last_dim, out_dim_))
last_dim = out_dim_
self.out = nn.Linear(last_dim, output_dim)

def forward(self, data: Data) -> torch.Tensor:
"""Compute predictions for a batch of graphs.

The input graphs are encoded into graph embeddings by the
GraphConvolutionSage encoder, then passed through the MLP head
to obtain final predictions.

Args:
data: Batched torch_geometric.data.Data object
containing the graphs to be evaluated.

Returns:
Tensor of shape [num_graphs, output_dim] with the model
predictions for each graph in the batch.
"""
x = self.graph_conv(data)
for fc in self.fcs:
x = self.mlp_drop(self.mlp_activation(fc(x), **self.mlp_act_kwargs))
return self.out(x)
Loading