Skip to content

Commit de781a6

Browse files
committed
add hf pretrained encoders
1 parent 417ba11 commit de781a6

5 files changed

Lines changed: 279 additions & 40 deletions

File tree

tests/encoder/run_hfpretrained.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
import torch
3+
from torch_molecule import HFPretrainedMolecularEncoder
4+
5+
def test_hf_pretrained_encoder():
6+
# Test molecules (simple examples)
7+
molecules = [
8+
"CC(=O)O", # Acetic acid
9+
"CCO", # Ethanol
10+
"CCCC", # Butane
11+
"c1ccccc1", # Benzene
12+
"CCN", # Ethylamine
13+
]
14+
15+
# Test different HuggingFace models
16+
models_to_test = [
17+
{"repo_id": "entropy/gpt2_zinc_87m", "model_name": "GPT-2_ZINC_87M"},
18+
{"repo_id": "entropy/roberta_zinc_480m", "model_name": "RoBERTa_ZINC_480M"},
19+
{"repo_id": "ncfrey/ChemGPT-1.2B", "model_name": "ChemGPT_1.2B"},
20+
{"repo_id": "ncfrey/ChemGPT-19M", "model_name": "ChemGPT_19M"},
21+
{"repo_id": "ncfrey/ChemGPT-4.7M", "model_name": "ChemGPT_4.7M"},
22+
{"repo_id": "DeepChem/ChemBERTa-77M-MTR", "model_name": "ChemBERTa_77M_MTR"},
23+
{"repo_id": "DeepChem/ChemBERTa-77M-MLM", "model_name": "ChemBERTa_77M_MLM"},
24+
{"repo_id": "DeepChem/ChemBERTa-10M-MTR", "model_name": "ChemBERTa_10M_MTR"},
25+
{"repo_id": "DeepChem/ChemBERTa-10M-MLM", "model_name": "ChemBERTa_10M_MLM"},
26+
{"repo_id": "DeepChem/ChemBERTa-5M-MLM", "model_name": "ChemBERTa_5M_MLM"},
27+
{"repo_id": "DeepChem/ChemBERTa-5M-MTR", "model_name": "ChemBERTa_5M_MTR"}
28+
{"repo_id": "seyonec/ChemBERTa-zinc-base-v1", "model_name": "ChemBERTa_zinc_base_v1"},
29+
{"repo_id": "unikei/bert-base-smiles", "model_name": "bert-base-smiles"}
30+
]
31+
32+
for model_config in models_to_test:
33+
print(f"\n=== Testing {model_config['model_name']} ===")
34+
35+
# Initialize model
36+
model = HFPretrainedMolecularEncoder(repo_id=model_config["repo_id"], model_name=model_config["model_name"])
37+
print(f"Model initialized successfully: {model_config['model_name']}")
38+
39+
# Load the model
40+
print("Loading model from HuggingFace...")
41+
model.fit()
42+
print("Model loaded successfully")
43+
44+
# Encoding test
45+
print("Testing molecule encoding...")
46+
encodings_pt = model.encode(molecules, return_type="pt")
47+
encodings_np = model.encode(molecules, return_type="np")
48+
49+
print('model_config', model_config)
50+
print(f"Encoded {len(molecules)} molecules")
51+
print(f"PyTorch tensor shape: {encodings_pt.shape}")
52+
print(f"NumPy array shape: {encodings_np.shape}")
53+
54+
# Verify PyTorch and NumPy outputs match
55+
if np.allclose(encodings_pt.cpu().numpy(), encodings_np):
56+
print("PyTorch and NumPy encodings match!")
57+
else:
58+
print("Warning: PyTorch and NumPy encodings differ")
59+
60+
# Print some stats about the embeddings
61+
print(f"Embedding dimensionality: {encodings_pt.shape[1]}")
62+
print(f"Mean embedding value: {encodings_pt.mean().item():.4f}")
63+
print(f"Std of embedding values: {encodings_pt.std().item():.4f}")
64+
65+
# Check if embeddings are different for different molecules
66+
distances = []
67+
for i in range(len(molecules)):
68+
for j in range(i+1, len(molecules)):
69+
dist = torch.norm(encodings_pt[i] - encodings_pt[j]).item()
70+
distances.append(dist)
71+
72+
print(f"Average L2 distance between embeddings: {np.mean(distances):.4f}")
73+
print(f"Min L2 distance between embeddings: {np.min(distances):.4f}")
74+
print(f"Max L2 distance between embeddings: {np.max(distances):.4f}")
75+
76+
if __name__ == "__main__":
77+
test_hf_pretrained_encoder()

