diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 9a92717..5f1b965 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -34,6 +34,7 @@ import shutil import time import uuid +from typing import Literal import numpy as np import polars as pl @@ -69,12 +70,12 @@ class VectorStore: """A class to model and create `VectorStore` objects for building and searching vector databases from CSV text files. Attributes: - file_name (str): the data file contatining the knowledgebase to build the `VectorStore` - data_type (str): the data type of the data file (curently only csv supported) - vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Pacakge module + file_name (str | os.PathLike[str]): the data file contatining the knowledgebase to build the `VectorStore` + data_type (Literal["csv"]): the data type of the data file (curently only csv supported) + vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Package module batch_size (int): the batch size to pass to the vectoriser when embedding meta_data (dict): key-value pairs of metadata to extract from the input file and their correpsonding types - output_dir (str): the path to the output directory where the `VectorStore` will be saved + output_dir (str | os.PathLike[str]): the path to the output directory where the `VectorStore` will be saved vectors (np.array): a numpy array of vectors for the vector database vector_shape (int): the dimension of the vectors num_vectors (int): the number of records saved in the `VectorStore` @@ -84,12 +85,12 @@ class VectorStore: def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self, - file_name: str, - data_type: str, + file_name: str | os.PathLike[str], + data_type: Literal["csv"], vectoriser: VectoriserBase, batch_size: int = 8, meta_data: dict | None = None, - output_dir: str | None = None, + output_dir: str | os.PathLike[str] | None = None, overwrite: bool = False, hooks: dict | None = None, ): @@ -97,9 +98,9 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 vector embeddings. Args: - file_name (str): The name of the input CSV file. + file_name (str | os.PathLike): The name of the input CSV file. data_type (str): The type of input data (currently supports only "csv"). - vectoriser (object): The `Vectoriser` object used to transform text into + vectoriser (VectoriserBase): The `Vectoriser` object used to transform text into vector embeddings. batch_size (int): [optional] The batch size for processing the input file and batching to vectoriser. Defaults to 8. @@ -119,8 +120,10 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 `IndexBuildError`: If there are failures during index building or saving outputs. """ # ---- Input validation (caller mistakes) -> DataValidationError / ConfigurationError - if not isinstance(file_name, str) or not file_name.strip(): - raise DataValidationError("file_name must be a non-empty string.", context={"file_name": file_name}) + if not isinstance(file_name, (str, os.PathLike)) or not os.fspath(file_name).strip(): + raise DataValidationError( + "file_name must be a non-empty string or os.PathLike.", context={"file_name": file_name} + ) if not os.path.exists(file_name): raise DataValidationError("Input file does not exist.", context={"file_name": file_name}) @@ -149,17 +152,15 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__}) # ---- Assign fields + ## all these fields are all initalised from inputs self.file_name = file_name self.data_type = data_type self.vectoriser = vectoriser self.batch_size = batch_size self.meta_data = meta_data if meta_data is not None else {} self.output_dir = output_dir - self.vectors = None - self.vector_shape = None - self.num_vectors = None - self.vectoriser_class = vectoriser.__class__.__name__ self.hooks = {} if hooks is None else hooks + self.vectoriser_class = vectoriser.__class__.__name__ # ---- Output directory handling (filesystem problems) -> ConfigurationError try: @@ -185,7 +186,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 # ---- Build index (wrap every unexpected failure) -> IndexBuildError try: - self._create_vector_store_index() + self._create_vector_store_index(os.fspath(self.file_name)) except ClassifaiError: # preserve already-classified errors (e.g. vectoriser raised DataValidationError) raise @@ -260,7 +261,7 @@ def _save_metadata(self, path: str): context={"path": path, "metadata": metadata, "cause_type": type(e).__name__, "cause_message": str(e)}, ) from e - def _create_vector_store_index(self): # noqa: C901 + def _create_vector_store_index(self, file_name: str): # noqa: C901 """Processes text strings in batches, generates vector embeddings, and creates the `VectorStore`. Called from the constructor once other metadata has been set. @@ -268,6 +269,9 @@ def _create_vector_store_index(self): # noqa: C901 Creates a Polars DataFrame with the captured data and embeddings, and saves it as a Parquet file in the output_dir attribute, and stores in the vectors attribute. + Args: + file_name (str): The filename of csv to read in + Raises: `DataValidationError`: If there are issues reading or validating the input file. `IndexBuildError`: If there are failures during embedding or building the vectors table. @@ -276,9 +280,9 @@ def _create_vector_store_index(self): # noqa: C901 try: if self.data_type == "csv": self.vectors = pl.read_csv( - self.file_name, + file_name, columns=["label", "text", *self.meta_data.keys()], - dtypes=self.meta_data | {"label": str, "text": str}, + schema_overrides=self.meta_data | {"label": str, "text": str}, ) self.vectors = self.vectors.with_columns( pl.Series("uuid", [str(uuid.uuid4()) for _ in range(self.vectors.height)]) @@ -740,7 +744,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V return result_df @classmethod - def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # noqa: C901, PLR0912, PLR0915 + def from_filespace(cls, folder_path: str | os.PathLike[str], vectoriser: VectoriserBase, hooks: dict | None = None): # noqa: C901, PLR0912, PLR0915 """Creates a `VectorStore` instance from stored metadata and Parquet files. This method reads the metadata and vectors from the specified folder, validates the contents, and initializes a `VectorStore` object with the @@ -752,8 +756,8 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # needing to reprocess the original text data. Args: - folder_path (str): The folder path containing the metadata and Parquet files. - vectoriser (object): The `Vectoriser` object used to transform text into vector embeddings. + folder_path (str | os.PathLike): The folder path containing the metadata and Parquet files. + vectoriser (VectoriserBase): The `Vectoriser` object used to transform text into vector embeddings. hooks (dict): [optional] A dictionary of user-defined hooks for preprocessing and postprocessing. Defaults to None. Returns: @@ -765,7 +769,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # `IndexBuildError`: If there are failures during loading or parsing the files. """ # ---- Validate arguments (caller mistakes) -> DataValidationError / ConfigurationError - if not isinstance(folder_path, str) or not folder_path.strip(): + if not isinstance(folder_path, (str, os.PathLike)) or not os.fspath(folder_path).strip(): raise DataValidationError("folder_path must be a non-empty string.", context={"folder_path": folder_path}) if not os.path.isdir(folder_path): diff --git a/src/classifai/servers/pydantic_models.py b/src/classifai/servers/pydantic_models.py index e86e412..3b889cb 100644 --- a/src/classifai/servers/pydantic_models.py +++ b/src/classifai/servers/pydantic_models.py @@ -2,7 +2,7 @@ """Pydantic Classes to model request and response data for ClassifAI FastAPI RESTful API.""" import pandas as pd -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, ConfigDict, Field class SearchRequestEntry(BaseModel): @@ -33,8 +33,7 @@ class SearchResponseEntry(BaseModel): rank: int = Field(description="The rank of the result entry for the given query, with 1 being the most relevant.") score: float = Field(description="The similarity score of the result entry for the given query.") - class Config: - extra = Extra.allow # Allow extra keys (e.g., metadata columns)å + model_config = ConfigDict(extra="allow") class SearchResponseSet(BaseModel): @@ -81,8 +80,7 @@ class ReverseSearchResponseEntry(BaseModel): doc_label: str doc_text: str - class Config: - extra = Extra.allow # Allow extra keys (e.g., metadata columns) + model_config = ConfigDict(extra="allow") class ReverseSearchResponseSet(BaseModel): @@ -135,8 +133,7 @@ class EmbedResponseEntry(BaseModel): description="The vector embedding result for the input text string, represented as a list of floats." ) - class Config: - extra = Extra.allow # Allow extra keys (e.g., metadata columns) + model_config = ConfigDict(extra="allow") class EmbedResponseBody(BaseModel):