diff --git a/gptcache/manager/vector_data/pgvector.py b/gptcache/manager/vector_data/pgvector.py index 55ab0c99..61e7f590 100644 --- a/gptcache/manager/vector_data/pgvector.py +++ b/gptcache/manager/vector_data/pgvector.py @@ -25,10 +25,13 @@ class _VectorType(UserDefinedType): """ cache_ok = True - def __init__(self, precision=8): + def __init__(self, precision=8, use_halfvec=False): self.precision = precision + self.use_halfvec = use_halfvec def get_col_spec(self, **_): + if self.use_halfvec + return f"halfvec({self.precision})" return f"vector({self.precision})" # pylint: disable=unused-argument @@ -40,7 +43,7 @@ def result_processor(self, dialect, coltype): return lambda value: value -def _get_model_and_index(table_prefix, vector_dimension, index_type, lists): +def _get_model_and_index(table_prefix, vector_dimension, index_type, lists, use_halfvec=False): class VectorStoreTable(Base): """ vector store table @@ -49,7 +52,7 @@ class VectorStoreTable(Base): __tablename__ = table_prefix + "_pg_vector_store" __table_args__ = {"extend_existing": True} id = Column(Integer, primary_key=True, autoincrement=False) - embedding = Column(_VectorType(vector_dimension), nullable=False) + embedding = Column(_VectorType(vector_dimension, use_halfvec), nullable=False) vector_store_index = Index( f"idx_{table_prefix}_pg_vector_store_embedding", @@ -76,12 +79,16 @@ class PGVector(VectorBase): :param index_params: the index parameters for pgvector, defaults to 'vector_l2_ops' index: {"index_type": "L2", "params": {"lists": 100, "probes": 10}. :type index_params: dict + :param use_halfvec: whether to use half-precision vector, defaults to False + :type use_halfvec: bool """ INDEX_PARAM = { "L2": {"operator": "<->", "name": "vector_l2_ops"}, # The only one supported now "cosine": {"operator": "<=>", "name": "vector_cosine_ops"}, "inner_product": {"operator": "<->", "name": "vector_ip_ops"}, + "halfvec_l2": {"operator": "<->", "name": "halfvec_l2_ops"}, + "halfvec_cosine": {"operator": "<=>", "name": "halfvec_cosine_ops"}, } def __init__( @@ -91,6 +98,7 @@ def __init__( collection_name: str = "gptcache", dimension: int = 0, top_k: int = 1, + use_halfvec: bool = False; ): if dimension <= 0: raise ValueError( @@ -100,11 +108,18 @@ def __init__( self.top_k = top_k self.index_params = index_params self._url = url + self.use_halfvec = use_halfvec + + #correcting the index type passed by user + if use_halfvec and "halfvec" not in index_params["index_type"]: + index_params["index_type"] = f"halfvec_{index_params['index_type'].lower()}" + self._store, self._index = _get_model_and_index( collection_name, dimension, index_type=self.INDEX_PARAM[index_params["index_type"]]["name"], - lists=index_params["params"]["lists"] + lists=index_params["params"]["lists"], + use_halfvec=use_halfvec ) self._connect(url) self._create_collection() @@ -116,7 +131,7 @@ def _connect(self, url): def _create_collection(self): with self._engine.connect() as con: con.execution_options(isolation_level="AUTOCOMMIT").execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) - + self._store.__table__.create(bind=self._engine, checkfirst=True) self._index.create(bind=self._engine, checkfirst=True) @@ -124,11 +139,15 @@ def _query(self, session): return session.query(self._store) def _format_data_for_search(self, data): + return f"[{','.join(map(str, data))}]" def mul_add(self, datas: List[VectorData]): data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) - np_data = np.array(data_array).astype("float32") + if self.use_halfvec: + np_data = np.array(data_array).astype("float16") + else: + np_data = np.array(data_array).astype("float32") entities = [{"id": id, "embedding": embedding.tolist()} for id, embedding in zip(id_array, np_data)] with self._session() as session: @@ -139,6 +158,9 @@ def search(self, data: np.ndarray, top_k: int = -1): if top_k == -1: top_k = self.top_k + if self.use_halfvec: + data = data.astype(np.float16) + formatted_data = self._format_data_for_search(data.reshape(1, -1)[0].tolist()) index_config = self.INDEX_PARAM[self.index_params["index_type"]] similarity = self._store.embedding.op(index_config["operator"])(formatted_data)