torch_molecule/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .encoder.edgepred import EdgePredMolecularEncoder
2222
from .encoder.moama import MoamaMolecularEncoder
2323
from .encoder.infograph import InfoGraphMolecularEncoder
24-
24+
from .encoder.pretrained import HFPretrainedMolecularEncoder
2525
"""
2626
generator module
2727
"""
@@ -48,7 +48,8 @@
4848
'ContextPredMolecularEncoder',
4949
'EdgePredMolecularEncoder',
5050
'MoamaMolecularEncoder',
51-
'InfographMolecularEncoder',
51+
'InfoGraphMolecularEncoder',
52+
'HFPretrainedMolecularEncoder',
5253
# generators
5354
'GraphDITMolecularGenerator',
5455
'GraphGAMolecularGenerator',
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .modeling_pretrained import HFPretrainedMolecularEncoder
2+
3+
__all__ = ['HFPretrainedMolecularEncoder']

torch_molecule/encoder/pretrained/modeling_pretrained.py

Lines changed: 193 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,115 @@
1-
import numpy as np
1+
import warnings
22
from tqdm import tqdm
33
from typing import Optional, Union, Dict, Any, Tuple, List, Literal
4-
5-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
65

76
import torch
7+
import numpy as np
88

99
from ...base import BaseMolecularEncoder
10+
11+
known_repos = [
12+
"entropy/gpt2_zinc_87m",
13+
"entropy/roberta_zinc_480m",
14+
"ncfrey/ChemGPT-1.2B",
15+
"ncfrey/ChemGPT-19M",
16+
"ncfrey/ChemGPT-4.7M",
17+
"DeepChem/ChemBERTa-77M-MTR",
18+
"DeepChem/ChemBERTa-77M-MLM",
19+
"DeepChem/ChemBERTa-10M-MTR",
20+
"DeepChem/ChemBERTa-10M-MLM",
21+
"DeepChem/ChemBERTa-5M-MLM",
22+
"DeepChem/ChemBERTa-5M-MTR",
23+
"unikei/bert-base-smiles",
24+
'seyonec/ChemBERTa-zinc-base-v1'
25+
]
26+
27+
known_add_bos_eos_list = ["entropy/gpt2_zinc_87m"]
28+
29+
@dataclass(init=False)
30+
class HFPretrainedMolecularEncoder(BaseMolecularEncoder):
31+
"""Implements Hugging Face pretrained transformer models as molecular encoders.
32+
33+
This class provides an interface to use pretrained transformer models from Hugging Face
34+
as molecular encoders. It handles tokenization, encoding, and pooling of molecular representations.
35+
36+
Tested models include:
37+
38+
- ChemGPT series (1.2B/19M/4.7M): GPT-Neo based models pretrained on PubChem10M dataset with SELFIES strings.
39+
Output dimension: 2048.
40+
repo_id: "ncfrey/ChemGPT-1.2B" (https://huggingface.co/ncfrey/ChemGPT-1.2B)
41+
repo_id: "ncfrey/ChemGPT-19M" (https://huggingface.co/ncfrey/ChemGPT-19M)
42+
repo_id: "ncfrey/ChemGPT-4.7M" (https://huggingface.co/ncfrey/ChemGPT-4.7M)
43+
44+
- GPT2-ZINC-87M: GPT-2 based model (87M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
45+
Output dimension: 768.
46+
repo_id: "entropy/gpt2_zinc_87m" (https://huggingface.co/entropy/gpt2_zinc_87m)
47+
48+
- RoBERTa-ZINC-480M: RoBERTa based model (102M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
49+
Output dimension: 768.
50+
repo_id: "entropy/roberta_zinc_480m" (https://huggingface.co/entropy/roberta_zinc_480m)
51+
52+
- ChemBERTa series: Available in multiple sizes (77M/10M/5M) and training objectives (MTR/MLM).
53+
Output dimension: 384.
54+
repo_id: "DeepChem/ChemBERTa-77M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-77M-MTR)
55+
repo_id: "DeepChem/ChemBERTa-77M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-77M-MLM)
56+
repo_id: "DeepChem/ChemBERTa-10M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-10M-MTR)
57+
repo_id: "DeepChem/ChemBERTa-10M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-10M-MLM)
58+
repo_id: "DeepChem/ChemBERTa-5M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-5M-MLM)
59+
repo_id: "DeepChem/ChemBERTa-5M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-5M-MTR)
60+
61+
- UniKi/bert-base-smiles: UniKi's BERT model pretrained on SMILES strings.
62+
Output dimension: 768.
63+
repo_id: "unikei/bert-base-smiles" (https://huggingface.co/unikei/bert-base-smiles)
64+
65+
- ChemBERTa-zinc-base-v1: RoBERTa model pretrained on ZINC dataset with ~100k SMILES strings.
66+
Output dimension: 384.
67+
repo_id: "seyonec/ChemBERTa-zinc-base-v1" (https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1)
68+
69+
Parameters
70+
----------
71+
repo_id : str
72+
The Hugging Face repository ID of the pretrained model.
73+
max_length : int, default=128
74+
Maximum sequence length for tokenization. Longer sequences will be truncated.
75+
batch_size : int, default=128
76+
Batch size used when encoding multiple molecules.
77+
add_bos_eos : Optional[bool], default=None
78+
Whether to add beginning/end of sequence tokens. If None, determined automatically based on model type.
79+
model_name : str, default="PretrainedMolecularEncoder"
80+
Name identifier for the model instance.
81+
verbose : bool, default=False
82+
Whether to display progress information during encoding.
83+
"""
1084

