66
77# Define the PyTorch-based LSTM model
88class LSTM (nn .Module ):
9+ """LSTM-based model for molecular property prediction.
10+
11+ Parameters
12+ ----------
13+ num_task : int
14+ Number of prediction tasks.
15+ input_dim : int
16+ Size of vocabulary for SMILES tokenization.
17+ output_dim : int
18+ Dimension of embedding vectors.
19+ LSTMunits : int
20+ Number of hidden units in LSTM layers.
21+ max_input_len : int
22+ Maximum length of input sequences.
23+
24+ Attributes
25+ ----------
26+ embedding : nn.Embedding
27+ Embedding layer that converts token indices to dense vectors.
28+ lstm1 : nn.LSTM
29+ First bidirectional LSTM layer.
30+ lstm2 : nn.LSTM
31+ Second bidirectional LSTM layer.
32+ timedist_dense : nn.Linear
33+ Time-distributed dense layer for feature transformation.
34+ relu : nn.ReLU
35+ ReLU activation function.
36+ fc : nn.Linear
37+ Final fully connected layer for prediction.
38+
39+ Notes
40+ -----
41+ The model architecture consists of:
42+ 1. An embedding layer to convert tokens to vectors
43+ 2. Two stacked bidirectional LSTM layers
44+ 3. A time-distributed dense layer with ReLU activation
45+ 4. A final fully connected layer for prediction
46+ """
47+
948 def __init__ (self , num_task , input_dim , output_dim , LSTMunits , max_input_len ):
1049 """
1150 input_dim: Vocabulary size
@@ -23,12 +62,14 @@ def __init__(self, num_task, input_dim, output_dim, LSTMunits, max_input_len):
2362 self .relu = nn .ReLU ()
2463 self .fc = nn .Linear (hidden_dim * max_input_len , 1 )
2564
65+
2666 def initialize_parameters (self , seed = None ):
27- """
28- Randomly initialize all model parameters using the init_weights function.
67+ """Initialize model parameters randomly.
2968
30- Args:
31- seed (int, optional): Random seed for reproducibility. Defaults to None.
69+ Parameters
70+ ----------
71+ seed : int, optional
72+ Random seed for reproducibility.
3273 """
3374 if seed is not None :
3475 torch .manual_seed (seed )
@@ -51,6 +92,22 @@ def reset_parameters(module):
5192 self .apply (reset_parameters )
5293
5394 def compute_loss (self , batched_input , batched_label , criterion ):
95+ """Compute the loss for a batch of data.
96+
97+ Parameters
98+ ----------
99+ batched_input : torch.Tensor
100+ Batch of input sequences, shape (batch_size, seq_len).
101+ batched_label : torch.Tensor
102+ Batch of target values, shape (batch_size, 1).
103+ criterion : callable
104+ Loss function to use.
105+
106+ Returns
107+ -------
108+ torch.Tensor
109+ Scalar loss value.
110+ """
54111 emb = self .embedding (batched_input )
55112 emb , _ = self .lstm1 (emb )
56113 emb , _ = self .lstm2 (emb )
@@ -63,6 +120,19 @@ def compute_loss(self, batched_input, batched_label, criterion):
63120 return loss
64121
65122 def forward (self , batched_input ):
123+ """Forward pass of the model.
124+
125+ Parameters
126+ ----------
127+ batched_input : torch.Tensor
128+ Batch of input sequences, shape (batch_size, seq_len).
129+
130+ Returns
131+ -------
132+ dict
133+ Dictionary containing:
134+ - prediction: Model predictions (shape: [batch_size, 1])
135+ """
66136 # batched_data: (batch_size, seq_len)
67137 emb = self .embedding (batched_input ) # -> (batch, seq_len, output_dim)
68138 emb , _ = self .lstm1 (emb ) # -> (batch, seq_len, 2*LSTMunits)
0 commit comments