Skip to content

Commit 9675959

Browse files
committed
update doc
1 parent 164b43c commit 9675959

4 files changed

Lines changed: 46 additions & 22 deletions

File tree

docs/source/api/generator.rst

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,25 @@ Molecular Generation Models
33

44
The generator models inherit from the :class:`torch_molecule.base.generator.BaseMolecularGenerator` class and share common methods for model training, generation and persistence.
55

6+
The following models support conditional generation (click model name to jump to details):
7+
8+
.. list-table:: Conditional Generation Models
9+
:header-rows: 1
10+
:widths: 30 70
11+
12+
* - Model
13+
- Description
14+
* - :class:`GraphDITMolecularGenerator <torch_molecule.generator.graph_dit.modeling_graph_dit.GraphDITMolecularGenerator>`
15+
- `Graph Diffusion Transformers for Multi-Conditional Molecular Generation <https://arxiv.org/abs/2401.13858>`_
16+
* - :class:`DeFoGMolecularGenerator <torch_molecule.generator.defog.modeling_defog.DeFoGMolecularGenerator>`
17+
- `Discrete Flow Matching for Graph Generation <https://openreview.net/forum?id=KPRIwWhqAZ>`_
18+
* - :class:`GraphGAMolecularGenerator <torch_molecule.generator.graph_ga.modeling_graph_ga.GraphGAMolecularGenerator>`
19+
- Graph Genetic Algorithm with Random Forests
20+
* - :class:`MolGPTMolecularGenerator <torch_molecule.generator.molgpt.modeling_molgpt.MolGPTMolecularGenerator>`
21+
- `MolGPT: Molecular Generation Using a Transformer-Decoder Model <https://pubs.acs.org/doi/10.1021/acs.jcim.1c00600>`_
22+
* - :class:`LSTMMolecularGenerator <torch_molecule.generator.lstm.modeling_lstm.LSTMMolecularGenerator>`
23+
- LSTM
24+
625
.. rubric:: Training and Generation
726

827
- ``fit(X, **kwargs)``: Train the model on given data, where X contains SMILES strings (y should be provided for conditional generation)
@@ -30,21 +49,40 @@ inherited from :class:`torch_molecule.base.base.BaseModel`
3049
Modeling Molecules as Graphs
3150
---------------------------------------------------------------------
3251

33-
.. rubric:: GraphDiT for Un/Multi-conditional Molecular Generation
52+
.. rubric:: GraphDiT for Unconditional/Multi-Conditional Molecular Generation
3453
.. autoclass:: torch_molecule.generator.graph_dit.modeling_graph_dit.GraphDITMolecularGenerator
3554
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
3655
:members: fit, generate
3756
:undoc-members:
3857
:show-inheritance:
3958

59+
.. rubric:: Discrete Flow Matching for Graph Generation for Unconditional/Multi-Conditional Molecular Generation
60+
.. autoclass:: torch_molecule.generator.defog.modeling_defog.DeFoGMolecularGenerator
61+
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
62+
:members: fit, generate
63+
:undoc-members:
64+
:show-inheritance:
65+
66+
.. rubric:: Graph Genetic Algorithm for Unconditional/Multi-Conditional Molecular Generation
67+
.. autoclass:: torch_molecule.generator.graph_ga.modeling_graph_ga.GraphGAMolecularGenerator
68+
:exclude-members: fitting_epoch, fitting_loss, save_to_hf, load_from_hf
69+
:members: fit, generate
70+
:undoc-members:
71+
:show-inheritance:
72+
73+
.. automodule:: torch_molecule.generator.graph_ga.oracle
74+
:members:
75+
:undoc-members:
76+
:show-inheritance:
77+
4078
.. rubric:: DiGress for Unconditional Molecular Generation
4179
.. autoclass:: torch_molecule.generator.digress.modeling_digress.DigressMolecularGenerator
4280
:exclude-members: fitting_epoch, fitting_loss, model_class, dataset_info, model_name
4381
:members: fit, generate
4482
:undoc-members:
4583
:show-inheritance:
4684

