-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathbase.py
More file actions
49 lines (36 loc) · 1.15 KB
/
base.py
File metadata and controls
49 lines (36 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar
import numpy as np
import numpy.typing as npt
from sklearn.base import BaseEstimator
from autointent import Embedder, Ranker, VectorIndex
from autointent._wrappers import BaseTorchModuleWithVocab
from autointent.schemas import TagsList
if TYPE_CHECKING:
from pathlib import Path
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
ModuleAttributes: TypeAlias = (
ModuleSimpleAttributes
| TagsList
| npt.NDArray[np.floating]
| Embedder
| VectorIndex
| BaseEstimator
| Ranker
| BaseTorchModuleWithVocab
)
logger = logging.getLogger(__name__)
T = TypeVar("T")
class BaseObjectDumper(ABC, Generic[T]):
dir_or_file_name: str
@staticmethod
@abstractmethod
def dump(obj: T, path: Path, exists_ok: bool) -> None: ...
@staticmethod
@abstractmethod
def load(path: Path, **kwargs: Any) -> T: ... # noqa: ANN401
@classmethod
@abstractmethod
def check_isinstance(cls, obj: Any) -> bool: ... # noqa: ANN401