1- import numpy as np
1+ import warnings
22from tqdm import tqdm
33from typing import Optional , Union , Dict , Any , Tuple , List , Literal
4-
5- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
65
76import torch
7+ import numpy as np
88
99from ...base import BaseMolecularEncoder
10+
11+ known_repos = [
12+ "entropy/gpt2_zinc_87m" ,
13+ "entropy/roberta_zinc_480m" ,
14+ "ncfrey/ChemGPT-1.2B" ,
15+ "ncfrey/ChemGPT-19M" ,
16+ "ncfrey/ChemGPT-4.7M" ,
17+ "DeepChem/ChemBERTa-77M-MTR" ,
18+ "DeepChem/ChemBERTa-77M-MLM" ,
19+ "DeepChem/ChemBERTa-10M-MTR" ,
20+ "DeepChem/ChemBERTa-10M-MLM" ,
21+ "DeepChem/ChemBERTa-5M-MLM" ,
22+ "DeepChem/ChemBERTa-5M-MTR" ,
23+ "unikei/bert-base-smiles" ,
24+ 'seyonec/ChemBERTa-zinc-base-v1'
25+ ]
26+
27+ known_add_bos_eos_list = ["entropy/gpt2_zinc_87m" ]
28+
29+ @dataclass (init = False )
30+ class HFPretrainedMolecularEncoder (BaseMolecularEncoder ):
31+ """Implements Hugging Face pretrained transformer models as molecular encoders.
32+
33+ This class provides an interface to use pretrained transformer models from Hugging Face
34+ as molecular encoders. It handles tokenization, encoding, and pooling of molecular representations.
35+
36+ Tested models include:
37+
38+ - ChemGPT series (1.2B/19M/4.7M): GPT-Neo based models pretrained on PubChem10M dataset with SELFIES strings.
39+ Output dimension: 2048.
40+ repo_id: "ncfrey/ChemGPT-1.2B" (https://huggingface.co/ncfrey/ChemGPT-1.2B)
41+ repo_id: "ncfrey/ChemGPT-19M" (https://huggingface.co/ncfrey/ChemGPT-19M)
42+ repo_id: "ncfrey/ChemGPT-4.7M" (https://huggingface.co/ncfrey/ChemGPT-4.7M)
43+
44+ - GPT2-ZINC-87M: GPT-2 based model (87M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
45+ Output dimension: 768.
46+ repo_id: "entropy/gpt2_zinc_87m" (https://huggingface.co/entropy/gpt2_zinc_87m)
47+
48+ - RoBERTa-ZINC-480M: RoBERTa based model (102M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
49+ Output dimension: 768.
50+ repo_id: "entropy/roberta_zinc_480m" (https://huggingface.co/entropy/roberta_zinc_480m)
51+
52+ - ChemBERTa series: Available in multiple sizes (77M/10M/5M) and training objectives (MTR/MLM).
53+ Output dimension: 384.
54+ repo_id: "DeepChem/ChemBERTa-77M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-77M-MTR)
55+ repo_id: "DeepChem/ChemBERTa-77M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-77M-MLM)
56+ repo_id: "DeepChem/ChemBERTa-10M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-10M-MTR)
57+ repo_id: "DeepChem/ChemBERTa-10M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-10M-MLM)
58+ repo_id: "DeepChem/ChemBERTa-5M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-5M-MLM)
59+ repo_id: "DeepChem/ChemBERTa-5M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-5M-MTR)
60+
61+ - UniKi/bert-base-smiles: UniKi's BERT model pretrained on SMILES strings.
62+ Output dimension: 768.
63+ repo_id: "unikei/bert-base-smiles" (https://huggingface.co/unikei/bert-base-smiles)
64+
65+ - ChemBERTa-zinc-base-v1: RoBERTa model pretrained on ZINC dataset with ~100k SMILES strings.
66+ Output dimension: 384.
67+ repo_id: "seyonec/ChemBERTa-zinc-base-v1" (https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1)
68+
69+ Parameters
70+ ----------
71+ repo_id : str
72+ The Hugging Face repository ID of the pretrained model.
73+ max_length : int, default=128
74+ Maximum sequence length for tokenization. Longer sequences will be truncated.
75+ batch_size : int, default=128
76+ Batch size used when encoding multiple molecules.
77+ add_bos_eos : Optional[bool], default=None
78+ Whether to add beginning/end of sequence tokens. If None, determined automatically based on model type.
79+ model_name : str, default="PretrainedMolecularEncoder"
80+ Name identifier for the model instance.
81+ verbose : bool, default=False
82+ Whether to display progress information during encoding.
83+ """
1084
85+ repo_id : str
1186
12- @dataclass
13- class PretrainedMolecularEncoder (BaseMolecularEncoder ):
14- """This encoder uses a pretrained transformer model from HuggingFace."""
15- # Task-related parameters
16- # repo_id: str = "huggingface/PretrainedMolecularEncoder"
87+ # Default arguments
88+ max_length : int = 128
89+ batch_size : int = 128
90+ add_bos_eos : Optional [bool ] = None
1791 model_name : str = "PretrainedMolecularEncoder"
92+ verbose : bool = False
93+
94+ def __init__ (self , repo_id : str , max_length : int = 128 , batch_size : int = 128 , add_bos_eos : Optional [bool ] = None ,
95+ model_name : str = "PretrainedMolecularEncoder" , verbose : bool = False , ** kwargs ):
96+ self .repo_id = repo_id
97+ self .max_length = max_length
98+ self .batch_size = batch_size
99+ self .add_bos_eos = add_bos_eos
100+ self .model_name = model_name
101+ self .verbose = verbose
102+ super ().__init__ (** kwargs )
18103
19104 def __post_init__ (self ):
20- """Initialize the model after dataclass initialization."""
21105 super ().__post_init__ ()
22106 self ._require_transformers ()
23- self .is_fitted_ = True
24107 self .fitting_epoch = - 1
25108 self .fitting_loss = - 1
26109
110+ if self .repo_id not in known_repos :
111+ warnings .warn (f"Unknown repo_id: { self .repo_id } . The class will try to load the model from HuggingFace, but it might fail." )
112+
27113 @staticmethod
28114 def _get_param_names () -> List [str ]:
29115 """Get parameter names for the estimator.
@@ -33,40 +119,69 @@ def _get_param_names() -> List[str]:
33119 List[str]
34120 List of parameter names that can be used for model configuration.
35121 """
36- return []
122+ return ["repo_id" , "max_length" , "model_name" , "add_bos_eos" ]
37123
38- def _get_model_params (self , checkpoint : Optional [Dict ] = None ) -> Dict [str , Any ]:
39- params = ["model_name" ]
40- if checkpoint is not None :
41- if "hyperparameters" not in checkpoint :
42- raise ValueError ("Checkpoint missing 'hyperparameters' key" )
43- return {k : checkpoint ["hyperparameters" ][k ] for k in params }
44- return {k : getattr (self , k ) for k in params }
124+ def _get_model_params (self ) -> Dict [str , Any ]:
125+ raise NotImplementedError ("PretrainedMolecularEncoder does not support model parameters." )
126+
127+ def _setup_optimizers (self ) -> None :
128+ raise NotImplementedError ("PretrainedMolecularEncoder does not support training." )
45129
46- def _setup_optimizers (self ) -> Tuple [ torch . optim . Optimizer , Optional [ Any ]] :
130+ def _train_epoch (self ) -> None :
47131 raise NotImplementedError ("PretrainedMolecularEncoder does not support training." )
48132
49- def save_to_local (self , path : str ) -> None :
133+ def save_to_local (self ) -> None :
50134 raise NotImplementedError ("PretrainedMolecularEncoder does not support saving to local." )
51135
52- def load_from_local (self , path : str ) -> None :
136+ def load_from_local (self ) -> None :
53137 raise NotImplementedError ("PretrainedMolecularEncoder does not support loading from local." )
54138
55139 def save_to_hf (self ) -> None :
56140 raise NotImplementedError ("PretrainedMolecularEncoder does not support saving to huggingface." )
57141
58- def load_from_hf (self , repo_id : str ) -> None :
59- # TODO: Implement this
60- raise NotImplementedError ("Implements this." )
142+ def load_from_hf (self ) -> None :
143+ self .fit ()
61144
62- def load (self , repo_id : str ) -> None :
63- self .load_from_hf (repo_id )
145+ def load (self ) -> None :
146+ self .fit ()
147+
148+ def fit (self ) -> "HFPretrainedMolecularEncoder" :
149+ """Load the pretrained model from HuggingFace."""
150+ assert self .repo_id is not None , "repo_id is not set"
151+ self ._require_transformers ()
152+ import transformers
153+
154+ self .tokenizer = transformers .AutoTokenizer .from_pretrained (self .repo_id , max_length = self .max_length )
155+ self .model = transformers .AutoModel .from_pretrained (self .repo_id )
156+ self .model .to (self .device )
157+
158+ model_config = self .model .config
159+ model_type = model_config .model_type
64160
65- def fit (self , repo_id : str ) -> "PretrainedMolecularEncoder" :
66- self .load_from_hf (repo_id )
161+ if self .add_bos_eos is None :
162+ self .add_bos_eos = self .repo_id in known_add_bos_eos_list
163+
164+ if self .tokenizer .pad_token is None :
165+ if self .tokenizer .eos_token is not None :
166+ self .tokenizer .pad_token = self .tokenizer .eos_token
167+ else :
168+ self .tokenizer .add_special_tokens ({'pad_token' : '[PAD]' })
169+ self .model .resize_token_embeddings (len (self .tokenizer ))
170+
171+ if self .add_bos_eos :
172+ if self .tokenizer .bos_token is None :
173+ self .tokenizer .add_special_tokens ({'bos_token' : '[BOS]' })
174+ if self .tokenizer .eos_token is None :
175+ self .tokenizer .add_special_tokens ({'eos_token' : '[EOS]' })
176+ self .model .resize_token_embeddings (len (self .tokenizer ))
177+
178+ warnings .warn ("BOS and EOS tokens are not found in the tokenizer. They are added to the tokenizer since add_bos_eos is set to True." )
179+
180+ self .collator = transformers .DataCollatorWithPadding (self .tokenizer , padding = True , return_tensors = 'pt' )
181+ self .is_fitted_ = True
67182 return self
68183
69- def encode (self , X : List [str ], return_type : Literal ["np" , "pt" ] = "pt" ) -> Union [np .ndarray , torch .Tensor ]:
184+ def encode (self , X : List [str ], return_type : Literal ["np" , "pt" ] = "pt" , add_bos_eos : Optional [ bool ] = None ) -> Union [np .ndarray , torch .Tensor ]:
70185 """Encode molecules into vector representations.
71186
72187 Parameters
@@ -75,22 +190,62 @@ def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union
75190 List of SMILES strings
76191 return_type : Literal["np", "pt"], default="pt"
77192 Return type of the representations
193+ add_bos_eos : Optional[bool], default=None
194+ Whether to add BOS and EOS tokens. If None, will be determined based on model type.
78195
79196 Returns
80197 -------
81198 representations : ndarray or torch.Tensor
82199 Molecular representations
83200 """
201+ self ._require_transformers ()
84202 self ._check_is_fitted ()
85- X , _ = self ._validate_inputs (X , return_rdkit_mol = True )
86- raise NotImplementedError ("Implements this." )
87-
88- # Placeholder for transformer-based encodings
89- # Replace with actual encoding logic when integrating the transformer model
90- encodings = [X ] # dummy list to allow concat
91-
92- encodings = torch .cat (encodings , dim = 0 )
93- return encodings if return_type == "pt" else encodings .numpy ()
203+ X , _ = self ._validate_inputs (X , return_rdkit_mol = False )
204+
205+ # Process in batches
206+ all_embeddings = []
207+ iterator = tqdm (range (0 , len (X ), self .batch_size ), desc = "Encoding molecules" , total = len (X ) // self .batch_size , disable = not self .verbose )
208+ for i in iterator :
209+ batch_X = X [i :i + self .batch_size ]
210+
211+ if self .add_bos_eos :
212+ # For decoding models (e.g. GPT2), manually add BOS and EOS tokens
213+ processed_batch = [self .tokenizer .bos_token + x + self .tokenizer .eos_token for x in batch_X ]
214+ inputs = self .collator (self .tokenizer (processed_batch ))
215+ else :
216+ inputs = self .collator (self .tokenizer (batch_X ))
217+
218+ # Move inputs to the same device as the model
219+ inputs = {k : v .to (self .device ) for k , v in inputs .items ()}
220+
221+ # Get model outputs
222+ with torch .no_grad ():
223+ outputs = self .model (** inputs , output_hidden_states = True )
224+
225+ # get all attributes of outputs
226+ print ('outputs' , outputs .keys ())
227+ # Extract embeddings based on model type
228+ if hasattr (outputs , 'hidden_states' ):
229+ # For models that return a named tuple
230+ full_embeddings = outputs .hidden_states [- 1 ]
231+ elif isinstance (outputs , tuple ) and len (outputs ) > 1 :
232+ # For models that return a tuple with hidden states
233+ full_embeddings = outputs [- 1 ][- 1 ]
234+ else :
235+ # For models that return last_hidden_state directly
236+ full_embeddings = outputs .last_hidden_state
237+
238+ # Apply attention mask to get meaningful embeddings
239+ mask = inputs ['attention_mask' ]
240+ batch_embeddings = ((full_embeddings * mask .unsqueeze (- 1 )).sum (1 ) /
241+ mask .sum (- 1 ).unsqueeze (- 1 ))
242+
243+ all_embeddings .append (batch_embeddings )
244+
245+ # Concatenate all batch embeddings
246+ embeddings = torch .cat (all_embeddings , dim = 0 )
247+
248+ return embeddings if return_type == "pt" else embeddings .cpu ().numpy ()
94249
95250 @staticmethod
96251 def _require_transformers ():
0 commit comments