Skip to content

Commit a08a0a3

Browse files
committed
add get_default_hfmodel_config
1 parent 3059bda commit a08a0a3

4 files changed

Lines changed: 12 additions & 5 deletions

File tree

src/autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
LoggingConfig,
2323
VectorIndexConfig,
2424
get_default_embedder_config,
25+
get_default_hfmodel_config,
2526
get_default_vector_index_config,
2627
)
2728
from autointent.custom_types import NodeType
@@ -64,7 +65,7 @@ def __init__(
6465
self.embedder_config = get_default_embedder_config()
6566
self.cross_encoder_config = CrossEncoderConfig()
6667
self.data_config = DataConfig()
67-
self.transformer_config = HFModelConfig()
68+
self.transformer_config = get_default_hfmodel_config()
6869
self.hpo_config = HPOConfig()
6970
self.vector_index_config = get_default_vector_index_config()
7071
elif not isinstance(nodes[0], InferenceNode):

src/autointent/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
EmbedderFineTuningConfig,
1818
HFModelConfig,
1919
TokenizerConfig,
20+
get_default_hfmodel_config,
2021
)
2122
from ._vector_index import FaissConfig, OpenSearchConfig, VectorIndexConfig, get_default_vector_index_config
2223

@@ -40,6 +41,7 @@
4041
"VectorIndexConfig",
4142
"VocabConfig",
4243
"get_default_embedder_config",
44+
"get_default_hfmodel_config",
4345
"get_default_vector_index_config",
4446
"initialize_embedder_config",
4547
]

src/autointent/configs/_transformers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class HFModelConfig(BaseModel):
5353
fp16: bool = Field(False, description="Whether to use mixed precision training (not all devices support this).")
5454
tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig)
5555
trust_remote_code: bool = Field(False, description="Whether to trust the remote code when loading the model.")
56+
revision: str | None = Field(None, description="Revision from HF repo")
5657

5758
@classmethod
5859
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:
@@ -75,6 +76,10 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
7576
return cls(**values)
7677

7778

79+
def get_default_hfmodel_config() -> HFModelConfig:
80+
return HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16")
81+
82+
7883
class CrossEncoderConfig(HFModelConfig):
7984
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L6-v2", description="Name of the hugging face model.")
8085
train_head: bool = Field(

src/autointent/context/_context.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
HPOConfig,
1717
LoggingConfig,
1818
VectorIndexConfig,
19+
get_default_hfmodel_config,
1920
)
2021

2122
from .data_handler import DataHandler
@@ -25,9 +26,7 @@
2526
from pathlib import Path
2627

2728
from autointent import Dataset
28-
from autointent.configs import (
29-
DataConfig,
30-
)
29+
from autointent.configs import DataConfig
3130

3231

3332
class Context:
@@ -202,4 +201,4 @@ def resolve_transformer(self) -> HFModelConfig:
202201
"""
203202
if hasattr(self, "transformer_config"):
204203
return self.transformer_config
205-
return HFModelConfig()
204+
return get_default_hfmodel_config()

0 commit comments

Comments
 (0)