Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,10 @@
"*.yaml",
"!*/.github/*/*.yaml"
]
}
},
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
3 changes: 3 additions & 0 deletions src/autointent/_dump_tools/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import numpy.typing as npt
import torch

from autointent.configs import CrossEncoderConfig, EmbedderConfig
from autointent.context.optimization_info import Artifact
Expand Down Expand Up @@ -108,6 +109,8 @@ def dump(
simple_attrs[key] = val
elif isinstance(val, np.ndarray):
arrays[key] = val
elif isinstance(val, torch.Tensor):
arrays[key] = val.cpu().numpy()
else:
# Use the appropriate dumper for complex objects
Dumper._dump_single_object(key, val, path, exists_ok, raise_errors)
Expand Down
12 changes: 6 additions & 6 deletions src/autointent/_dump_tools/unit_dumpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

from autointent import Embedder, Ranker, VectorIndex
from autointent._wrappers import BaseTorchModuleWithVocab
from autointent._wrappers import BaseTorchModule
from autointent.schemas import TagsList

from .base import BaseObjectDumper, ModuleSimpleAttributes
Expand Down Expand Up @@ -276,11 +276,11 @@ def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, PreTrainedTokenizer | PreTrainedTokenizerFast)


class TorchModelDumper(BaseObjectDumper[BaseTorchModuleWithVocab]):
class TorchModelDumper(BaseObjectDumper[BaseTorchModule]):
dir_or_file_name = "torch_models"

@staticmethod
def dump(obj: BaseTorchModuleWithVocab, path: Path, exists_ok: bool) -> None:
def dump(obj: BaseTorchModule, path: Path, exists_ok: bool) -> None:
path.mkdir(parents=True, exist_ok=exists_ok)
class_info = {
"module": obj.__class__.__module__,
Expand All @@ -291,16 +291,16 @@ def dump(obj: BaseTorchModuleWithVocab, path: Path, exists_ok: bool) -> None:
obj.dump(path)

@staticmethod
def load(path: Path, **kwargs: Any) -> BaseTorchModuleWithVocab: # noqa: ANN401, ARG004
def load(path: Path, **kwargs: Any) -> BaseTorchModule: # noqa: ANN401, ARG004
with (path / "class_info.json").open("r") as f:
class_info = json.load(f)
module = importlib.import_module(class_info["module"])
model_class: BaseTorchModuleWithVocab = getattr(module, class_info["name"])
model_class: BaseTorchModule = getattr(module, class_info["name"])
return model_class.load(path)

@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, BaseTorchModuleWithVocab)
return isinstance(obj, BaseTorchModule)


class CatBoostDumper(BaseObjectDumper[CatBoostClassifier]):
Expand Down
3 changes: 2 additions & 1 deletion src/autointent/_wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .embedder import Embedder
from .vector_index import VectorIndex
from .base_torch_module import BaseTorchModuleWithVocab
from .base_torch_module import BaseTorchModule

__all__ = ["BaseTorchModuleWithVocab", "Embedder", "Ranker", "VectorIndex"]
__all__ = ["BaseTorchModule", "BaseTorchModuleWithVocab", "Embedder", "Ranker", "VectorIndex"]
93 changes: 51 additions & 42 deletions src/autointent/_wrappers/base_torch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,52 @@
from autointent.configs import VocabConfig


class BaseTorchModuleWithVocab(nn.Module, ABC):
class BaseTorchModule(nn.Module, ABC):
@abstractmethod
def forward(self, text: torch.Tensor) -> torch.Tensor:
"""Compute sentence embeddings for given text.

Args:
text: torch tensor of shape (B, T), token ids

Returns:
embeddings of shape (B, H)
"""

@abstractmethod
def dump(self, path: Path) -> None:
"""Dump torch module to disk.

This method encapsulates all the logic of dumping module's weights and
hyperparameters required for initialization from disk and nice inference.

Args:
path: path in file system
"""

@classmethod
@abstractmethod
def load(cls, path: Path, device: str | None = None) -> Self:
"""Load torch module from disk.

This method loads all weights and hyperparameters required for
initialization from disk and inference.

Args:
path: path in file system
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
"""

@property
def device(self) -> torch.device:
"""Torch device object where this module resides."""
return next(self.parameters()).device


