Skip to content

Commit 951935d

Browse files
committed
update readme
1 parent 5bdd24f commit 951935d

1 file changed

Lines changed: 34 additions & 38 deletions

File tree

README.md

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
1. **Predictive Models**: Done: GREA, SGIR, IRM, GIN/GCN w/ virtual, DIR. TODO: SMILES-based LSTM/Transformers, more
2323
2. **Generative Models**: Done: Graph DiT, GraphGA, DiGress. TODO:, GDSS, more
24-
3. **Representation Models**: Done: MoAMa, AttrMasking, ContextPred, EdgePred. TODO: checkpoints, more
24+
3. **Representation Models**: Done: MoAMa, AttrMasking, ContextPred, EdgePred. Many pretrained models from HF. TODO: checkpoints, more
2525

2626
> **Note**: This project is in active development, and features may change.
2727
@@ -54,6 +54,13 @@
5454
```bash
5555
pip install -i https://test.pypi.org/simple/ torch-molecule
5656
```
57+
58+
### Additional Packages
59+
60+
| Model | Required Packages |
61+
|-------|-------------------|
62+
| HFPretrainedMolecularEncoder | transformers |
63+
5764
## Usage
5865

5966
Refer to the `tests` folder for more use cases.
@@ -65,34 +72,13 @@ The following example demonstrates how to use the `GREAMolecularPredictor` class
6572
More examples could be found in the folders `examples` and `tests`.
6673

6774
```python
68-
from torch_molecule import GREAMolecularPredictor, GNNMolecularPredictor
69-
from torch_molecule.utils.search import ParameterType, ParameterSpec
70-
71-
# Define search parameters
72-
search_GNN = {
73-
"gnn_type": ParameterSpec(ParameterType.CATEGORICAL, ["gin-virtual", "gcn-virtual", "gin", "gcn"]),
74-
"norm_layer": ParameterSpec(ParameterType.CATEGORICAL, ["batch_norm", "layer_norm"]),
75-
"graph_pooling": ParameterSpec(ParameterType.CATEGORICAL, ["mean", "sum", "max"]),
76-
"augmented_feature": ParameterSpec(ParameterType.CATEGORICAL, ["maccs,morgan", "maccs", "morgan", None]),
77-
"num_layer": ParameterSpec(ParameterType.INTEGER, (2, 5)),
78-
"hidden_size": ParameterSpec(ParameterType.INTEGER, (64, 512)),
79-
"drop_ratio": ParameterSpec(ParameterType.FLOAT, (0.0, 0.5)),
80-
"learning_rate": ParameterSpec(ParameterType.LOG_FLOAT, (1e-5, 1e-2)),
81-
"weight_decay": ParameterSpec(ParameterType.LOG_FLOAT, (1e-10, 1e-3)),
82-
}
83-
84-
search_GREA = {
85-
"gamma": ParameterSpec(ParameterType.FLOAT, (0.25, 0.75)),
86-
**search_GNN
87-
}
75+
from torch_molecule import GREAMolecularPredictor
8876

8977
# Train GREA model
9078
grea_model = GREAMolecularPredictor(
9179
num_task=num_task,
9280
task_type="regression",
9381
model_name="GREA_multitask",
94-
batch_size=BATCH_SIZE,
95-
epochs=N_epoch,
9682
evaluate_criterion='r2',
9783
evaluate_higher_better=True,
9884
verbose=True
@@ -103,59 +89,69 @@ X_train = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
10389
y_train = [[0.5], [1.5]]
10490
X_val = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
10591
y_val = [[0.5], [1.5]]
106-
N_trial = 100
92+
N_trial = 10
10793

10894
grea_model.autofit(
10995
X_train=X_train.tolist(),
11096
y_train=y_train,
11197
X_val=X_val.tolist(),
11298
y_val=y_val,
11399
n_trials=N_trial,
114-
search_parameters=search_GREA
115100
)
116101
```
117102

118-
### Using Checkpoints for Deployment
103+
### Checkpoints
119104

120-
`torch-molecule` provides checkpoints hosted on Hugging Face, which can save computational resources by starting from a pretrained state. For example, a checkpoint for gas permeability predictions (in log10 space) can be used as follows:
105+
`torch-molecule` provides checkpoint functions that can be interacted with on Hugging Face.
121106

122107
```python
123108
from torch_molecule import GREAMolecularPredictor
109+
from sklearn.metrics import mean_absolute_error
124110

111+
# Define the repository ID for Hugging Face
125112
repo_id = "user/repo_id"
126-
# Push a trained model to Hugging Face
113+
114+
# Initialize the GREAMolecularPredictor model
127115
model = GREAMolecularPredictor()
116+
117+
# Train the model using autofit
128118
model.autofit(
129-
X_train=X.tolist(), # List of SMILES strings
130-
y_train=y_train, # numpy array [n_samples, n_tasks]
131-
X_val=X_val.tolist(),
132-
y_val=y_val,
133-
n_trials=100 # Number of trials for hyperparameter optimization
119+
X_train=X.tolist(), # List of SMILES strings for training
120+
y_train=y_train, # numpy array [n_samples, n_tasks] for training labels
121+
X_val=X_val.tolist(),# List of SMILES strings for validation
122+
y_val=y_val, # numpy array [n_samples, n_tasks] for validation labels
134123
)
124+
125+
# Make predictions on the test set
135126
output = model.predict(X_test.tolist()) # (n_sample, n_task)
127+
128+
# Calculate the mean absolute error
136129
mae = mean_absolute_error(y_test, output['prediction'])
137130
metrics = {'MAE': mae}
138-
model.push_to_huggingface(
131+
132+
# Save the trained model to Hugging Face
133+
model.save_to_hf(
139134
repo_id=repo_id,
140135
task_id=f"{task_name}",
141136
metrics=metrics,
142137
commit_message=f"Upload GREA_{task_name} model with metrics: {metrics}",
143138
private=False
144139
)
140+
145141
# Load a pretrained checkpoint from Hugging Face
146142
model = GREAMolecularPredictor()
147-
model.load_model(f"{model_dir}/GREA_{task_name}.pt", repo_id=repo_id)
143+
model.load_from_hf(repo_id=repo_id, local_cache=f"{model_dir}/GREA_{task_name}.pt")
144+
145+
# Set model parameters
148146
model.set_params(verbose=True)
149147

150-
# Make predictions
148+
# Make predictions using the loaded model
151149
predictions = model.predict(smiles_list)
152150
```
153151

154152
<!-- ### Using Checkpoints for Benchmarking
155-
156153
_(Coming soon)_ -->
157154

158-
159155
## Project Structure
160156

161157
The structure of `torch_molecule` is as follows:

0 commit comments

Comments
 (0)