Skip to content

Commit 795deff

Browse files
committed
version 0.1.3 with datasets
1 parent d9b9016 commit 795deff

12 files changed

Lines changed: 987 additions & 332 deletions

File tree

README.md

Lines changed: 63 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,6 @@
2626

2727
See the [List of Supported Models](#list-of-supported-models) section for all available models.
2828

29-
<!-- ### API Comparison
30-
31-
| Functionality | scikit-learn | torch-molecule |
32-
|---------------|-------------|----------------|
33-
| Property Prediction | `predictor.fit/predict(...)` | `predictor.fit/autofit/predict(...)` |
34-
| Representation Learning | Not supported | `encoder.fit/encode(...)` |
35-
| Molecular Generation | Not supported | `generator.fit/generate(...)` | -->
36-
37-
3829
## Installation
3930

4031
1. **Create a Conda environment**:
@@ -44,22 +35,19 @@ See the [List of Supported Models](#list-of-supported-models) section for all av
4435
```
4536

4637
2. **Install using pip (0.1.2)**:
47-
4838
```bash
4939
pip install torch-molecule
5040
```
5141

5242
3. **Install from source for the latest version**:
5343

5444
Clone the repository:
55-
5645
```bash
5746
git clone https://github.com/liugangcode/torch-molecule
5847
cd torch-molecule
5948
```
6049

6150
Install:
62-
6351
```bash
6452
pip install .
6553
```
@@ -80,92 +68,97 @@ See the [List of Supported Models](#list-of-supported-models) section for all av
8068

8169
## Usage
8270

83-
Refer to the `tests` folder for more use cases.
71+
> More examples can be found in the `examples` and `tests` folders.
72+
73+
`torch-molecule` supports applications in broad domains from chemistry, biology, to materials science. To get started, you can load prepared datasets from `torch_molecule.dataset` (updated after v0.1.3):
74+
75+
| Dataset | Description | Function |
76+
|---------|-------------|----------|
77+
| qm9 | Quantum chemical properties (DFT level) | `load_qm9` |
78+
| chembl2k | Bioactive molecules with drug-like properties | `load_chembl2k` |
79+
| broad6k | Bioactive molecules with drug-like properties | `load_broad6k` |
80+
| toxcast | Toxicity of chemical compounds | `load_toxcast` |
81+
| admet | Chemical absorption, distribution, metabolism, excretion, and toxicity | `load_admet` |
82+
| gasperm | Six gas permeability properties for polymeric materials | `load_gasperm` |
83+
84+
85+
```python
86+
from torch_molecule.dataset import load_qm9
87+
88+
# local_dir is the local path where the dataset will be saved
89+
smiles_list, property_np_array = load_qm9(local_dir='torchmol_data')
90+
91+
# len(smiles_list): 133885
92+
# Property array shape: (133885, 1)
93+
94+
# load_qm9 returns the target "gas" by default, but you can adjust it by passing new target_cols
95+
target_cols = ['homo', 'lumo', 'gap']
96+
smiles_list, property_np_array = load_qm9(local_dir='torchmol_data', target_cols=target_cols)
97+
```
8498

85-
### Python API Example
99+
(We welcome your suggestions and contributions on your datasets!)
86100

87-
The following example demonstrates how to use the `GREAMolecularPredictor` class from `torch_molecule`:
101+
### Fit a Model
88102

89-
More examples could be found in the folders `examples` and `tests`.
103+
After preparing the dataset, we can easily fit a model similar to how we use sklearn (actually, the coding is even simpler than sklearn, as we still need to do feature engineering in sklearn to convert molecule SMILES into vectors):
90104

91105
```python
92106
from torch_molecule import GREAMolecularPredictor
93107

94-
# Train GREA model
95-
grea_model = GREAMolecularPredictor(
108+
split = int(0.8 * len(smiles_list))
109+
110+
grea = GREAMolecularPredictor(
96111
num_task=num_task,
97112
task_type="regression",
98-
model_name="GREA_multitask",
99-
evaluate_criterion='r2',
100-
evaluate_higher_better=True,
113+
evaluate_higher_better=False,
101114
verbose=True
102115
)
103116

104-
# Fit the model
105-
X_train = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
106-
y_train = [[0.5], [1.5]]
107-
X_val = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
108-
y_val = [[0.5], [1.5]]
109-
N_trial = 10
110-
111-
grea_model.autofit(
112-
X_train=X_train.tolist(),
113-
y_train=y_train,
114-
X_val=X_val.tolist(),
115-
y_val=y_val,
116-
n_trials=N_trial,
117+
# Fit with automatic hyperparameter tuning with 10 attempts, or implement .fit() with the default/manual hyperparameters
118+
grea.autofit(
119+
X_train=smiles_list[:split],
120+
y_train=property_np_array[:split],
121+
X_val=smiles_list[split:],
122+
y_val=property_np_array[split:],
123+
n_trials=10,
117124
)
118125
```
119126

120127
### Checkpoints
121128

122-
`torch-molecule` provides checkpoint functions that can be interacted with on Hugging Face.
129+
`torch-molecule` provides checkpoint functions that can be interacted with on Hugging Face:
123130

124131
```python
125132
from torch_molecule import GREAMolecularPredictor
126-
from sklearn.metrics import mean_absolute_error
127-
128-
# Define the repository ID for Hugging Face
129-
repo_id = "user/repo_id"
130-
131-
# Initialize the GREAMolecularPredictor model
132-
model = GREAMolecularPredictor()
133-
134-
# Train the model using autofit
135-
model.autofit(
136-
X_train=X.tolist(), # List of SMILES strings for training
137-
y_train=y_train, # numpy array [n_samples, n_tasks] for training labels
138-
X_val=X_val.tolist(),# List of SMILES strings for validation
139-
y_val=y_val, # numpy array [n_samples, n_tasks] for validation labels
140-
)
141133

142-
# Make predictions on the test set
143-
output = model.predict(X_test.tolist()) # (n_sample, n_task)
144-
145-
# Calculate the mean absolute error
146-
mae = mean_absolute_error(y_test, output['prediction'])
147-
metrics = {'MAE': mae}
134+
repo_id = "user/repo_id" # replace with your own Hugging Face username and repo_id
148135

149136
# Save the trained model to Hugging Face
150-
model.save_to_hf(
137+
grea.save_to_hf(
151138
repo_id=repo_id,
152-
task_id=f"{task_name}",
153-
metrics=metrics,
154-
commit_message=f"Upload GREA_{task_name} model with metrics: {metrics}",
139+
task_id="qm9_grea",
140+
commit_message="Upload qm9_grea",
155141
private=False
156142
)
157143

158144
# Load a pretrained checkpoint from Hugging Face
159145
model = GREAMolecularPredictor()
160146
model.load_from_hf(repo_id=repo_id, local_cache=f"{model_dir}/GREA_{task_name}.pt")
161147

162-
# Set model parameters
163-
model.set_params(verbose=True)
164-
165-
# Make predictions using the loaded model
148+
# Adjust model parameters and make predictions
149+
model.set_params(verbose=False)
166150
predictions = model.predict(smiles_list)
167151
```
168152

153+
Or you can save the model to a local path:
154+
155+
```python
156+
grea.save_to_local("qm9_grea.pt")
157+
158+
new_model = GREAMolecularPredictor()
159+
new_model.load_from_local("qm9_grea.pt")
160+
```
161+
169162
## List of Supported Models
170163

171164
### Predictive Models
@@ -207,24 +200,19 @@ predictions = model.predict(smiles_list)
207200
| EdgePred | [Strategies for Pre-training Graph Neural Networks. ICLR 2020](https://arxiv.org/abs/1905.12265) |
208201
| InfoGraph | [InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. ICLR 2020](https://arxiv.org/abs/1908.01000) |
209202
| Supervised | Supervised pretraining |
210-
| Pretrained | [GPT2-ZINC-87M](https://huggingface.co/entropy/gpt2_zinc_87m): GPT-2 based model (87M parameters) pretrained on ZINC dataset with ~480M SMILES strings. |
211-
| | [RoBERTa-ZINC-480M](https://huggingface.co/entropy/roberta_zinc_480m): RoBERTa based model (102M parameters) pretrained on ZINC dataset with ~480M SMILES strings. |
212-
| | [UniKi/bert-base-smiles](https://huggingface.co/unikei/bert-base-smiles): BERT model pretrained on SMILES strings. |
213-
| | [ChemBERTa-zinc-base-v1](https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1): RoBERTa model pretrained on ZINC dataset with ~100k SMILES strings.|
214-
| | ChemBERTa series: Available in multiple sizes and training objectives (MLM/MTR). <br> - [ChemBERTa-5M-MLM](https://huggingface.co/DeepChem/ChemBERTa-5M-MLM)<br> - [ChemBERTa-5M-MTR](https://huggingface.co/DeepChem/ChemBERTa-5M-MTR)<br> - [ChemBERTa-10M-MLM](https://huggingface.co/DeepChem/ChemBERTa-10M-MLM)<br> - [ChemBERTa-10M-MTR](https://huggingface.co/DeepChem/ChemBERTa-10M-MTR)<br> - [ChemBERTa-77M-MLM](https://huggingface.co/DeepChem/ChemBERTa-77M-MLM)<br> - [ChemBERTa-77M-MTR](https://huggingface.co/DeepChem/ChemBERTa-77M-MTR)|
215-
| | ChemGPT series: GPT-Neo based models pretrained on PubChem10M dataset with SELFIES strings. <br> - [ChemGPT-1.2B](https://huggingface.co/ncfrey/ChemGPT-1.2B)<br> - [ChemGPT-4.7B](https://huggingface.co/ncfrey/ChemGPT-4.7M)<br> - [ChemGPT-19B](https://huggingface.co/ncfrey/ChemGPT-19M)|
203+
| Pretrained | [GPT2-ZINC-87M](https://huggingface.co/entropy/gpt2_zinc_87m): GPT-2 based model (87M parameters) pretrained on ZINC dataset with ~480M SMILES strings. <br> [RoBERTa-ZINC-480M](https://huggingface.co/entropy/roberta_zinc_480m): RoBERTa based model (102M parameters) pretrained on ZINC dataset with ~480M SMILES strings. <br> [UniKi/bert-base-smiles](https://huggingface.co/unikei/bert-base-smiles): BERT model pretrained on SMILES strings. <br> [ChemBERTa-zinc-base-v1](https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1): RoBERTa model pretrained on ZINC dataset with ~100k SMILES strings. <br> ChemBERTa series: Available in multiple sizes and training objectives (MLM/MTR). [ChemBERTa-5M-MLM](https://huggingface.co/DeepChem/ChemBERTa-5M-MLM), [ChemBERTa-5M-MTR](https://huggingface.co/DeepChem/ChemBERTa-5M-MTR), [ChemBERTa-10M-MLM](https://huggingface.co/DeepChem/ChemBERTa-10M-MLM), [ChemBERTa-10M-MTR](https://huggingface.co/DeepChem/ChemBERTa-10M-MTR), [ChemBERTa-77M-MLM](https://huggingface.co/DeepChem/ChemBERTa-77M-MLM), [ChemBERTa-77M-MTR](https://huggingface.co/DeepChem/ChemBERTa-77M-MTR). <br> ChemGPT series: GPT-Neo based models pretrained on PubChem10M dataset with SELFIES strings. [ChemGPT-1.2B](https://huggingface.co/ncfrey/ChemGPT-1.2B), [ChemGPT-4.7B](https://huggingface.co/ncfrey/ChemGPT-4.7M), [ChemGPT-19B](https://huggingface.co/ncfrey/ChemGPT-19M). |
216204

217-
## Project Structure
205+
<!-- ## Project Structure
218206
219-
See the structure of `torch_molecule` with the command `tree -L 2 torch_molecule -I '__pycache__|*.pyc|*.pyo|.git|old*'`
207+
See the structure of `torch_molecule` with the command `tree -L 2 torch_molecule -I '__pycache__|*.pyc|*.pyo|.git|old*'` -->
220208

221-
## Plan
209+
<!-- ## Plan
222210
223211
1. **Predictive Models**: Done: GREA, SGIR, IRM, GIN/GCN w/ virtual, DIR. SMILES-based LSTM/Transformers. TODO more
224212
2. **Generative Models**: Done: Graph DiT, GraphGA, DiGress, GDS, MolGPT TODO: more
225-
3. **Representation Models**: Done: MoAMa, AttrMasking, ContextPred, EdgePred. Many pretrained models from HF. TODO: checkpoints, more
213+
3. **Representation Models**: Done: MoAMa, AttrMasking, ContextPred, EdgePred. Many pretrained models from HF. TODO: checkpoints, more -->
226214

227-
> **Note**: This project is in active development, and features may change.
215+
<!-- > **Note**: This project is in active development, and features may change. -->
228216

229217
## Acknowledgements
230218

pyproject.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "torch_molecule"
7-
version = "0.1.2"
7+
version = "0.1.3"
88
description = "Deep learning packages for molecular discovery with a simple sklearn-style interface"
99
authors = [{name = "Gang Liu", email = "gliu7@nd.edu"}]
1010
readme = "README.md"
@@ -43,11 +43,16 @@ exclude = [
4343
"**/old_*.py"
4444
]
4545

46-
[tool.setuptools.package-data]
47-
"torch_molecule" = ["**/*.yaml", "**/*.json"]
48-
4946
[tool.setuptools]
5047
zip-safe = false
48+
include-package-data = true
49+
50+
[tool.setuptools.package-data]
51+
"torch_molecule" = [
52+
"**/*.yaml",
53+
"**/*.json",
54+
"datasets/data/*",
55+
]
5156

5257
[tool.pytest.ini_options]
5358
addopts = "--verbose"

tests/datasets/gasperm.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from torch_molecule.datasets import load_gasperm
2+
import numpy as np
3+
4+
5+
def test_gasperm_download_and_cleanup():
6+
"""
7+
Test gas permeability dataset loading, print results, and cleanup local files.
8+
"""
9+
print("=" * 60)
10+
print("Testing Gas Permeability Dataset Loading")
11+
print("=" * 60)
12+
13+
try:
14+
print(f"\n1. Testing loading with default target columns")
15+
print("-" * 40)
16+
17+
# Test with default target columns
18+
smiles_list, property_numpy = load_gasperm()
19+
20+
# Print results
21+
print(f"\nResults:")
22+
print(f"- Number of molecules: {len(smiles_list)}")
23+
print(f"- Property array shape: {property_numpy.shape}")
24+
print(f"- Target columns: ['CH4', 'CO2', 'H2', 'N2', 'O2']")
25+
26+
print(f"\nFirst 5 SMILES:")
27+
for i, smiles in enumerate(smiles_list[:5]):
28+
print(f" {i+1}. {smiles}")
29+
30+
print(f"\nFirst 5 property values (all gases):")
31+
for i, prop in enumerate(property_numpy[:5]):
32+
print(f" {i+1}. {prop}")
33+
34+
print(f"\nProperty statistics for each gas:")
35+
gas_names = ['CH4', 'CO2', 'H2', 'N2', 'O2']
36+
for j, gas in enumerate(gas_names):
37+
gas_values = property_numpy[:, j]
38+
# Filter out NaN values for statistics
39+
valid_values = gas_values[~np.isnan(gas_values)]
40+
if len(valid_values) > 0:
41+
print(f" {gas}:")
42+
print(f" Min: {valid_values.min():.6f}")
43+
print(f" Max: {valid_values.max():.6f}")
44+
print(f" Mean: {valid_values.mean():.6f}")
45+
print(f" Std: {valid_values.std():.6f}")
46+
print(f" Valid values: {len(valid_values)}/{len(gas_values)}")
47+
else:
48+
print(f" {gas}: No valid values found")
49+
50+
# Test with custom target columns
51+
print(f"\n2. Testing with custom target columns")
52+
print("-" * 40)
53+
54+
custom_targets = ["CH4", "CO2"]
55+
smiles_list2, property_numpy2 = load_gasperm(target_cols=custom_targets)
56+
57+
print(f"Custom target results:")
58+
print(f"- Target columns: {custom_targets}")
59+
print(f"- Property array shape: {property_numpy2.shape}")
60+
print(f"- Same number of molecules: {len(smiles_list2) == len(smiles_list)}")
61+
print(f"- First molecule properties: {property_numpy2[0]}")
62+
63+
# Test with single target column
64+
print(f"\n3. Testing with single target column")
65+
print("-" * 40)
66+
67+
single_target = ["H2"]
68+
smiles_list3, property_numpy3 = load_gasperm(target_cols=single_target)
69+
70+
print(f"Single target results:")
71+
print(f"- Target columns: {single_target}")
72+
print(f"- Property array shape: {property_numpy3.shape}")
73+
print(f"- First 5 H2 permeability values:")
74+
for i, prop in enumerate(property_numpy3[:5]):
75+
print(f" {i+1}. {prop[0]:.6f}" if not np.isnan(prop[0]) else f" {i+1}. NaN")
76+
77+
# Test error handling with invalid target column
78+
print(f"\n4. Testing error handling with invalid target column")
79+
print("-" * 40)
80+
81+
try:
82+
invalid_targets = ["INVALID_GAS"]
83+
smiles_list4, property_numpy4 = load_gasperm(target_cols=invalid_targets)
84+
print("ERROR: Should have raised ValueError for invalid target column")
85+
except ValueError as e:
86+
print(f"Successfully caught expected error: {e}")
87+
except Exception as e:
88+
print(f"Unexpected error type: {type(e).__name__}: {e}")
89+
90+
except Exception as e:
91+
print(f"Error during testing: {e}")
92+
raise
93+
94+
print("\n" + "=" * 60)
95+
print("Gas Permeability Dataset Test Completed Successfully!")
96+
print("=" * 60)
97+
98+
99+
if __name__ == "__main__":
100+
test_gasperm_download_and_cleanup()

0 commit comments

Comments
 (0)