11"""Wrapper around the Clickhouse vector database over VectorDB"""
22
3- import io
43import logging
54from contextlib import contextmanager
65from typing import Any
6+
77import clickhouse_connect
8- import numpy as np
98
10- from ..api import VectorDB , DBCaseConfig
9+ from ..api import DBCaseConfig , VectorDB
1110
1211log = logging .getLogger (__name__ )
1312
13+
1414class Clickhouse (VectorDB ):
1515 """Use SQLAlchemy instructions"""
16+
1617 def __init__ (
1718 self ,
1819 dim : int ,
@@ -32,12 +33,13 @@ def __init__(
3233 self ._vector_field = "embedding"
3334
3435 # construct basic units
35- self .conn = clickhouse_connect .get_client (
36- host = self .db_config ["host" ],
37- port = self .db_config ["port" ],
38- username = self .db_config ["user" ],
39- password = self .db_config ["password" ],
40- database = self .db_config ["dbname" ])
36+ self .conn = clickhouse_connect .get_client (
37+ host = self .db_config ["host" ],
38+ port = self .db_config ["port" ],
39+ username = self .db_config ["user" ],
40+ password = self .db_config ["password" ],
41+ database = self .db_config ["dbname" ],
42+ )
4143
4244 if drop_old :
4345 log .info (f"Clickhouse client drop table : { self .table_name } " )
@@ -48,20 +50,21 @@ def __init__(
4850 self .conn = None
4951
5052 @contextmanager
51- def init (self ) -> None :
53+ def init (self ):
5254 """
5355 Examples:
5456 >>> with self.init():
5557 >>> self.insert_embeddings()
5658 >>> self.search_embedding()
5759 """
5860
59- self .conn = clickhouse_connect .get_client (
60- host = self .db_config ["host" ],
61- port = self .db_config ["port" ],
62- username = self .db_config ["user" ],
63- password = self .db_config ["password" ],
64- database = self .db_config ["dbname" ])
61+ self .conn = clickhouse_connect .get_client (
62+ host = self .db_config ["host" ],
63+ port = self .db_config ["port" ],
64+ username = self .db_config ["user" ],
65+ password = self .db_config ["password" ],
66+ database = self .db_config ["dbname" ],
67+ )
6568
6669 try :
6770 yield
@@ -85,9 +88,7 @@ def _create_table(self, dim: int):
8588 )
8689
8790 except Exception as e :
88- log .warning (
89- f"Failed to create Clickhouse table: { self .table_name } error: { e } "
90- )
91+ log .warning (f"Failed to create Clickhouse table: { self .table_name } error: { e } " )
9192 raise e from None
9293
9394 def ready_to_load (self ):
@@ -104,16 +105,20 @@ def insert_embeddings(
104105 embeddings : list [list [float ]],
105106 metadata : list [int ],
106107 ** kwargs : Any ,
107- ) -> ( int , Exception ) :
108+ ) -> tuple [ int , Exception ] :
108109 assert self .conn is not None , "Connection is not initialized"
109110
110111 try :
111112 # do not iterate for bulk insert
112113 items = [metadata , embeddings ]
113114
114- self .conn .insert (table = self .table_name , data = items ,
115- column_names = ['id' , 'embedding' ], column_type_names = ['UInt32' , 'Array(Float64)' ],
116- column_oriented = True )
115+ self .conn .insert (
116+ table = self .table_name ,
117+ data = items ,
118+ column_names = ["id" , "embedding" ],
119+ column_type_names = ["UInt32" , "Array(Float64)" ],
120+ column_oriented = True ,
121+ )
117122 return len (metadata ), None
118123 except Exception as e :
119124 log .warning (f"Failed to insert data into Clickhouse table ({ self .table_name } ), error: { e } " )
@@ -128,22 +133,24 @@ def search_embedding(
128133 ) -> list [int ]:
129134 assert self .conn is not None , "Connection is not initialized"
130135
131- index_param = self .case_config .index_param ()
136+ index_param = self .case_config .index_param () # noqa: F841
132137 search_param = self .case_config .search_param ()
133138
134139 if filters :
135140 gt = filters .get ("id" )
136- filterSql = (f'SELECT id, { search_param ["metric_type" ]} (embedding,{ query } ) AS score '
137- f'FROM { self .db_config ["dbname" ]} .{ self .table_name } '
138- f'WHERE id > { gt } '
139- f'ORDER BY score LIMIT { k } ;'
140- )
141- result = self .conn .query (filterSql ).result_rows
141+ filter_sql = (
142+ f'SELECT id, { search_param ["metric_type" ]} (embedding,{ query } ) AS score ' # noqa: S608
143+ f'FROM { self .db_config ["dbname" ]} .{ self .table_name } '
144+ f"WHERE id > { gt } "
145+ f"ORDER BY score LIMIT { k } ;"
146+ )
147+ result = self .conn .query (filter_sql ).result_rows
142148 return [int (row [0 ]) for row in result ]
143- else :
144- selectSql = (f'SELECT id, { search_param ["metric_type" ]} (embedding,{ query } ) AS score '
145- f'FROM { self .db_config ["dbname" ]} .{ self .table_name } '
146- f'ORDER BY score LIMIT { k } ;'
147- )
148- result = self .conn .query (selectSql ).result_rows
149+ else : # noqa: RET505
150+ select_sql = (
151+ f'SELECT id, { search_param ["metric_type" ]} (embedding,{ query } ) AS score ' # noqa: S608
152+ f'FROM { self .db_config ["dbname" ]} .{ self .table_name } '
153+ f"ORDER BY score LIMIT { k } ;"
154+ )
155+ result = self .conn .query (select_sql ).result_rows
149156 return [int (row [0 ]) for row in result ]
0 commit comments