Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4caced7
add train
k0lenk4 Jul 3, 2025
c5b5b2c
fixed env
k0lenk4 Jul 3, 2025
e13f171
deleted kwargs and local savings, added config
k0lenk4 Jul 28, 2025
d26eda0
added test for train method
k0lenk4 Jul 28, 2025
ee1b4a1
add EmbedderFineTuningConfig to __init__
k0lenk4 Aug 2, 2025
3052628
correct __init__ in config, remov pytest in test file
k0lenk4 Aug 2, 2025
a96c27d
correct some syntax isues
k0lenk4 Aug 2, 2025
bdc1161
move batch_size to EmbedderFineTuningConfig
k0lenk4 Aug 11, 2025
d71de34
add __init__.py to /test/embedder
k0lenk4 Aug 11, 2025
941c13a
Remove whitespace from blank line
k0lenk4 Aug 11, 2025
3ecdc60
Merge remote-tracking branch 'origin/dev' into feat/train-embeddings
k0lenk4 Aug 11, 2025
0739413
correct errors
k0lenk4 Aug 11, 2025
1e161c6
the number of epochs and train objects have been increased
k0lenk4 Aug 11, 2025
e67f1bc
made lint
k0lenk4 Aug 11, 2025
3c38ec8
add early stopping
k0lenk4 Aug 12, 2025
c743c0b
remake train args
k0lenk4 Aug 12, 2025
71bf957
make a list of callbacks
k0lenk4 Aug 12, 2025
2963a4c
inline type annotation of variable "callback"
k0lenk4 Aug 12, 2025
03e4c59
change save_strategy to "epoch"
k0lenk4 Aug 12, 2025
714f910
default value of fp16 changed to False
k0lenk4 Aug 15, 2025
cb9b2ea
pull dev
voorhs Aug 18, 2025
6aa7abc
integrate embeddings fine-tuning into Embedding modules
voorhs Aug 19, 2025
b88a810
pull dev
voorhs Aug 19, 2025
2970737
Update optimizer_config.schema.json
github-actions[bot] Aug 19, 2025
fcf1f31
clean up `freeze` throughout tests and tutorials
voorhs Aug 19, 2025
0b0c1fa
add comprehensive tests
voorhs Aug 19, 2025
19a74f4
embedder_model -> _model
voorhs Aug 19, 2025
1d73af6
fix early stopping
voorhs Aug 25, 2025
714c8c2
fix tests
voorhs Aug 25, 2025
ebf066b
clear ram bug fix
voorhs Aug 25, 2025
51a9b1a
try to fix windows cleanup issue
voorhs Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,5 @@ vector_db*
*.db
*.sqlite
/wandb
model_output/
my.py
9 changes: 1 addition & 8 deletions docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,6 @@
"description": "Whether to use embeddings caching.",
"title": "Use Cache",
"type": "boolean"
},
"freeze": {
"default": true,
"description": "Whether to freeze the model parameters.",
"title": "Freeze",
"type": "boolean"
}
},
"title": "EmbedderConfig",
Expand Down Expand Up @@ -578,8 +572,7 @@
"query_prompt": null,
"passage_prompt": null,
"similarity_fn_name": "cosine",
"use_cache": true,
"freeze": true
"use_cache": true
}
},
"cross_encoder_config": {
Expand Down
114 changes: 100 additions & 14 deletions src/autointent/_wrappers/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@
import json
import logging
import shutil
import tempfile
from functools import lru_cache
from pathlib import Path
from uuid import uuid4

import huggingface_hub
import numpy as np
import numpy.typing as npt
import torch
from appdirs import user_cache_dir
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import BatchAllTripletLoss
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.training_args import BatchSamplers
from sklearn.model_selection import train_test_split
from transformers import EarlyStoppingCallback, TrainerCallback

from autointent._hash import Hasher
from autointent.configs import EmbedderConfig, TaskTypeEnum
from autointent.configs import EmbedderConfig, EmbedderFineTuningConfig, TaskTypeEnum
from autointent.custom_types import ListOfLabels

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,15 +74,18 @@ class Embedder:
"""

_metadata_dict_name: str = "metadata.json"
_weights_dir_name: str = "sentence_transformer"
_dump_dir: Path | None = None
_trained: bool = False
_model: SentenceTransformer

def __init__(self, embedder_config: EmbedderConfig) -> None:
"""Initialize the Embedder.

Args:
embedder_config: Config of embedder.
"""
self.config = embedder_config
self.config = embedder_config.model_copy(deep=True)

def _get_hash(self) -> int:
"""Compute a hash value for the Embedder.
Expand All @@ -83,19 +94,19 @@ def _get_hash(self) -> int:
The hash value of the Embedder.
"""
hasher = Hasher()
if self.config.freeze:
if not Path(self.config.model_name).exists():
commit_hash = _get_latest_commit_hash(self.config.model_name)
hasher.update(commit_hash)
else:
self.embedding_model = self._load_model()
for parameter in self.embedding_model.parameters():
self._model = self._load_model()
for parameter in self._model.parameters():
hasher.update(parameter.detach().cpu().numpy())
hasher.update(self.config.tokenizer_config.max_length)
return hasher.intdigest()

def _load_model(self) -> SentenceTransformer:
"""Load sentence transformers model to device."""
if not hasattr(self, "embedding_model"):
if not hasattr(self, "_model"):
res = SentenceTransformer(
self.config.model_name,
device=self.config.device,
Expand All @@ -104,15 +115,80 @@ def _load_model(self) -> SentenceTransformer:
trust_remote_code=self.config.trust_remote_code,
)
else:
res = self.embedding_model
res = self._model
return res

def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFineTuningConfig) -> None:
"""Train the embedding model."""
if len(utterances) != len(labels):
msg = f"Utterances and labels lists lengths mismatch: {len(utterances)=} != {len(labels)=}"
raise ValueError(msg)

if len(labels) == 0:
msg = "Empty data"
raise ValueError(msg)

# TODO support multi-label data
if isinstance(labels[0], list):
msg = "Multi-label data is not supported for embeddings fine-tuning for now"
logger.warning(msg)
return

self._model = self._load_model()

x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=config.val_fraction)
tr_ds = Dataset.from_dict({"text": x_train, "label": y_train})
val_ds = Dataset.from_dict({"text": x_val, "label": y_val})

loss = BatchAllTripletLoss(model=self._model, margin=config.margin)
with tempfile.TemporaryDirectory() as tmp_dir:
args = SentenceTransformerTrainingArguments(
save_strategy="epoch",
save_total_limit=1,
output_dir=tmp_dir,
num_train_epochs=config.epoch_num,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
learning_rate=config.learning_rate,
warmup_ratio=config.warmup_ratio,
fp16=config.fp16,
bf16=config.bf16,
batch_sampler=BatchSamplers.NO_DUPLICATES,
metric_for_best_model="eval_loss",
load_best_model_at_end=True,
eval_strategy="epoch",
greater_is_better=False,
)
callbacks: list[TrainerCallback] = [
EarlyStoppingCallback(
early_stopping_patience=config.early_stopping_patience,
early_stopping_threshold=config.early_stopping_threshold,
)
]
trainer = SentenceTransformerTrainer(
model=self._model,
args=args,
train_dataset=tr_ds,
eval_dataset=val_ds,
loss=loss,
callbacks=callbacks,
)

trainer.train()

# use temporary path for re-usage
model_path = str(Path(tempfile.mkdtemp("autointent_embedders")) / str(uuid4()))
self._model.save(model_path)
self.config.model_name = model_path

self._trained = True

def clear_ram(self) -> None:
"""Move the embedding model to CPU and delete it from memory."""
if hasattr(self, "embedding_model"):
if hasattr(self, "_model"):
logger.debug("Clearing embedder %s from memory", self.config.model_name)
self.embedding_model.cpu()
del self.embedding_model
self._model.cpu()
del self._model
torch.cuda.empty_cache()

def delete(self) -> None:
Expand All @@ -127,6 +203,11 @@ def dump(self, path: Path) -> None:
Args:
path: Path to the directory where the model will be saved.
"""
if self._trained:
model_path = str((path / self._weights_dir_name).resolve())
self._model.save(model_path, create_model_card=False)
self.config.model_name = model_path

self._dump_dir = path
path.mkdir(parents=True, exist_ok=True)
with (path / self._metadata_dict_name).open("w") as file:
Expand Down Expand Up @@ -164,6 +245,11 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
Returns:
A numpy array of embeddings.
"""
if len(utterances) == 0:
msg = "Empty input"
logger.error(msg)
raise ValueError(msg)

prompt = self.config.get_prompt(task_type)

if self.config.use_cache:
Expand All @@ -179,7 +265,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
logger.debug("loading embeddings from %s", str(embeddings_path))
return np.load(embeddings_path) # type: ignore[no-any-return]

self.embedding_model = self._load_model()
self._model = self._load_model()

logger.debug(
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",
Expand All @@ -191,9 +277,9 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
)

if self.config.tokenizer_config.max_length is not None:
self.embedding_model.max_seq_length = self.config.tokenizer_config.max_length
self._model.max_seq_length = self.config.tokenizer_config.max_length

embeddings = self.embedding_model.encode(
embeddings = self._model.encode(
utterances,
convert_to_numpy=True,
batch_size=self.config.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions src/autointent/_wrappers/vector_index/vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ class VectorIndex:
embedder: Embedder
index: BaseIndexBackend

def __init__(self, embedder_config: EmbedderConfig, config: VectorIndexConfig) -> None:
def __init__(self, embedder_config: EmbedderConfig | Embedder, config: VectorIndexConfig) -> None:
"""Initialize the VectorIndex with an embedding model.

Args:
embedder_config: Configuration for the embedding model.
config: settings for vector index.
backend: vector index backend to use.
"""
self.embedder = Embedder(embedder_config)
self.embedder = embedder_config if isinstance(embedder_config, Embedder) else Embedder(embedder_config)
self.config = config

def _init_index(self, vector_size: int) -> BaseIndexBackend:
Expand Down
2 changes: 2 additions & 0 deletions src/autointent/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CrossEncoderConfig,
EarlyStoppingConfig,
EmbedderConfig,
EmbedderFineTuningConfig,
HFModelConfig,
TaskTypeEnum,
TokenizerConfig,
Expand All @@ -18,6 +19,7 @@
"DataConfig",
"EarlyStoppingConfig",
"EmbedderConfig",
"EmbedderFineTuningConfig",
"FaissConfig",
"HFModelConfig",
"HPOConfig",
Expand Down
34 changes: 29 additions & 5 deletions src/autointent/configs/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from typing_extensions import Self
from typing_extensions import Self, assert_never

from autointent.custom_types import FloatFromZeroToOne
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
Expand All @@ -15,6 +15,29 @@ class TokenizerConfig(BaseModel):
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")


class EmbedderFineTuningConfig(BaseModel):
epoch_num: int
batch_size: int
margin: float = Field(default=0.5)
learning_rate: float = Field(default=2e-5)
warmup_ratio: float = Field(default=0.1)
early_stopping_patience: int = Field(default=1)
early_stopping_threshold: float = Field(default=0.0)
val_fraction: float = Field(default=0.2)
fp16: bool = Field(default=False)
bf16: bool = Field(default=False)

@classmethod
def from_search_config(cls, values: dict[str, Any] | BaseModel | None) -> Self | None:
if isinstance(values, BaseModel):
return cls(**values.model_dump())
if isinstance(values, dict):
return cls(**values)
if values is None:
return None
assert_never(values)


class HFModelConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
model_name: str = Field(
Expand Down Expand Up @@ -42,7 +65,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
if values is None:
return cls()
if isinstance(values, BaseModel):
return values # type: ignore[return-value]
return cls(**values.model_dump())
if isinstance(values, str):
return cls(model_name=values)
return cls(**values)
Expand Down Expand Up @@ -73,7 +96,6 @@ class EmbedderConfig(HFModelConfig):
"cosine", description="Name of the similarity function to use."
)
use_cache: bool = Field(True, description="Whether to use embeddings caching.")
freeze: bool = Field(True, description="Whether to freeze the model parameters.")

def get_prompt_config(self) -> dict[str, str] | None:
"""Get the prompt config for the given prompt type.
Expand Down Expand Up @@ -162,5 +184,7 @@ def from_search_config(cls, values: dict[str, Any] | BaseModel | None) -> Self:
if values is None:
return cls()
if isinstance(values, BaseModel):
return values # type: ignore[return-value]
return cls(**values)
return cls(**values.model_dump())
if isinstance(values, dict):
return cls(**values)
assert_never(values)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from autointent.configs import EmbedderConfig, InferenceNodeConfig
from autointent.custom_types import NodeType

from ._data_models import Artifact, Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials
from ._data_models import Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials

if TYPE_CHECKING:
from autointent.modules.base import BaseModule
Expand Down Expand Up @@ -95,7 +95,6 @@ def log_module_optimization(
metric_value: float,
metric_name: str,
metrics: dict[str, float],
artifact: Artifact,
module_dump_dir: str | None,
module: "BaseModule",
) -> None:
Expand All @@ -108,7 +107,6 @@ def log_module_optimization(
metric_value: Metric value achieved by the module.
metric_name: Name of the evaluation metric.
metrics: Dictionary of metric names and their values.
artifact: Artifact generated by the module.
module_dump_dir: Directory where the module is dumped.
module: The module instance, if available.
"""
Expand All @@ -117,7 +115,7 @@ def log_module_optimization(
self.modules.add_module(node_type, module)
if module_dump_dir is not None:
module.dump(module_dump_dir)
self.artifacts.add_artifact(node_type, artifact)
self.artifacts.add_artifact(node_type, module.get_assets())

if old_best_metric_value_idx is not None:
prev_best_dump = self.trials.get_trials(node_type)[old_best_metric_value_idx].module_dump_dir
Expand Down
Loading
Loading