Skip to content

Commit d4a0a3c

Browse files
author
Yihan Zhu
committed
Update LSTM
1 parent 50abfea commit d4a0a3c

5 files changed

Lines changed: 382 additions & 67 deletions

File tree

tests/predictor/run_lstm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def test_lstm_predictor():
2020
LSTMunits=60,
2121
batch_size=2,
2222
epochs=2,
23-
device="cpu"
23+
device="cpu",
24+
verbose=True
2425
)
2526
print("Model initialized successfully")
2627

@@ -49,7 +50,7 @@ def test_lstm_predictor():
4950
task_type="regression",
5051
epochs=3, # Small number for testing
5152
# verbose=True
52-
verbose=False
53+
verbose=True
5354
)
5455

5556
model_auto.autofit(
@@ -67,8 +68,12 @@ def test_lstm_predictor():
6768
print(f"Model saved to {save_path}")
6869

6970
new_model = LSTMMolecularPredictor(
70-
num_task=1,
71-
task_type="regression"
71+
task_type="regression",
72+
output_dim=15,
73+
LSTMunits=60,
74+
batch_size=2,
75+
epochs=2,
76+
device="cpu"
7277
)
7378
new_model.load_from_local(save_path)
7479
print("Model loaded successfully")

tests/predictor/run_rpgnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def train_rpgnn_predictor():
2929
fixed_size=10,
3030
hidden_size=300,
3131
batch_size=128,
32-
epochs=300,
32+
epochs=2,
3333
verbose=True
3434
)
3535
print("RPGNN model initialized successfully")

torch_molecule/predictor/lstm/model.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,45 @@
66

77
# Define the PyTorch-based LSTM model
88
class LSTM(nn.Module):
9+
"""LSTM-based model for molecular property prediction.
10+
11+
Parameters
12+
----------
13+
num_task : int
14+
Number of prediction tasks.
15+
input_dim : int
16+
Size of vocabulary for SMILES tokenization.
17+
output_dim : int
18+
Dimension of embedding vectors.
19+
LSTMunits : int
20+
Number of hidden units in LSTM layers.
21+
max_input_len : int
22+
Maximum length of input sequences.
23+
24+
Attributes
25+
----------
26+
embedding : nn.Embedding
27+
Embedding layer that converts token indices to dense vectors.
28+
lstm1 : nn.LSTM
29+
First bidirectional LSTM layer.
30+
lstm2 : nn.LSTM
31+
Second bidirectional LSTM layer.
32+
timedist_dense : nn.Linear
33+
Time-distributed dense layer for feature transformation.
34+
relu : nn.ReLU
35+
ReLU activation function.
36+
fc : nn.Linear
37+
Final fully connected layer for prediction.
38+
39+
Notes
40+
-----
41+
The model architecture consists of:
42+
1. An embedding layer to convert tokens to vectors
43+
2. Two stacked bidirectional LSTM layers
44+
3. A time-distributed dense layer with ReLU activation
45+
4. A final fully connected layer for prediction
46+
"""
47+
948
def __init__(self, num_task, input_dim, output_dim, LSTMunits, max_input_len):
1049
"""
1150
input_dim: Vocabulary size
@@ -23,12 +62,14 @@ def __init__(self, num_task, input_dim, output_dim, LSTMunits, max_input_len):
2362
self.relu = nn.ReLU()
2463
self.fc = nn.Linear(hidden_dim * max_input_len, 1)
2564

65+
2666
def initialize_parameters(self, seed=None):
27-
"""
28-
Randomly initialize all model parameters using the init_weights function.
67+
"""Initialize model parameters randomly.
2968
30-
Args:
31-
seed (int, optional): Random seed for reproducibility. Defaults to None.
69+
Parameters
70+
----------
71+
seed : int, optional
72+
Random seed for reproducibility.
3273
"""
3374
if seed is not None:
3475
torch.manual_seed(seed)
@@ -51,6 +92,22 @@ def reset_parameters(module):
5192
self.apply(reset_parameters)
5293

5394
def compute_loss(self, batched_input, batched_label, criterion):
95+
"""Compute the loss for a batch of data.
96+
97+
Parameters
98+
----------
99+
batched_input : torch.Tensor
100+
Batch of input sequences, shape (batch_size, seq_len).
101+
batched_label : torch.Tensor
102+
Batch of target values, shape (batch_size, 1).
103+
criterion : callable
104+
Loss function to use.
105+
106+
Returns
107+
-------
108+
torch.Tensor
109+
Scalar loss value.
110+
"""
54111
emb = self.embedding(batched_input)
55112
emb, _ = self.lstm1(emb)
56113
emb, _ = self.lstm2(emb)
@@ -63,6 +120,19 @@ def compute_loss(self, batched_input, batched_label, criterion):
63120
return loss
64121

65122
def forward(self, batched_input):
123+
"""Forward pass of the model.
124+
125+
Parameters
126+
----------
127+
batched_input : torch.Tensor
128+
Batch of input sequences, shape (batch_size, seq_len).
129+
130+
Returns
131+
-------
132+
dict
133+
Dictionary containing:
134+
- prediction: Model predictions (shape: [batch_size, 1])
135+
"""
66136
# batched_data: (batch_size, seq_len)
67137
emb = self.embedding(batched_input) # -> (batch, seq_len, output_dim)
68138
emb, _ = self.lstm1(emb) # -> (batch, seq_len, 2*LSTMunits)

0 commit comments

Comments
 (0)