Skip to content

Commit 8973795

Browse files
committed
remove dependency on torch_scatter
1 parent 0f335db commit 8973795

8 files changed

Lines changed: 112 additions & 15 deletions

File tree

docs/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ torch==2.2.0+cu118
55
# Install PyTorch Geometric and related packages
66
-f https://data.pyg.org/whl/torch-2.2.0+cu118.html
77
torch_geometric==2.6.1
8-
torch_cluster
9-
torch_scatter
8+
# torch_cluster
9+
# torch_scatter
1010

1111
# Other dependencies
1212
huggingface_hub

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ classifiers = [
1818
dependencies = [
1919
"torch>=2.2.0",
2020
"torch-geometric>=2.6.1",
21-
"torch-scatter>=2.1.2",
2221
"numpy",
2322
"pandas>=2.2.3",
2423
"click",
@@ -36,6 +35,7 @@ dependencies = [
3635
[project.urls]
3736
"Homepage" = "https://github.com/liugangcode/torch-molecule"
3837
"Bug Tracker" = "https://github.com/liugangcode/torch-molecule"
38+
"Documentation" = "https://liugangcode.github.io/torch-molecule/"
3939

4040
[tool.setuptools.packages.find]
4141
include = ["torch_molecule*"]

requirements.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ torch==2.2.0+cu118
55
# Install PyTorch Geometric and related packages
66
-f https://data.pyg.org/whl/torch-2.2.0+cu118.html
77
torch_geometric==2.6.1
8-
torch_cluster
9-
torch_scatter
8+
# torch_cluster
9+
# torch_scatter
1010

1111
# Other dependencies
12-
huggingface_hub
1312
joblib==1.3.2
1413
networkx==3.2.1
1514
pandas==2.2.3
@@ -19,11 +18,13 @@ scikit_learn==1.4.1.post1
1918
scipy==1.14.1
2019
tqdm==4.66.2
2120

21+
# huggingface
22+
huggingface_hub
2223
optuna
23-
# ogb
24-
25-
# pytest
2624

2725
# docs
2826
sphinx
2927
furo
28+
29+
# ogb
30+
# pytest

tests/predictor/run_grea.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_grea_upload():
218218
print("\n=== Testing model loading from Hugging Face Hub ===")
219219
# Load model
220220
downloaded_model = GREAMolecularPredictor()
221-
downloaded_model.load_from_hf(repo_id=repo_id, path="./downloaded_model/GREA_O2.pt")
221+
downloaded_model.load_from_hf(repo_id=repo_id, local_cache="./downloaded_model/GREA_O2.pt")
222222

223223
# Test prediction with downloaded model
224224
test_pred = downloaded_model.predict(smiles_list[3:])

torch_molecule/base/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def load_from_hf(self, repo_id: str, local_cache: Optional[str] = None, config_f
280280
"""
281281
if local_cache is None:
282282
local_cache = 'model.pt'
283-
HuggingFaceCheckpointManager.load_model_from_hf(self, repo_id, local_cache, config_filename)
283+
HuggingFaceCheckpointManager.load_model_from_hf(self, repo_id, local_cache, config_filename=config_filename)
284284

285285
def save(self, path: Optional[str] = None, repo_id: Optional[str] = None, **kwargs) -> None:
286286
"""Automatic save to either local disk or Hugging Face Hub.

torch_molecule/predictor/grea/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import torch
33
import torch.nn as nn
4-
from torch_scatter import scatter_add
4+
# from torch_scatter import scatter_add
5+
from .utils import scatter_add
56

67
from ...nn import GNN_node, GNN_node_Virtualnode, MLP
78
from ...utils import init_weights

torch_molecule/predictor/grea/modeling_grea.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import os
21
import numpy as np
32
import warnings
4-
import datetime
53
from tqdm import tqdm
6-
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal, Type
4+
from typing import Optional, Union, Dict, Any, List, Type
75
from dataclasses import dataclass, field
86

97
import torch
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# the code is torch_scatter: https://github.com/rusty1s/pytorch_scatter/blob/1.3.0/torch_scatter/add.py
2+
from itertools import repeat
3+
4+
def maybe_dim_size(index, dim_size=None):
5+
if dim_size is not None:
6+
return dim_size
7+
return index.max().item() + 1 if index.numel() > 0 else 0
8+
9+
10+
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
11+
dim = range(src.dim())[dim] # Get real dim value.
12+
13+
# Automatically expand index tensor to the right dimensions.
14+
if index.dim() == 1:
15+
index_size = list(repeat(1, src.dim()))
16+
index_size[dim] = src.size(dim)
17+
index = index.view(index_size).expand_as(src)
18+
19+
# Generate output tensor if not given.
20+
if out is None:
21+
out_size = list(src.size())
22+
dim_size = maybe_dim_size(index, dim_size)
23+
out_size[dim] = dim_size
24+
out = src.new_full(out_size, fill_value)
25+
26+
return src, out, index, dim
27+
28+
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
29+
r"""
30+
|
31+
32+
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
33+
master/docs/source/_figures/add.svg?sanitize=true
34+
:align: center
35+
:width: 400px
36+
37+
|
38+
39+
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
40+
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
41+
each value in :attr:`src`, its output index is specified by its index in
42+
:attr:`input` for dimensions outside of :attr:`dim` and by the
43+
corresponding value in :attr:`index` for dimension :attr:`dim`. If
44+
multiple indices reference the same location, their **contributions add**.
45+
46+
Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
47+
size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
48+
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
49+
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
50+
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
51+
52+
For one-dimensional tensors, the operation computes
53+
54+
.. math::
55+
\mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
56+
57+
where :math:`\sum_j` is over :math:`j` such that
58+
:math:`\mathrm{index}_j = i`.
59+
60+
Args:
61+
src (Tensor): The source tensor.
62+
index (LongTensor): The indices of elements to scatter.
63+
dim (int, optional): The axis along which to index.
64+
(default: :obj:`-1`)
65+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
66+
dim_size (int, optional): If :attr:`out` is not given, automatically
67+
create output with size :attr:`dim_size` at dimension :attr:`dim`.
68+
If :attr:`dim_size` is not given, a minimal sized output tensor is
69+
returned. (default: :obj:`None`)
70+
fill_value (int, optional): If :attr:`out` is not given, automatically
71+
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
72+
73+
:rtype: :class:`Tensor`
74+
75+
.. testsetup::
76+
77+
import torch
78+
79+
.. testcode::
80+
81+
from torch_scatter import scatter_add
82+
83+
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
84+
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
85+
out = src.new_zeros((2, 6))
86+
87+
out = scatter_add(src, index, out=out)
88+
89+
print(out)
90+
91+
.. testoutput::
92+
93+
tensor([[0., 0., 4., 3., 3., 0.],
94+
[2., 4., 4., 0., 0., 0.]])
95+
"""
96+
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
97+
return out.scatter_add_(dim, index, src)

0 commit comments

Comments
 (0)