85+
repo_id: str
1186

12-
@dataclass
13-
class PretrainedMolecularEncoder(BaseMolecularEncoder):
14-
"""This encoder uses a pretrained transformer model from HuggingFace."""
15-
# Task-related parameters
16-
# repo_id: str = "huggingface/PretrainedMolecularEncoder"
87+
# Default arguments
88+
max_length: int = 128
89+
batch_size: int = 128
90+
add_bos_eos: Optional[bool] = None
1791
model_name: str = "PretrainedMolecularEncoder"
92+
verbose: bool = False
93+
94+
def __init__(self, repo_id: str, max_length: int = 128, batch_size: int = 128, add_bos_eos: Optional[bool] = None,
95+
model_name: str = "PretrainedMolecularEncoder", verbose: bool = False, **kwargs):
96+
self.repo_id = repo_id
97+
self.max_length = max_length
98+
self.batch_size = batch_size
99+
self.add_bos_eos = add_bos_eos
100+
self.model_name = model_name
101+
self.verbose = verbose
102+
super().__init__(**kwargs)
18103

19104
def __post_init__(self):
20-
"""Initialize the model after dataclass initialization."""
21105
super().__post_init__()
22106
self._require_transformers()
23-
self.is_fitted_ = True
24107
self.fitting_epoch = -1
25108
self.fitting_loss = -1
26109

110+
if self.repo_id not in known_repos:
111+
warnings.warn(f"Unknown repo_id: {self.repo_id}. The class will try to load the model from HuggingFace, but it might fail.")
112+
27113
@staticmethod
28114
def _get_param_names() -> List[str]:
29115
"""Get parameter names for the estimator.
@@ -33,40 +119,69 @@ def _get_param_names() -> List[str]:
33119
List[str]
34120
List of parameter names that can be used for model configuration.
35121
"""
36-
return []
122+
return ["repo_id", "max_length", "model_name", "add_bos_eos"]
37123

38-
def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]:
39-
params = ["model_name"]
40-
if checkpoint is not None:
41-
if "hyperparameters" not in checkpoint:
42-
raise ValueError("Checkpoint missing 'hyperparameters' key")
43-
return {k: checkpoint["hyperparameters"][k] for k in params}
44-
return {k: getattr(self, k) for k in params}
124+
def _get_model_params(self) -> Dict[str, Any]:
125+
raise NotImplementedError("PretrainedMolecularEncoder does not support model parameters.")
126+
127+
def _setup_optimizers(self) -> None:
128+
raise NotImplementedError("PretrainedMolecularEncoder does not support training.")
45129

46-
def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
130+
def _train_epoch(self) -> None:
47131
raise NotImplementedError("PretrainedMolecularEncoder does not support training.")
48132

49-
def save_to_local(self, path: str) -> None:
133+
def save_to_local(self) -> None:
50134
raise NotImplementedError("PretrainedMolecularEncoder does not support saving to local.")
51135

52-
def load_from_local(self, path: str) -> None:
136+
def load_from_local(self) -> None:
53137
raise NotImplementedError("PretrainedMolecularEncoder does not support loading from local.")
54138

55139
def save_to_hf(self) -> None:
56140
raise NotImplementedError("PretrainedMolecularEncoder does not support saving to huggingface.")
57141

58-
def load_from_hf(self, repo_id: str) -> None:
59-
# TODO: Implement this
60-
raise NotImplementedError("Implements this.")
142+
def load_from_hf(self) -> None:
143+
self.fit()
61144

62-
def load(self, repo_id: str) -> None:
63-
self.load_from_hf(repo_id)
145+
def load(self) -> None:
146+
self.fit()
147+
148+
def fit(self) -> "HFPretrainedMolecularEncoder":
149+
"""Load the pretrained model from HuggingFace."""
150+
assert self.repo_id is not None, "repo_id is not set"
151+
self._require_transformers()
152+
import transformers
153+
154+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.repo_id, max_length=self.max_length)
155+
self.model = transformers.AutoModel.from_pretrained(self.repo_id)
156+
self.model.to(self.device)
157+
158+
model_config = self.model.config
159+
model_type = model_config.model_type
64160