47-
.. rubric:: GDSS for score-based molecular generation
85+
.. rubric:: GDSS for Unconditional Molecular Generation
4886
.. autoclass:: torch_molecule.generator.gdss.modeling_gdss.GDSSMolecularGenerator
4987
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
5088
:members: fit, generate
@@ -58,29 +96,17 @@ Modeling Molecules as Graphs
5896
:undoc-members:
5997
:show-inheritance:
6098

61-
.. rubric:: Graph Genetic Algorithm for Un/Multi-conditional Molecular Generation
62-
.. autoclass:: torch_molecule.generator.graph_ga.modeling_graph_ga.GraphGAMolecularGenerator
63-
:exclude-members: fitting_epoch, fitting_loss, save_to_hf, load_from_hf
64-
:members: fit, generate
65-
:undoc-members:
66-
:show-inheritance:
67-
68-
.. automodule:: torch_molecule.generator.graph_ga.oracle
69-
:members:
70-
:undoc-members:
71-
:show-inheritance:
72-
7399
Modeling Molecules as Sequences
74100
--------------------------------
75101

76-
.. rubric:: MolGPT for Unconditional Molecular Generation
102+
.. rubric:: MolGPT for Unconditional/Multi-Conditional Molecular Generation
77103
.. autoclass:: torch_molecule.generator.molgpt.modeling_molgpt.MolGPTMolecularGenerator
78104
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
79105
:members: fit, generate
80106
:undoc-members:
81107
:show-inheritance:
82108

83-
.. rubric:: LSTM for Unconditional/Conditional Molecular Generation
109+
.. rubric:: LSTM for Unconditional/Multi-Conditional Molecular Generation
84110
.. autoclass:: torch_molecule.generator.lstm.modeling_lstm.LSTMMolecularGenerator
85111
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
86112
:members: fit, generate

torch_molecule/generator/defog/modeling_defog.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ class DeFoGMolecularGenerator(BaseMolecularGenerator):
3232
----------
3333
num_layer : int, default=6
3434
Number of transformer layers
35-
hidden_mlp_dims : Dict[str, int], default={'X': 256, 'E': 128, 'y': 128}
35+
hidden_mlp_dims : Dict[str, int], default={'X': 256, 'E': 128, 'y': 128} if None
3636
Hidden dimensions for MLP layers in X (node dim), E (edge dim), and y (property dim) components
37-
hidden_dims : Dict[str, Any], default={'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
37+
hidden_dims : Dict[str, Any], default={'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128} if None
3838
Hidden dimensions for transformer components including attention heads and feed-forward layers
3939
Keys: 'dx' (node dim), 'de' (edge dim), 'dy' (property dim), 'n_head' (number of attention heads), 'dim_ffX' (feed-forward dim for node features), 'dim_ffE' (feed-forward dim for edge features), 'dim_ffy' (feed-forward dim for property features)
4040
transition : str, default='marginal'
@@ -43,7 +43,7 @@ class DeFoGMolecularGenerator(BaseMolecularGenerator):
4343
time_distortion : str, default="polydec"
4444
Time distortion schedule used during training/sampling.
4545
Options: 'identity', 'cos', 'revcos', 'polyinc', 'polydec'
46-
lambda_train : List[float], default=[5.0, 1.0]
46+
lambda_train : List[float], default=[5.0, 1.0] if None
4747
Loss weights: [edge_loss_weight, property_loss_weight]
4848
extra_features_type : str, default='rrwp'
4949
Extra feature type.

torch_molecule/generator/graph_ga/oracle.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import numpy as np
22
from sklearn.ensemble import RandomForestRegressor
33
from rdkit import Chem
4-
from typing import List, Any
54
from ...utils.graph.features import getmorganfingerprint
65

7-
86
class Oracle:
97
"""The default Oracle class for scoring molecules in GraphGA.
108

torch_molecule/generator/molgpt/modeling_molgpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class MolGPTMolecularGenerator(BaseMolecularGenerator):
3838
max_len : int, default=128
3939
Maximum length of SMILES strings.
4040
num_task : int, default=0
41-
Number of property prediction tasks for conditional generation. O for unconditional generation.
41+
Number of property prediction tasks for conditional generation. 0 for unconditional generation.
4242
use_scaffold : bool, default=False
4343
Whether to use scaffold conditioning.
4444
use_lstm : bool, default=False

0 commit comments

Comments
 (0)