class BaseTorchModuleWithVocab(BaseTorchModule, ABC):
def __init__(
self,
embed_dim: int,
embed_dim: int | None = None,
vocab_config: VocabConfig | None = None,
) -> None:
super().__init__()
Expand All @@ -34,6 +76,9 @@ def __init__(

def set_vocab(self, vocab: dict[str, Any]) -> None:
"""Save vocabulary into module's attributes and initialize embeddings matrix."""
if self.embed_dim is None:
msg = "embed_dim must be set to initialize embeddings"
raise ValueError(msg)
self.vocab_config.vocab = vocab
self.embedding = nn.Embedding(
num_embeddings=len(self.vocab_config.vocab),
Expand All @@ -43,6 +88,10 @@ def set_vocab(self, vocab: dict[str, Any]) -> None:

def build_vocab(self, utterances: list[str]) -> None:
"""Build vocabulary from training utterances."""
if self.embed_dim is None:
msg = "embed_dim must be set to initialize embeddings"
raise ValueError(msg)

if self.vocab_config.vocab is not None:
msg = "Vocab is already built."
raise RuntimeError(msg)
Expand Down Expand Up @@ -80,43 +129,3 @@ def text_to_indices(self, utterances: list[str]) -> list[list[int]]:
seq = seq + [self.vocab_config.padding_idx] * (self.vocab_config.max_seq_length - len(seq))
sequences.append(seq)
return sequences

@abstractmethod
def forward(self, text: torch.Tensor) -> torch.Tensor:
"""Compute sentence embeddings for given text.

Args:
text: torch tensor of shape (B, T), token ids

Returns:
embeddings of shape (B, H)
"""

@abstractmethod
def dump(self, path: Path) -> None:
"""Dump torch module to disk.

This method encapsulates all the logic of dumping module's weights and
hyperparameters required for initialization from disk and nice inference.

Args:
path: path in file system
"""

@classmethod
@abstractmethod
def load(cls, path: Path, device: str | None = None) -> Self:
"""Load torch module from disk.

This method loads all weights and hyperparameters required for
initialization from disk and inference.

Args:
path: path in file system
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
"""

@property
def device(self) -> torch.device:
"""Torch device object where this module resides."""
return next(self.parameters()).device
55 changes: 44 additions & 11 deletions src/autointent/_wrappers/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import tempfile
from functools import lru_cache
from pathlib import Path
from typing import Literal, cast, overload
from uuid import uuid4

import huggingface_hub
Expand Down Expand Up @@ -235,15 +236,28 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -

return cls(EmbedderConfig(**kwargs))

def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
@overload
def embed(
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[True]
) -> torch.Tensor: ...

@overload
def embed(
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[False] = False
) -> npt.NDArray[np.float32]: ...

def embed(
self, utterances: list[str], task_type: TaskTypeEnum | None = None, return_tensors: bool = False
) -> npt.NDArray[np.float32] | torch.Tensor:
"""Calculate embeddings for a list of utterances.

Args:
utterances: List of input texts to calculate embeddings for.
task_type: Type of task for which embeddings are calculated.
return_tensors: If True, return a PyTorch tensor; otherwise, return a numpy array.

Returns:
A numpy array of embeddings.
A numpy array or PyTorch tensor of embeddings.
"""
if len(utterances) == 0:
msg = "Empty input"
Expand All @@ -263,7 +277,10 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
embeddings_path = _get_embeddings_path(hasher.hexdigest())
if embeddings_path.exists():
logger.debug("loading embeddings from %s", str(embeddings_path))
return np.load(embeddings_path) # type: ignore[no-any-return]
embeddings_np = cast(npt.NDArray[np.float32], np.load(embeddings_path))
if return_tensors:
return torch.from_numpy(embeddings_np).to(self.config.device)
return embeddings_np

self._model = self._load_model()

Expand All @@ -279,17 +296,33 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
if self.config.tokenizer_config.max_length is not None:
self._model.max_seq_length = self.config.tokenizer_config.max_length

embeddings = self._model.encode(
utterances,
convert_to_numpy=True,
batch_size=self.config.batch_size,
normalize_embeddings=True,
prompt=prompt,
)
embeddings: npt.NDArray[np.float32] | torch.Tensor
if return_tensors:
embeddings = self._model.encode(
utterances,
convert_to_tensor=True,
batch_size=self.config.batch_size,
normalize_embeddings=True,
prompt=prompt,
)
else:
embeddings = cast(
npt.NDArray[np.float32],
self._model.encode(
utterances,
convert_to_numpy=True,
batch_size=self.config.batch_size,
normalize_embeddings=True,
prompt=prompt,
),
)

if self.config.use_cache:
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
np.save(embeddings_path, embeddings)
if isinstance(embeddings, torch.Tensor):
np.save(embeddings_path, embeddings.cpu().numpy())
else:
np.save(embeddings_path, embeddings)

return embeddings

Expand Down
2 changes: 2 additions & 0 deletions src/autointent/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CNNScorer,
CrossEncoderDescriptionScorer,
DNNCScorer,
GCNScorer,
KNNScorer,
LinearScorer,
LLMDescriptionScorer,
Expand Down Expand Up @@ -47,6 +48,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
[
CatBoostScorer,
DNNCScorer,
GCNScorer,
KNNScorer,
LinearScorer,
BiEncoderDescriptionScorer,
Expand Down
2 changes: 2 additions & 0 deletions src/autointent/modules/scoring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ._catboost import CatBoostScorer
from ._description import BiEncoderDescriptionScorer, CrossEncoderDescriptionScorer, LLMDescriptionScorer
from ._dnnc import DNNCScorer
from ._gcn import GCNScorer
from ._knn import KNNScorer, RerankScorer
from ._linear import LinearScorer
from ._lora import BERTLoRAScorer
Expand All @@ -18,6 +19,7 @@
"CatBoostScorer",
"CrossEncoderDescriptionScorer",
"DNNCScorer",
"GCNScorer",
"KNNScorer",
"LLMDescriptionScorer",
"LinearScorer",
Expand Down
3 changes: 3 additions & 0 deletions src/autointent/modules/scoring/_gcn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gcn_scorer import GCNScorer

__all__ = ["GCNScorer"]
Loading
Loading