65-
def fit(self, repo_id: str) -> "PretrainedMolecularEncoder":
66-
self.load_from_hf(repo_id)
161+
if self.add_bos_eos is None:
162+
self.add_bos_eos = self.repo_id in known_add_bos_eos_list
163+
164+
if self.tokenizer.pad_token is None:
165+
if self.tokenizer.eos_token is not None:
166+
self.tokenizer.pad_token = self.tokenizer.eos_token
167+
else:
168+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
169+
self.model.resize_token_embeddings(len(self.tokenizer))
170+
171+
if self.add_bos_eos:
172+
if self.tokenizer.bos_token is None:
173+
self.tokenizer.add_special_tokens({'bos_token': '[BOS]'})
174+
if self.tokenizer.eos_token is None:
175+
self.tokenizer.add_special_tokens({'eos_token': '[EOS]'})
176+
self.model.resize_token_embeddings(len(self.tokenizer))
177+
178+
warnings.warn("BOS and EOS tokens are not found in the tokenizer. They are added to the tokenizer since add_bos_eos is set to True.")
179+
180+
self.collator = transformers.DataCollatorWithPadding(self.tokenizer, padding=True, return_tensors='pt')
181+
self.is_fitted_ = True
67182
return self
68183

69-
def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union[np.ndarray, torch.Tensor]:
184+
def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt", add_bos_eos: Optional[bool] = None) -> Union[np.ndarray, torch.Tensor]:
70185
"""Encode molecules into vector representations.
71186
72187
Parameters
@@ -75,22 +190,62 @@ def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union
75190
List of SMILES strings
76191
return_type : Literal["np", "pt"], default="pt"
77192
Return type of the representations
193+
add_bos_eos : Optional[bool], default=None
194+
Whether to add BOS and EOS tokens. If None, will be determined based on model type.
78195
79196
Returns
80197
-------
81198
representations : ndarray or torch.Tensor
82199
Molecular representations
83200
"""
201+
self._require_transformers()
84202
self._check_is_fitted()
85-
X, _ = self._validate_inputs(X, return_rdkit_mol=True)
86-
raise NotImplementedError("Implements this.")
87-
88-
# Placeholder for transformer-based encodings
89-
# Replace with actual encoding logic when integrating the transformer model
90-
encodings = [X] # dummy list to allow concat
91-
92-
encodings = torch.cat(encodings, dim=0)
93-
return encodings if return_type == "pt" else encodings.numpy()
203+
X, _ = self._validate_inputs(X, return_rdkit_mol=False)
204+
205+
# Process in batches
206+
all_embeddings = []
207+
iterator = tqdm(range(0, len(X), self.batch_size), desc="Encoding molecules", total=len(X) // self.batch_size, disable=not self.verbose)
208+
for i in iterator:
209+
batch_X = X[i:i + self.batch_size]
210+
211+
if self.add_bos_eos:
212+
# For decoding models (e.g. GPT2), manually add BOS and EOS tokens
213+
processed_batch = [self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in batch_X]
214+
inputs = self.collator(self.tokenizer(processed_batch))
215+
else:
216+
inputs = self.collator(self.tokenizer(batch_X))
217+
218+
# Move inputs to the same device as the model
219+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
220+
221+
# Get model outputs
222+
with torch.no_grad():
223+
outputs = self.model(**inputs, output_hidden_states=True)
224+
225+
# get all attributes of outputs
226+
print('outputs', outputs.keys())
227+
# Extract embeddings based on model type
228+
if hasattr(outputs, 'hidden_states'):
229+
# For models that return a named tuple
230+
full_embeddings = outputs.hidden_states[-1]
231+
elif isinstance(outputs, tuple) and len(outputs) > 1:
232+
# For models that return a tuple with hidden states
233+
full_embeddings = outputs[-1][-1]
234+
else:
235+
# For models that return last_hidden_state directly
236+
full_embeddings = outputs.last_hidden_state
237+
238+
# Apply attention mask to get meaningful embeddings
239+
mask = inputs['attention_mask']
240+
batch_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) /
241+
mask.sum(-1).unsqueeze(-1))
242+
243+
all_embeddings.append(batch_embeddings)
244+
245+
# Concatenate all batch embeddings
246+
embeddings = torch.cat(all_embeddings, dim=0)
247+
248+
return embeddings if return_type == "pt" else embeddings.cpu().numpy()
94249

95250
@staticmethod
96251
def _require_transformers():
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TODO
2+
3+
# safe-100m: https://huggingface.co/anrilombard/safe-100m

0 commit comments

Comments
 (0)