Skip to content

Commit e586e24

Browse files
committed
lstm cuda test
1 parent d4a0a3c commit e586e24

3 files changed

Lines changed: 27 additions & 6 deletions

File tree

docs/source/api/utils.rst

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,34 @@ Utility Functions
2626
:undoc-members:
2727
:show-inheritance:
2828

29-
.. automodule:: torch_molecule.utils.graph
29+
.. automodule:: torch_molecule.utils.graph.features
3030
:members:
3131
:undoc-members:
3232
:show-inheritance:
3333

34-
.. automodule:: torch_molecule.utils.generic
34+
.. automodule:: torch_molecule.utils.graph.graph_from_smiles
35+
:members:
36+
:undoc-members:
37+
:show-inheritance:
38+
39+
.. automodule:: torch_molecule.utils.graph.graph_to_smiles
40+
:members:
41+
:undoc-members:
42+
:show-inheritance:
43+
44+
.. automodule:: torch_molecule.utils.generic.metrics
45+
:members:
46+
:undoc-members:
47+
:show-inheritance:
48+
49+
.. automodule:: torch_molecule.utils.generic.pseudo_tasks
3550
:members:
3651
:undoc-members:
3752
:show-inheritance:
3853

39-
54+
.. automodule:: torch_molecule.utils.generic.weights
55+
:members:
56+
:undoc-members:
57+
:show-inheritance:
58+
59+

docs/source/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
templates_path = ['_templates']
2525
exclude_patterns = []
2626

27-
2827
# -- Options for HTML output -------------------------------------------------
2928
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
3029

tests/predictor/run_lstm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from torch_molecule import LSTMMolecularPredictor
33
from torch_molecule.utils.search import ParameterType, ParameterSpec
4-
4+
import torch
55
def test_lstm_predictor():
66
# Test data
77
smiles_list = [
@@ -67,13 +67,15 @@ def test_lstm_predictor():
6767
model.save_to_local(save_path)
6868
print(f"Model saved to {save_path}")
6969

70+
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
7071
new_model = LSTMMolecularPredictor(
7172
task_type="regression",
7273
output_dim=15,
7374
LSTMunits=60,
7475
batch_size=2,
7576
epochs=2,
76-
device="cpu"
77+
# device="cpu"
78+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
7779
)
7880
new_model.load_from_local(save_path)
7981
print("Model loaded successfully")

0 commit comments

Comments
 (0)