diff --git a/Makefile b/Makefile index 1f466d59..37defcda 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ sh = uv run --no-sync --frozen .PHONY: install install: rm -rf uv.lock - uv sync --all-groups --extra catboost --extra peft --extra sentence-transformers --extra transformers + uv sync --all-groups --extra catboost --extra peft --extra sentence-transformers --extra transformers --extra openai .PHONY: test test: diff --git a/src/autointent/_utils.py b/src/autointent/_utils.py index f4092756..2cf709ce 100644 --- a/src/autointent/_utils.py +++ b/src/autointent/_utils.py @@ -1,10 +1,15 @@ """Utils.""" +from __future__ import annotations + import importlib -from typing import Any, TypeVar +from typing import TYPE_CHECKING, TypeVar import torch +if TYPE_CHECKING: + from types import ModuleType + T = TypeVar("T") @@ -28,7 +33,7 @@ def detect_device() -> str: return "cpu" -def require(dependency: str, extra: str | None = None) -> Any: # noqa: ANN401 +def require(dependency: str, extra: str | None = None) -> ModuleType: """Try to import dependency, raise informative ImportError if missing. Args: diff --git a/src/autointent/_wrappers/embedder/openai.py b/src/autointent/_wrappers/embedder/openai.py index 6640a36d..845c232f 100644 --- a/src/autointent/_wrappers/embedder/openai.py +++ b/src/autointent/_wrappers/embedder/openai.py @@ -59,7 +59,7 @@ def __init__(self, config: OpenaiEmbeddingConfig) -> None: def _get_client(self) -> openai.OpenAI: """Get or create OpenAI client instance.""" - import openai + openai = require("openai", "openai") if self._client is None: self._client = openai.OpenAI( @@ -71,7 +71,7 @@ def _get_client(self) -> openai.OpenAI: def _get_async_client(self) -> openai.AsyncOpenAI: """Get or create async OpenAI client instance.""" - import openai + openai = require("openai", "openai") if self._async_client is None: self._async_client = openai.AsyncOpenAI( @@ -308,3 +308,4 @@ def load(cls, path: Path) -> OpenaiEmbeddingBackend: # Create instance return cls(config) + diff --git a/src/autointent/generation/_generator.py b/src/autointent/generation/_generator.py index 0a63a549..bf0d6879 100644 --- a/src/autointent/generation/_generator.py +++ b/src/autointent/generation/_generator.py @@ -220,7 +220,7 @@ async def _get_structured_output_openai_async( Returns: Tuple of (parsed_result, error_message, raw_response). """ - from openai import LengthFinishReasonError + openai = require("openai", "openai") res: T | None = None msg: str | None = None @@ -235,7 +235,7 @@ async def _get_structured_output_openai_async( ) raw = response.choices[0].message.content res = response.choices[0].message.parsed - except (ValidationError, ValueError, LengthFinishReasonError) as e: + except (ValidationError, ValueError, openai.LengthFinishReasonError) as e: msg = f"Failed to obtain structured output for model {self.model_name} and messages {messages}: {e!s}" logger.warning(msg) else: @@ -307,7 +307,7 @@ def _get_structured_output_openai_sync( Returns: Tuple of (parsed_result, error_message, raw_response). """ - from openai import LengthFinishReasonError + openai = require("openai", "openai") res: T | None = None msg: str | None = None @@ -322,7 +322,7 @@ def _get_structured_output_openai_sync( ) raw = response.choices[0].message.content res = response.choices[0].message.parsed - except (ValidationError, ValueError, LengthFinishReasonError) as e: + except (ValidationError, ValueError, openai.LengthFinishReasonError) as e: msg = f"Failed to obtain structured output for model {self.model_name} and messages {messages}: {e!s}" logger.warning(msg) else: diff --git a/src/autointent/modules/scoring/_bert.py b/src/autointent/modules/scoring/_bert.py index 399cfe42..fa8e4bff 100644 --- a/src/autointent/modules/scoring/_bert.py +++ b/src/autointent/modules/scoring/_bert.py @@ -128,12 +128,12 @@ def get_implicit_initialization_params(self) -> dict[str, Any]: } def _initialize_model(self) -> Any: # noqa: ANN401 - from transformers import AutoModelForSequenceClassification + transformers = require("transformers", "transformers") label2id = {i: i for i in range(self._n_classes)} id2label = {i: i for i in range(self._n_classes)} - return AutoModelForSequenceClassification.from_pretrained( + return transformers.AutoModelForSequenceClassification.from_pretrained( self.classification_model_config.model_name, trust_remote_code=self.classification_model_config.trust_remote_code, num_labels=self._n_classes, @@ -147,11 +147,11 @@ def fit( utterances: list[str], labels: ListOfLabels, ) -> None: - from transformers import AutoTokenizer + transformers = require("transformers", "transformers") self._validate_task(labels) - self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name) # type: ignore[no-untyped-call] + self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.classification_model_config.model_name) self._model = self._initialize_model() tokenized_dataset = self._get_tokenized_dataset(utterances, labels) self._train(tokenized_dataset) @@ -164,10 +164,10 @@ def _train(self, tokenized_dataset: DatasetDict) -> None: Args: tokenized_dataset: output from :py:meth:`BertScorer._get_tokenized_dataset` """ - from transformers import DataCollatorWithPadding, PrinterCallback, ProgressCallback, Trainer, TrainingArguments + transformers = require("transformers", "transformers") with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = transformers.TrainingArguments( output_dir=tmp_dir, num_train_epochs=self.num_train_epochs, per_device_train_batch_size=self.batch_size, @@ -186,29 +186,29 @@ def _train(self, tokenized_dataset: DatasetDict) -> None: load_best_model_at_end=self.early_stopping_config.metric is not None, ) - trainer = Trainer( + trainer = transformers.Trainer( model=self._model, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["validation"], processing_class=self._tokenizer, - data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer), + data_collator=transformers.DataCollatorWithPadding(tokenizer=self._tokenizer), compute_metrics=self._get_compute_metrics(), callbacks=self._get_trainer_callbacks(), ) if not self.print_progress: - trainer.remove_callback(PrinterCallback) - trainer.remove_callback(ProgressCallback) + trainer.remove_callback(transformers.PrinterCallback) + trainer.remove_callback(transformers.ProgressCallback) trainer.train() def _get_trainer_callbacks(self) -> list[TrainerCallback]: - from transformers import EarlyStoppingCallback + transformers = require("transformers", "transformers") res: list[TrainerCallback] = [] if self.early_stopping_config.metric is not None: res.append( - EarlyStoppingCallback( + transformers.EarlyStoppingCallback( early_stopping_patience=self.early_stopping_config.patience, early_stopping_threshold=self.early_stopping_config.threshold, ) diff --git a/tests/modules/test_dumper.py b/tests/modules/test_dumper.py index 96552c5a..48b37336 100644 --- a/tests/modules/test_dumper.py +++ b/tests/modules/test_dumper.py @@ -62,8 +62,6 @@ def check_attributes(self): class TestVectorIndex: def init_attributes(self): - pytest.importorskip("sentence_transformers", reason="Sentence Transformers library is required for these tests") - self.vector_index = VectorIndex( embedder_config=initialize_embedder_config("bert-base-uncased"), config=FaissConfig(), @@ -178,6 +176,14 @@ def _transformers_is_installed() -> bool: id="transformer", ), TestVectorIndex, + pytest.param( + TestVectorIndex, + marks=pytest.mark.skipif( + not _st_is_installed(), + reason="need sentence-transformers dependency", + ), + id="vector_index", + ), pytest.param( TestEmbedder, marks=pytest.mark.skipif(