Skip to content

Commit 4dfa9ec

Browse files
committed
fix typing
1 parent 92e46d1 commit 4dfa9ec

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

src/autointent/configs/_embedder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TaskTypeEnum(Enum):
2020
sts = "sts"
2121

2222

23-
class BaseEmbedderConfig(ABC, BaseModel, extra="forbid"):
23+
class BaseEmbedderConfig(BaseModel, extra="forbid"):
2424
"""Base class for embedder configurations."""
2525

2626
default_prompt: str | None = Field(
@@ -124,15 +124,15 @@ class HashingVectorizerEmbeddingConfig(BaseEmbedderConfig):
124124

125125

126126
EmbedderConfig: TypeAlias = (
127-
SentenceTransformerEmbeddingConfig | OpenaiEmbeddingConfig | HashingVectorizerEmbeddingConfig
127+
SentenceTransformerEmbeddingConfig | OpenaiEmbeddingConfig | HashingVectorizerEmbeddingConfig | BaseEmbedderConfig
128128
)
129129

130130

131-
def get_default_embedder_config(**kwargs: Any) -> SentenceTransformerEmbeddingConfig: # noqa: ANN401
131+
def get_default_embedder_config(**kwargs: Any) -> EmbedderConfig: # noqa: ANN401
132132
return SentenceTransformerEmbeddingConfig.model_validate(kwargs)
133133

134134

135-
def initialize_embedder_config(values: dict[str, Any] | str | BaseEmbedderConfig | None) -> BaseEmbedderConfig:
135+
def initialize_embedder_config(values: dict[str, Any] | str | BaseEmbedderConfig | None) -> EmbedderConfig:
136136
if values is None:
137137
return get_default_embedder_config()
138138
if isinstance(values, BaseEmbedderConfig):

src/autointent/context/_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from .data_handler import DataHandler
2323
from .optimization_info import OptimizationInfo
24+
from ..configs._embedder import BaseEmbedderConfig
2425

2526
if TYPE_CHECKING:
2627
from pathlib import Path
@@ -178,7 +179,7 @@ def resolve_embedder(self) -> EmbedderConfig:
178179
except ValueError:
179180
if hasattr(self, "embedder_config"):
180181
return self.embedder_config
181-
return EmbedderConfig()
182+
return BaseEmbedderConfig()
182183

183184
def resolve_ranker(self) -> CrossEncoderConfig:
184185
"""Resolve the cross-encoder configuration.

0 commit comments

Comments
 (0)