Skip to content

Commit 5751df4

Browse files
committed
polish hf pretrained
1 parent c754736 commit 5751df4

2 files changed

Lines changed: 17 additions & 4 deletions

File tree

docs/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torch==2.2.0+cu118
66
-f https://data.pyg.org/whl/torch-2.2.0+cu118.html
77
torch_geometric==2.6.1
88
# torch_cluster
9-
# torch_scatter
9+
torch_scatter
1010

1111
# Other dependencies
1212
huggingface_hub
@@ -19,6 +19,8 @@ scikit_learn==1.4.1.post1
1919
scipy==1.14.1
2020
tqdm==4.66.2
2121

22+
transformers
23+
2224
optuna
2325
# ogb
2426

torch_molecule/encoder/pretrained/modeling_pretrained.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,33 +37,46 @@ class HFPretrainedMolecularEncoder(BaseMolecularEncoder):
3737
3838
- ChemGPT series (1.2B/19M/4.7M): GPT-Neo based models pretrained on PubChem10M dataset with SELFIES strings.
3939
Output dimension: 2048.
40+
4041
repo_id: "ncfrey/ChemGPT-1.2B" (https://huggingface.co/ncfrey/ChemGPT-1.2B)
42+
4143
repo_id: "ncfrey/ChemGPT-19M" (https://huggingface.co/ncfrey/ChemGPT-19M)
44+
4245
repo_id: "ncfrey/ChemGPT-4.7M" (https://huggingface.co/ncfrey/ChemGPT-4.7M)
4346
4447
- GPT2-ZINC-87M: GPT-2 based model (87M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
4548
Output dimension: 768.
49+
4650
repo_id: "entropy/gpt2_zinc_87m" (https://huggingface.co/entropy/gpt2_zinc_87m)
4751
4852
- RoBERTa-ZINC-480M: RoBERTa based model (102M parameters) pretrained on ZINC dataset with ~480M SMILES strings.
4953
Output dimension: 768.
54+
5055
repo_id: "entropy/roberta_zinc_480m" (https://huggingface.co/entropy/roberta_zinc_480m)
5156
5257
- ChemBERTa series: Available in multiple sizes (77M/10M/5M) and training objectives (MTR/MLM).
5358
Output dimension: 384.
59+
5460
repo_id: "DeepChem/ChemBERTa-77M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-77M-MTR)
61+
5562
repo_id: "DeepChem/ChemBERTa-77M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-77M-MLM)
63+
5664
repo_id: "DeepChem/ChemBERTa-10M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-10M-MTR)
65+
5766
repo_id: "DeepChem/ChemBERTa-10M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-10M-MLM)
67+
5868
repo_id: "DeepChem/ChemBERTa-5M-MLM" (https://huggingface.co/DeepChem/ChemBERTa-5M-MLM)
69+
5970
repo_id: "DeepChem/ChemBERTa-5M-MTR" (https://huggingface.co/DeepChem/ChemBERTa-5M-MTR)
6071
6172
- UniKi/bert-base-smiles: UniKi's BERT model pretrained on SMILES strings.
6273
Output dimension: 768.
74+
6375
repo_id: "unikei/bert-base-smiles" (https://huggingface.co/unikei/bert-base-smiles)
6476
6577
- ChemBERTa-zinc-base-v1: RoBERTa model pretrained on ZINC dataset with ~100k SMILES strings.
6678
Output dimension: 384.
79+
6780
repo_id: "seyonec/ChemBERTa-zinc-base-v1" (https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1)
6881
6982
Other models accessible through the transformers library have not been explicitly tested but may still be compatible with this interface.
@@ -185,7 +198,7 @@ def fit(self) -> "HFPretrainedMolecularEncoder":
185198
self.is_fitted_ = True
186199
return self
187200

188-
def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt", add_bos_eos: Optional[bool] = None) -> Union[np.ndarray, torch.Tensor]:
201+
def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union[np.ndarray, torch.Tensor]:
189202
"""Encode molecules into vector representations.
190203
191204
Parameters
@@ -194,8 +207,6 @@ def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt", add_bos_
194207
List of SMILES strings
195208
return_type : Literal["np", "pt"], default="pt"
196209
Return type of the representations
197-
add_bos_eos : Optional[bool], default=None
198-
Whether to add BOS and EOS tokens. If None, will be determined based on model type.
199210
200211
Returns
201212
-------

0 commit comments

Comments
 (0)