2121
22221 . ** Predictive Models** : Done: GREA, SGIR, IRM, GIN/GCN w/ virtual, DIR. TODO: SMILES-based LSTM/Transformers, more
23232 . ** Generative Models** : Done: Graph DiT, GraphGA, DiGress. TODO:, GDSS, more
24- 3 . ** Representation Models** : Done: MoAMa, AttrMasking, ContextPred, EdgePred. TODO: checkpoints, more
24+ 3 . ** Representation Models** : Done: MoAMa, AttrMasking, ContextPred, EdgePred. Many pretrained models from HF. TODO: checkpoints, more
2525
2626> ** Note** : This project is in active development, and features may change.
2727
5454 ``` bash
5555 pip install -i https://test.pypi.org/simple/ torch-molecule
5656 ```
57+
58+ ### Additional Packages
59+
60+ | Model | Required Packages |
61+ | -------| -------------------|
62+ | HFPretrainedMolecularEncoder | transformers |
63+
5764## Usage
5865
5966Refer to the ` tests ` folder for more use cases.
@@ -65,34 +72,13 @@ The following example demonstrates how to use the `GREAMolecularPredictor` class
6572More examples could be found in the folders ` examples ` and ` tests ` .
6673
6774``` python
68- from torch_molecule import GREAMolecularPredictor, GNNMolecularPredictor
69- from torch_molecule.utils.search import ParameterType, ParameterSpec
70-
71- # Define search parameters
72- search_GNN = {
73- " gnn_type" : ParameterSpec(ParameterType.CATEGORICAL , [" gin-virtual" , " gcn-virtual" , " gin" , " gcn" ]),
74- " norm_layer" : ParameterSpec(ParameterType.CATEGORICAL , [" batch_norm" , " layer_norm" ]),
75- " graph_pooling" : ParameterSpec(ParameterType.CATEGORICAL , [" mean" , " sum" , " max" ]),
76- " augmented_feature" : ParameterSpec(ParameterType.CATEGORICAL , [" maccs,morgan" , " maccs" , " morgan" , None ]),
77- " num_layer" : ParameterSpec(ParameterType.INTEGER , (2 , 5 )),
78- " hidden_size" : ParameterSpec(ParameterType.INTEGER , (64 , 512 )),
79- " drop_ratio" : ParameterSpec(ParameterType.FLOAT , (0.0 , 0.5 )),
80- " learning_rate" : ParameterSpec(ParameterType.LOG_FLOAT , (1e-5 , 1e-2 )),
81- " weight_decay" : ParameterSpec(ParameterType.LOG_FLOAT , (1e-10 , 1e-3 )),
82- }
83-
84- search_GREA = {
85- " gamma" : ParameterSpec(ParameterType.FLOAT , (0.25 , 0.75 )),
86- ** search_GNN
87- }
75+ from torch_molecule import GREAMolecularPredictor
8876
8977# Train GREA model
9078grea_model = GREAMolecularPredictor(
9179 num_task = num_task,
9280 task_type = " regression" ,
9381 model_name = " GREA_multitask" ,
94- batch_size = BATCH_SIZE ,
95- epochs = N_epoch,
9682 evaluate_criterion = ' r2' ,
9783 evaluate_higher_better = True ,
9884 verbose = True
@@ -103,59 +89,69 @@ X_train = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
10389y_train = [[0.5 ], [1.5 ]]
10490X_val = [' C1=CC=CC=C1' , ' C1=CC=CC=C1' ]
10591y_val = [[0.5 ], [1.5 ]]
106- N_trial = 100
92+ N_trial = 10
10793
10894grea_model.autofit(
10995 X_train = X_train.tolist(),
11096 y_train = y_train,
11197 X_val = X_val.tolist(),
11298 y_val = y_val,
11399 n_trials = N_trial,
114- search_parameters = search_GREA
115100)
116101```
117102
118- ### Using Checkpoints for Deployment
103+ ### Checkpoints
119104
120- ` torch-molecule ` provides checkpoints hosted on Hugging Face, which can save computational resources by starting from a pretrained state. For example, a checkpoint for gas permeability predictions (in log10 space) can be used as follows:
105+ ` torch-molecule ` provides checkpoint functions that can be interacted with on Hugging Face.
121106
122107``` python
123108from torch_molecule import GREAMolecularPredictor
109+ from sklearn.metrics import mean_absolute_error
124110
111+ # Define the repository ID for Hugging Face
125112repo_id = " user/repo_id"
126- # Push a trained model to Hugging Face
113+
114+ # Initialize the GREAMolecularPredictor model
127115model = GREAMolecularPredictor()
116+
117+ # Train the model using autofit
128118model.autofit(
129- X_train = X.tolist(), # List of SMILES strings
130- y_train = y_train, # numpy array [n_samples, n_tasks]
131- X_val = X_val.tolist(),
132- y_val = y_val,
133- n_trials = 100 # Number of trials for hyperparameter optimization
119+ X_train = X.tolist(), # List of SMILES strings for training
120+ y_train = y_train, # numpy array [n_samples, n_tasks] for training labels
121+ X_val = X_val.tolist(),# List of SMILES strings for validation
122+ y_val = y_val, # numpy array [n_samples, n_tasks] for validation labels
134123)
124+
125+ # Make predictions on the test set
135126output = model.predict(X_test.tolist()) # (n_sample, n_task)
127+
128+ # Calculate the mean absolute error
136129mae = mean_absolute_error(y_test, output[' prediction' ])
137130metrics = {' MAE' : mae}
138- model.push_to_huggingface(
131+
132+ # Save the trained model to Hugging Face
133+ model.save_to_hf(
139134 repo_id = repo_id,
140135 task_id = f " { task_name} " ,
141136 metrics = metrics,
142137 commit_message = f " Upload GREA_ { task_name} model with metrics: { metrics} " ,
143138 private = False
144139)
140+
145141# Load a pretrained checkpoint from Hugging Face
146142model = GREAMolecularPredictor()
147- model.load_model(f " { model_dir} /GREA_ { task_name} .pt " , repo_id = repo_id)
143+ model.load_from_hf(repo_id = repo_id, local_cache = f " { model_dir} /GREA_ { task_name} .pt " )
144+
145+ # Set model parameters
148146model.set_params(verbose = True )
149147
150- # Make predictions
148+ # Make predictions using the loaded model
151149predictions = model.predict(smiles_list)
152150```
153151
154152<!-- ### Using Checkpoints for Benchmarking
155-
156153_(Coming soon)_ -->
157154
158-
159155## Project Structure
160156
161157The structure of ` torch_molecule ` is as follows:
0 commit comments