Skip to content

Commit c754736

Browse files
committed
add hf pretrained encoders
1 parent de781a6 commit c754736

2 files changed

Lines changed: 15 additions & 2 deletions

File tree

docs/source/api/encoder.rst

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,13 @@ Supervised Pretraining for Molecules
7474
:members: fit, encode
7575
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class
7676
:undoc-members:
77-
:show-inheritance:
77+
:show-inheritance:
78+
79+
Pretrained Molecular Encoders
80+
----------------------------------------------
81+
82+
.. rubric:: Sequence-based Pretrained Transformers from Hugging Face
83+
84+
.. autoclass:: torch_molecule.encoder.pretrained.modeling_pretrained.HFPretrainedMolecularEncoder
85+
:members: fit, encode, load, load_from_hf
86+
:exclude-members: fitting_epoch, fitting_loss, model_name, model_class

torch_molecule/encoder/pretrained/modeling_pretrained.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class HFPretrainedMolecularEncoder(BaseMolecularEncoder):
3131
"""Implements Hugging Face pretrained transformer models as molecular encoders.
3232
3333
This class provides an interface to use pretrained transformer models from Hugging Face
34-
as molecular encoders. It handles tokenization, encoding, and pooling of molecular representations.
34+
as molecular encoders. It handles tokenization and encoding of molecular representations.
3535
3636
Tested models include:
3737
@@ -66,6 +66,8 @@ class HFPretrainedMolecularEncoder(BaseMolecularEncoder):
6666
Output dimension: 384.
6767
repo_id: "seyonec/ChemBERTa-zinc-base-v1" (https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1)
6868
69+
Other models accessible through the transformers library have not been explicitly tested but may still be compatible with this interface.
70+
6971
Parameters
7072
----------
7173
repo_id : str
@@ -140,9 +142,11 @@ def save_to_hf(self) -> None:
140142
raise NotImplementedError("PretrainedMolecularEncoder does not support saving to huggingface.")
141143

142144
def load_from_hf(self) -> None:
145+
"""The same as fit()"""
143146
self.fit()
144147

145148
def load(self) -> None:
149+
"""The same as fit()"""
146150
self.fit()
147151

148152
def fit(self) -> "HFPretrainedMolecularEncoder":

0 commit comments

Comments
 (0)