Skip to content

Commit f6ccf82

Browse files
committed
update docs
1 parent 6a64b05 commit f6ccf82

3 files changed

Lines changed: 254 additions & 9 deletions

File tree

docs/source/api.rst

Lines changed: 167 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,183 @@
11
API Reference
22
=============
33

4-
.. automodule:: torch_molecule.predictor
4+
This section documents the main modules and classes in `torch_molecule`.
5+
6+
Base Modules
7+
------------
8+
9+
.. automodule:: torch_molecule.base
10+
:members:
11+
:undoc-members:
12+
:show-inheritance:
13+
14+
.. automodule:: torch_molecule.base.encoder
15+
:members:
16+
:undoc-members:
17+
:show-inheritance:
18+
19+
.. automodule:: torch_molecule.base.generator
20+
:members:
21+
:undoc-members:
22+
:show-inheritance:
23+
24+
.. automodule:: torch_molecule.base.predictor
25+
:members:
26+
:undoc-members:
27+
:show-inheritance:
28+
29+
30+
Encoders
31+
--------
32+
33+
.. automodule:: torch_molecule.encoder.constant
34+
:members:
35+
:undoc-members:
36+
:show-inheritance:
37+
38+
.. automodule:: torch_molecule.encoder.attrmask
39+
:members:
40+
:undoc-members:
41+
:show-inheritance:
42+
43+
.. automodule:: torch_molecule.encoder.contextpred
44+
:members:
45+
:undoc-members:
46+
:show-inheritance:
47+
48+
.. automodule:: torch_molecule.encoder.edgepred
49+
:members:
50+
:undoc-members:
51+
:show-inheritance:
52+
53+
.. automodule:: torch_molecule.encoder.moama
54+
:members:
55+
:undoc-members:
56+
:show-inheritance:
57+
58+
.. automodule:: torch_molecule.encoder.supervised
59+
:members:
60+
:undoc-members:
61+
:show-inheritance:
62+
63+
.. (You can repeat this for other encoder modules like `attrmask`, `contextpred` if they are Python files or have `__init__.py` and docstrings inside.)
64+
65+
66+
Generators
67+
----------
68+
69+
.. (If `digress`, `graph_dit`, etc., are packages with docstrings, add them.)
70+
71+
.. automodule:: torch_molecule.generator.digress
72+
:members:
73+
:undoc-members:
74+
:show-inheritance:
75+
76+
.. automodule:: torch_molecule.generator.graph_dit
77+
:members:
78+
:undoc-members:
79+
:show-inheritance:
80+
81+
.. automodule:: torch_molecule.generator.graphga
82+
:members:
83+
:undoc-members:
84+
:show-inheritance:
85+
86+
87+
Neural Network Components
88+
-------------------------
89+
90+
.. automodule:: torch_molecule.nn.attention
91+
:members:
92+
:undoc-members:
93+
:show-inheritance:
94+
95+
.. automodule:: torch_molecule.nn.embedder
96+
:members:
97+
:undoc-members:
98+
:show-inheritance:
99+
100+
.. automodule:: torch_molecule.nn.gnn
101+
:members:
102+
:undoc-members:
103+
:show-inheritance:
104+
105+
.. automodule:: torch_molecule.nn.mlp
106+
:members:
107+
:undoc-members:
108+
:show-inheritance:
109+
110+
111+
Predictors
112+
----------
113+
114+
.. (Add predictors here as needed. You might write summaries and link submodules.)
115+
116+
.. automodule:: torch_molecule.predictor.gnn
117+
:members:
118+
:undoc-members:
119+
:show-inheritance:
120+
121+
.. automodule:: torch_molecule.predictor.dir
122+
:members:
123+
:undoc-members:
124+
:show-inheritance:
125+
126+
.. automodule:: torch_molecule.predictor.grea
127+
:members:
128+
:undoc-members:
129+
:show-inheritance:
130+
131+
.. automodule:: torch_molecule.predictor.sgir
132+
:members:
133+
:undoc-members:
134+
:show-inheritance:
135+
136+
.. automodule:: torch_molecule.predictor.irm
137+
:members:
138+
:undoc-members:
139+
:show-inheritance:
140+
141+
.. automodule:: torch_molecule.predictor.lstm
142+
:members:
143+
:undoc-members:
144+
:show-inheritance:
145+
146+
.. automodule:: torch_molecule.predictor.rpgnn
5147
:members:
6148
:undoc-members:
7149
:show-inheritance:
8150

9-
.. automodule:: torch_molecule.encoder
151+
.. automodule:: torch_molecule.predictor.ssr
10152
:members:
11153
:undoc-members:
12154
:show-inheritance:
155+
13156

14-
.. automodule:: torch_molecule.generator
157+
Utilities
158+
---------
159+
160+
.. automodule:: torch_molecule.utils.checker
161+
:members:
162+
:undoc-members:
163+
:show-inheritance:
164+
165+
.. automodule:: torch_molecule.utils.checkpoint
15166
:members:
16167
:undoc-members:
17168
:show-inheritance:
18169

170+
.. automodule:: torch_molecule.utils.format
171+
:members:
172+
:undoc-members:
173+
:show-inheritance:
19174

175+
.. automodule:: torch_molecule.utils.hf
176+
:members:
177+
:undoc-members:
178+
:show-inheritance:
179+
180+
.. automodule:: torch_molecule.utils.search
181+
:members:
182+
:undoc-members:
183+
:show-inheritance:

