Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4caced7
add train
k0lenk4 Jul 3, 2025
c5b5b2c
fixed env
k0lenk4 Jul 3, 2025
e13f171
deleted kwargs and local savings, added config
k0lenk4 Jul 28, 2025
d26eda0
added test for train method
k0lenk4 Jul 28, 2025
ee1b4a1
add EmbedderFineTuningConfig to __init__
k0lenk4 Aug 2, 2025
3052628
correct __init__ in config, remov pytest in test file
k0lenk4 Aug 2, 2025
a96c27d
correct some syntax isues
k0lenk4 Aug 2, 2025
bdc1161
move batch_size to EmbedderFineTuningConfig
k0lenk4 Aug 11, 2025
d71de34
add __init__.py to /test/embedder
k0lenk4 Aug 11, 2025
941c13a
Remove whitespace from blank line
k0lenk4 Aug 11, 2025
3ecdc60
Merge remote-tracking branch 'origin/dev' into feat/train-embeddings
k0lenk4 Aug 11, 2025
0739413
correct errors
k0lenk4 Aug 11, 2025
1e161c6
the number of epochs and train objects have been increased
k0lenk4 Aug 11, 2025
e67f1bc
made lint
k0lenk4 Aug 11, 2025
3c38ec8
add early stopping
k0lenk4 Aug 12, 2025
c743c0b
remake train args
k0lenk4 Aug 12, 2025
71bf957
make a list of callbacks
k0lenk4 Aug 12, 2025
2963a4c
inline type annotation of variable "callback"
k0lenk4 Aug 12, 2025
03e4c59
change save_strategy to "epoch"
k0lenk4 Aug 12, 2025
714f910
default value of fp16 changed to False
k0lenk4 Aug 15, 2025
cb9b2ea
pull dev
voorhs Aug 18, 2025
6aa7abc
integrate embeddings fine-tuning into Embedding modules
voorhs Aug 19, 2025
b88a810
pull dev
voorhs Aug 19, 2025
2970737
Update optimizer_config.schema.json
github-actions[bot] Aug 19, 2025
fcf1f31
clean up `freeze` throughout tests and tutorials
voorhs Aug 19, 2025
0b0c1fa
add comprehensive tests
voorhs Aug 19, 2025
19a74f4
embedder_model -> _model
voorhs Aug 19, 2025
1d73af6
fix early stopping
voorhs Aug 25, 2025
714c8c2
fix tests
voorhs Aug 25, 2025
ebf066b
clear ram bug fix
voorhs Aug 25, 2025
51a9b1a
try to fix windows cleanup issue
voorhs Aug 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions autointent/_wrappers/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@
from functools import lru_cache
from pathlib import Path
from typing import TypedDict
import tempfile

import huggingface_hub
import numpy as np
import numpy.typing as npt
import torch
from appdirs import user_cache_dir
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.losses import BatchAllTripletLoss
from sentence_transformers.training_args import BatchSamplers
from datasets import Dataset


from autointent._hash import Hasher
from autointent.configs import EmbedderConfig, TaskTypeEnum
from autointent.configs import EmbedderConfig, TaskTypeEnum, EmbedderFineTuningConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -122,7 +127,41 @@ def _load_model(self) -> None:
similarity_fn_name=self.config.similarity_fn_name,
trust_remote_code=self.config.trust_remote_code,
)
def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTuningConfig) -> None:
"""Train the embedding model"""
self._load_model()

tr_ds = Dataset.from_dict({
"text": utterances,
"label": labels
})

loss = BatchAllTripletLoss(
model=self.embedding_model,
margin=config.margin
)
with tempfile.TemporaryDirectory() as tmp_dir:
args = SentenceTransformerTrainingArguments(
save_strategy="no",
Comment thread
voorhs marked this conversation as resolved.
Outdated
output_dir=tmp_dir,
num_train_epochs=config.epoch_num,
per_device_train_batch_size=self.config.batch_size,
Comment thread
voorhs marked this conversation as resolved.
Outdated
learning_rate=config.learning_rate,
warmup_ratio=config.warmup_ratio,
fp16=config.fp16,
bf16=config.bf16,
batch_sampler=BatchSamplers.NO_DUPLICATES,
)

trainer = SentenceTransformerTrainer(
model=self.embedding_model,
args=args,
train_dataset=tr_ds,
loss=loss,
)

trainer.train()

def clear_ram(self) -> None:
"""Move the embedding model to CPU and delete it from memory."""
if hasattr(self, "embedding_model"):
Expand Down
7 changes: 7 additions & 0 deletions autointent/configs/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ class TokenizerConfig(BaseModel):
truncation: bool = True
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")

class EmbedderFineTuningConfig(BaseModel):
epoch_num: int
margin: float = Field(default=0.5)
learning_rate: float = Field(default=2e-5)
warmup_ratio: float = Field(default=0.1)
fp16: bool = Field(default=True)
Comment thread
voorhs marked this conversation as resolved.
Outdated
bf16: bool = Field(default=False)

class HFModelConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers=[
'Framework :: Sphinx',
'Typing :: Typed',
]
requires-python = ">=3.10,<4.0"
requires-python = ">=3.10,<3.13"
dependencies = [
"sentence-transformers (>=3,<4)",
"scikit-learn (>=1.5,<2.0)",
Expand Down
55 changes: 55 additions & 0 deletions tests/embedder/test_fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from autointent.context.data_handler import DataHandler
from autointent._wrappers.embedder import Embedder
from autointent.configs._transformers import HFModelConfig, EmbedderConfig, EmbedderFineTuningConfig
import numpy as np
import pytest

def test_model_updates_after_training(dataset):
"""Test that model weights actually change after training"""
data_handler = DataHandler(dataset)

hf_config = HFModelConfig(
model_name="intfloat/multilingual-e5-small",
batch_size=8,
trust_remote_code=True
)

embedder_config = EmbedderConfig(
**hf_config.model_dump(),
default_prompt="Represent this text for retrieval:",
query_prompt="Search query:",
passage_prompt="Document:",
similarity_fn_name="cosine",
use_cache=False,
freeze=False
)

train_config = EmbedderFineTuningConfig(
epoch_num = 1
)
embedder = Embedder(embedder_config)
embedder._load_model()

original_weights = [
param.data.detach().cpu().numpy().copy()
for param in embedder.embedding_model.parameters()
if param.requires_grad
]
embedder.train(
utterances=data_handler.train_utterances(0)[:10],
labels=data_handler.train_labels(0)[:10],
config=train_config
)

trained_weights = [
param.data.detach().cpu().numpy().copy()
for param in embedder.embedding_model.parameters()
if param.requires_grad
]

weights_changed = any(
not np.allclose(orig, trained, atol=1e-6)
for orig, trained in zip(original_weights, trained_weights)
)

assert weights_changed, "Model weights should change after training"
Loading