Skip to content

Commit 9b0fe18

Browse files
authored
Merge pull request #308 from zaz/fix-deprecations
Fix deprecations
2 parents 02bcc2c + 70959fc commit 9b0fe18

6 files changed

Lines changed: 50 additions & 18 deletions

File tree

test/transforms/liftings/graph2simplicial/test_latentclique_lifting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_lift_topology(self):
7373
# (or, equivalently, the SC has a simplex in its facets set if complex_dim = |maximal_clique|-1)
7474

7575
# Convert adjacency matrix to NetworkX graph
76-
G_from_latent_complex = nx.from_numpy_matrix(
76+
G_from_latent_complex = nx.from_numpy_array(
7777
edge_prob_one_adj.to_dense().numpy()
7878
)
7979
G_input = nx.Graph()
@@ -95,7 +95,7 @@ def test_lift_topology(self):
9595
# (or, equivalently, there is no subset of the 1-skeleton of the SC isomorphic to the input graph)
9696

9797
# Convert adjacency matrix to NetworkX graph
98-
G_from_latent_complex = nx.from_numpy_matrix(
98+
G_from_latent_complex = nx.from_numpy_array(
9999
edge_prob_any_adj.to_dense().numpy()
100100
)
101101
G_input = nx.Graph()

test/transforms/liftings/pointcloud2hypergraph/test_mogmst_lifting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def setup_method(self):
2828
[0.16, 0.45],
2929
]
3030
)
31-
self.data = Data(x=pos, y=torch.tensor(y))
31+
self.data = Data(x=pos, y=y)
3232

3333
# Initialise the HypergraphKHopLifting class
3434
self.lifting = MoGMSTLifting(min_components=3, random_state=0)

topobench/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,38 @@
11
"""TopoBench: A library for benchmarking of topological models."""
22

3+
# torch >= 2.6 defaults to weights_only=True in torch.load, but OGB and
4+
# older PyG code serialize these classes. Register them as safe so that
5+
# torch.load works without weights_only=False everywhere.
6+
import numpy as np
7+
import torch
8+
9+
if hasattr(torch.serialization, "add_safe_globals"):
10+
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
11+
from torch_geometric.data.storage import (
12+
EdgeStorage,
13+
GlobalStorage,
14+
NodeStorage,
15+
)
16+
17+
safe_globals = [
18+
DataEdgeAttr,
19+
DataTensorAttr,
20+
GlobalStorage,
21+
NodeStorage,
22+
EdgeStorage,
23+
np.core.multiarray.scalar,
24+
np.dtype,
25+
]
26+
# numpy >= 1.25 uses typed DType subclasses (e.g. Int64DType) in pickle
27+
# streams; register all of them so weights_only=True succeeds.
28+
import numpy.dtypes
29+
30+
for name in dir(numpy.dtypes):
31+
obj = getattr(numpy.dtypes, name)
32+
if isinstance(obj, type) and name.endswith("DType"):
33+
safe_globals.append(obj)
34+
torch.serialization.add_safe_globals(safe_globals)
35+
336
# Import submodules
437
from . import (
538
data,

topobench/transforms/data_manipulations/group_homophily.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def forward(self, data: torch_geometric.data.Data):
8585
if max_k != 1:
8686
H_k = H[:, torch.where(he_cardinalities == max_k)[0]].clone()
8787

88-
he_cardinalities_k = torch.tensor(H_k.sum(0), dtype=torch.long)
88+
he_cardinalities_k = (
89+
H_k.sum(0).detach().clone().to(dtype=torch.long)
90+
)
8991
Dt, D = self.calculate_D_matrix(
9092
H_k,
9193
labels,

topobench/transforms/liftings/pointcloud2simplicial/random_flag_complex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def lift_topology(self, data: Data) -> dict:
8585
for i in range(n):
8686
st.insert([i])
8787

88-
graph: nx.Graph = nx.from_numpy_matrix(adj_mat).to_undirected()
88+
graph: nx.Graph = nx.from_numpy_array(adj_mat).to_undirected()
8989

9090
# Insert all edges
9191
for v, u in graph.edges:

uv_env_setup.sh

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,16 @@ echo "======================================================="
1919
# ------------------------------------------------------------------------------
2020
TORCH_VER="2.3.0"
2121

22-
if [ "$PLATFORM" == "cpu" ]; then
23-
TARGET_INDEX="pytorch-cpu"
24-
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+cpu.html"
25-
elif [ "$PLATFORM" == "cu118" ]; then
26-
TARGET_INDEX="pytorch-cu118"
27-
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+cu118.html"
28-
elif [ "$PLATFORM" == "cu121" ]; then
29-
TARGET_INDEX="pytorch-cu121"
30-
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+cu121.html"
31-
else
32-
echo "❌ Error: Invalid platform '$PLATFORM'. Use: cpu, cu118, or cu121."
33-
return 1 2>/dev/null || exit 1
34-
fi
22+
case "$PLATFORM" in
23+
cpu|cu118|cu121)
24+
TARGET_INDEX="pytorch-${PLATFORM}"
25+
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+${PLATFORM}.html"
26+
;;
27+
*)
28+
echo "❌ Error: Invalid platform '$PLATFORM'. Use: cpu, cu118, or cu121."
29+
return 1 2>/dev/null || exit 1
30+
;;
31+
esac
3532

3633
echo "⚙️ Updating pyproject.toml..."
3734

0 commit comments

Comments
 (0)