docs/source/conf.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
# -- General configuration ---------------------------------------------------
1515
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
1616

17-
extensions = []
17+
extensions = [
18+
"sphinx.ext.autodoc",
19+
"sphinx.ext.napoleon", # for NumPy/Google-style docstrings
20+
"sphinx.ext.viewcode", # optional: adds [source] links
21+
]
22+
1823

1924
templates_path = ['_templates']
2025
exclude_patterns = []
@@ -30,3 +35,4 @@
3035
import os
3136
import sys
3237
sys.path.insert(0, os.path.abspath('../../torch_molecule')) # adjust as needed
38+

torch_molecule/nn/embedder.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,17 @@
44

55
class TimestepEmbedder(nn.Module):
66
"""
7-
Embeds scalar timesteps into vector representations.
7+
Embeds scalar timesteps into vector representations using a sinusoidal embedding
8+
followed by a multilayer perceptron (MLP).
9+
10+
Parameters
11+
----------
12+
hidden_size : int
13+
Output dimension of the MLP embedding.
14+
frequency_embedding_size : int, optional
15+
Size of the input frequency embedding, by default 256.
816
"""
17+
918
def __init__(self, hidden_size, frequency_embedding_size=256):
1019
super().__init__()
1120
self.mlp = nn.Sequential(
@@ -19,8 +28,9 @@ def __init__(self, hidden_size, frequency_embedding_size=256):
1928
def timestep_embedding(t, dim, max_period=10000):
2029
"""
2130
Create sinusoidal timestep embeddings.
31+
2232
:param t: a 1-D Tensor of N indices, one per batch element.
23-
These may be fractional.
33+
These may be fractional.
2434
:param dim: the dimension of the output.
2535
:param max_period: controls the minimum frequency of the embeddings.
2636
:return: an (N, D) Tensor of positional embeddings.
@@ -37,15 +47,37 @@ def timestep_embedding(t, dim, max_period=10000):
3747
return embedding
3848

3949
def forward(self, t):
50+
"""
51+
Forward pass for timestep embedding.
52+
53+
Parameters
54+
----------
55+
t : torch.Tensor
56+
1D tensor of scalar timesteps.
57+
58+
Returns
59+
-------
60+
torch.Tensor
61+
The final embedded representation of shape (N, hidden_size).
62+
"""
4063
t = t.view(-1)
4164
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
4265
t_emb = self.mlp(t_freq)
4366
return t_emb
4467

4568
class CategoricalEmbedder(nn.Module):
4669
"""
47-
Embeds categorical conditions such as data sources into vector representations.
48-
Also handles label dropout for classifier-free guidance.
70+
Embeds categorical conditions (e.g., data source labels) into vector representations.
71+
Supports label dropout for classifier-free guidance.
72+
73+
Parameters
74+
----------
75+
num_classes : int
76+
Number of distinct label categories.
77+
hidden_size : int
78+
Size of the embedding vectors.
79+
dropout_prob : float
80+
Probability of label dropout.
4981
"""
5082
def __init__(self, num_classes, hidden_size, dropout_prob):
5183
super().__init__()
@@ -57,6 +89,18 @@ def __init__(self, num_classes, hidden_size, dropout_prob):
5789
def token_drop(self, labels, force_drop_ids=None):
5890
"""
5991
Drops labels to enable classifier-free guidance.
92+
93+
Parameters
94+
----------
95+
labels : torch.Tensor
96+
Tensor of integer labels.
97+
force_drop_ids : torch.Tensor or None, optional
98+
Boolean mask to force specific labels to be dropped.
99+
100+
Returns
101+
-------
102+
torch.Tensor
103+
Labels with some entries replaced by a dropout token.
60104
"""
61105
if force_drop_ids is None:
62106
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
@@ -65,7 +109,24 @@ def token_drop(self, labels, force_drop_ids=None):
65109
labels = torch.where(drop_ids, self.num_classes, labels)
66110
return labels
67111

68-
def forward(self, labels, train, force_drop_ids=None, t=None):
112+
def forward(self, labels, train, force_drop_ids=None):
113+
"""
114+
Forward pass for categorical embedding with optional label dropout.
115+
116+
Parameters
117+
----------
118+
labels : torch.Tensor
119+
Tensor of categorical labels.
120+
train : bool
121+
Whether the model is in training mode.
122+
force_drop_ids : torch.Tensor or None, optional
123+
Explicit mask for which labels to drop.
124+
125+
Returns
126+
-------
127+
torch.Tensor
128+
Embedded label representations, with optional noise added during training.
129+
"""
69130
labels = labels.long().view(-1)
70131
use_dropout = self.dropout_prob > 0
71132
if (train and use_dropout) or (force_drop_ids is not None):
@@ -77,6 +138,20 @@ def forward(self, labels, train, force_drop_ids=None, t=None):
77138
return embeddings
78139

79140
class ClusterContinuousEmbedder(nn.Module):
141+
"""
142+
Embeds continuous input features into vector representations using a multilayer perceptron (MLP).
143+
Supports optional embedding dropout for classifier-free guidance.
144+
145+
Parameters
146+
----------
147+
input_size : int
148+
The size of the input features.
149+
hidden_size : int
150+
The size of the output embedding vectors.
151+
dropout_prob : float
152+
Probability of embedding dropout, used for classifier-free guidance.
153+
154+
"""
80155
def __init__(self, input_size, hidden_size, dropout_prob):
81156
super().__init__()
82157
use_cfg_embedding = dropout_prob > 0

0 commit comments

Comments
 (0)