55from collections import namedtuple
66from functools import partial , partialmethod
77from urllib import parse
8-
8+ import warnings
99import nest_asyncio
1010import pandas as pd
1111import requests
2323
2424nest_asyncio .apply () # allow run at jupyter and asyncio env
2525
26- node_parameters = [ 'host' , 'port' , 'user' , 'password' , 'database' ]
26+ node_parameters = ( 'host' , 'port' , 'user' , 'password' , 'database' )
2727node = namedtuple ('clickhouse' , node_parameters )
2828available_queries_select = ('select' , 'show' , 'desc' )
2929available_queries_insert = ('insert' , 'optimize' , 'create' )
30- PRINT_TEST_RESULT = True
30+ PRINT_CHECK_RESULT = True
31+ GLOBAL_RAISE_ERROR = True
3132SEMAPHORE = 10 # control async number for whole query list
3233
3334
34- class ParameterKeyError (Exception ):
35- pass
35+ class ParameterKeyError (Exception ): pass
3636
3737
38- class ParameterTypeError (Exception ):
39- pass
38+ class ParameterTypeError (Exception ): pass
4039
4140
42- class DatabaseTypeError (Exception ):
43- pass
41+ class DatabaseTypeError (Exception ): pass
42+
43+
44+ class DatabaseError (Exception ): pass
45+
46+
47+ class HeartbeatCheckFailure (Exception ): pass
48+
49+
50+ # class SmartResult(object):
51+ #
52+ # def __init__(self, result, status_code: int):
53+ # self._obj = result
54+ # self.status_code = status_code
55+
56+
57+ def SmartBytes (result : bytes , status_code : int ):
58+ result_cls = type ('SmartBytes' , (bytes ,), {'status_code' : property (lambda x : status_code )})
59+
60+ return result_cls (result )
4461
4562
4663# TODO change to queue mode change remove aiohttp depends
4764class ClickHouseTools (object ):
4865 @staticmethod
49- def _transfer_sql_format (sql , convert_to , transfer_sql_format = True ):
66+ def _transfer_sql_format (sql : str , convert_to : str , transfer_sql_format : bool = True ):
5067 """
5168 provide a method which will translate a standard sql into clickhouse sql with might use format as suffix
5269 :param sql:
@@ -65,7 +82,7 @@ def _transfer_sql_format(sql, convert_to, transfer_sql_format=True):
6582 return sql
6683
6784 @staticmethod
68- def _load_into_pd (ret_value , convert_to : str = 'dataframe' , errors = 'ignore' ):
85+ def _load_into_pd (ret_value : ( str , bytes ) , convert_to : str = 'dataframe' , errors = 'ignore' ):
6986 """
7087 will provide a approach to load data from clickhouse into pd.DataFrame format which may be easy to use
7188
@@ -76,7 +93,6 @@ def _load_into_pd(ret_value, convert_to: str = 'dataframe', errors='ignore'):
7693 """
7794
7895 if convert_to .lower () == 'dataframe' :
79-
8096 result_dict = json .loads (ret_value , strict = False )
8197 meta = result_dict ['meta' ]
8298 name = map (lambda x : x ['name' ], meta )
@@ -108,15 +124,12 @@ def _merge_settings(cls, settings: (None, dict), updated_settings: (None, dict)
108124 raise ParameterTypeError (f'updated_settings must be dict type, but get { type (updated_settings )} ' )
109125 else :
110126 pass
111-
112127 if settings is not None and isinstance (settings , dict ):
113128 invalid_setting_keys = set (settings .keys ()) - set (updated_settings .keys ())
114129 if len (invalid_setting_keys ) > 0 :
115130 raise ValueError ('setting "{0}" are invalid, valid settings are: {1}' .format (
116131 ',' .join (invalid_setting_keys ), ', ' .join (updated_settings .keys ())))
117-
118132 updated_settings .update (settings )
119-
120133 if extra_settings is not None and isinstance (extra_settings , dict ):
121134 updated_settings .update (extra_settings )
122135
@@ -139,31 +152,28 @@ def __init__(self, **db_settings):
139152 'JSONCompact', 'JSONEachRow', 'TSKV', 'Pretty', 'PrettyCompact',
140153 'PrettyCompactMonoBlock', 'PrettyNoEscapes', 'PrettySpace', 'XML')
141154
142-
143-
144155 :param db_settings:
145156 """
146- self ._check_db_settings (db_settings , available_db_type = [node .__name__ ])
147-
148- self ._db = db_settings ['database' ]
157+ self ._check_db_settings_ (db_settings , available_db_type = [node .__name__ ])
149158 self ._para = node (db_settings ['host' ], db_settings ['port' ], db_settings ['user' ],
150159 db_settings ['password' ], db_settings ['database' ]) # store connection information
160+ self ._db = self ._para .database
151161 self ._connect_url = 'http://{user}:{passwd}@{host}:{port}' .format (user = self ._para .user ,
152162 passwd = self ._para .password ,
153163 host = self ._para .host ,
154164 port = self ._para .port )
155165 self .http_settings = self ._merge_settings (None , updated_settings = self ._default_settings ,
156166 extra_settings = {'user' : self ._para .user ,
157167 'password' : self ._para .password })
158- # self._session = ClientSession() # the reason of unclose session client
159168 # self.max_async_query_once = 5
160169 # self.is_closed = False
161-
162- self ._test_connection_ ("http://{host}:{port}/?" .format (host = db_settings ['host' ], port = int (db_settings ['port' ])))
163- self .cache_query = partialmethod (self .query , enable_cache = True )
170+ _base_url = "http://{host}:{port}/?" .format (host = self ._para .host , port = int (self ._para .port ))
171+ self .__heartbeat_test__ (_base_url )
172+ self ._heartbeat_test_ = partialmethod (self .__heartbeat_test__ , _base_url = _base_url )
173+ self .cache_query = partialmethod (self .execute , enable_cache = True , exploit_func = True )
164174
165175 @staticmethod
166- def _check_db_settings (db_settings : dict , available_db_type = (node .__name__ ,)): # node.__name__ : clickhouse
176+ def _check_db_settings_ (db_settings : dict , available_db_type = (node .__name__ ,)): # node.__name__ : clickhouse
167177 """
168178 it is to check db setting whether is correct!
169179 :param db_settings:
@@ -173,22 +183,18 @@ def _check_db_settings(db_settings: dict, available_db_type=(node.__name__,)):
173183 if isinstance (db_settings , dict ):
174184 if db_settings ['name' ].lower () not in available_db_type :
175185 raise DatabaseTypeError (
176- f'database symbol is not accept, now only accept: { "," .join (available_db_type )} ' )
177- missing_keys = [key for key in node_parameters if key not in db_settings .keys ()]
178- # :
179- # missing_keys.append(key)
180- # else:
181- # pass
182- if len (missing_keys ) == 0 :
186+ f'database symbol is not accepted, now only accept: { "," .join (available_db_type )} ' )
187+
188+ missing_keys = filter (lambda x : x not in db_settings .keys (), node_parameters ) # can improve
189+ if len (tuple (missing_keys )) == 0 :
183190 pass
184191 else :
185192 raise ParameterKeyError (f"the following keys are not at settings: { ',' .join (missing_keys )} " )
186193 else :
187194 raise ParameterTypeError (f'db_setting must be dict type! but get { type (db_settings )} ' )
188195
189196 @staticmethod
190- def _test_connection_ (_base_url ):
191-
197+ def __heartbeat_test__ (_base_url : str ):
192198 """
193199 a function to test connection by normal way!
194200
@@ -197,11 +203,14 @@ def _test_connection_(_base_url):
197203 """
198204
199205 ret_value = requests .get (_base_url )
200- if PRINT_TEST_RESULT :
206+ status_code = ret_value .status_code
207+ if status_code != 200 :
208+ raise HeartbeatCheckFailure (f'heartbeat check failure at { _base_url } ' )
209+ if PRINT_CHECK_RESULT :
201210 print ('connection test: ' , ret_value .text .strip ())
202211 del ret_value
203212
204- async def _post (self , url : str , sql : str , session ):
213+ async def _post (self , url : str , sql : str , session , raise_error : bool = True ):
205214 """
206215 the async way to send post request to the server
207216 :param url:
@@ -218,14 +227,17 @@ async def _post(self, url: str, sql: str, session):
218227 async with session .post (url , body = sql .encode (), ) as resp :
219228 result = await resp .read ()
220229
221- status = resp .status
230+ result = SmartBytes ( result , resp .status )
222231 # reason = resp.reason
223- if status != 200 :
224- raise ValueError (result )
232+ if result .status_code != 200 :
233+ if raise_error and GLOBAL_RAISE_ERROR :
234+ raise DatabaseError (result )
235+ else :
236+ warnings .warn (str (result ))
225237 return result
226238
227239 async def _compression_switched_request (self , query_with_format : (tuple , list , str ), convert_to : str = 'dataframe' ,
228- transfer_sql_format : bool = True , sem = None ):
240+ transfer_sql_format : bool = True , sem = None , raise_error = True ):
229241 """
230242 the core request operator with compression switch adaptor
231243
@@ -236,30 +248,24 @@ async def _compression_switched_request(self, query_with_format: (tuple, list, s
236248 :return:
237249 """
238250 url = self ._connect_url + '/?' + parse .urlencode (self .http_settings )
239-
240251 transfer_sql = partial (self ._transfer_sql_format , convert_to = convert_to ,
241252 transfer_sql_format = transfer_sql_format )
242253 if sem is None :
243254 sem = asyncio .Semaphore (SEMAPHORE ) # limit async num
244255 async with sem : # limit async number
245256 async with ClientSession () as session :
246257 if isinstance (query_with_format , str ):
247- # sql2 =
248- result = await self ._post (url , transfer_sql (query_with_format ), session )
258+ result = await self ._post (url , transfer_sql (query_with_format ), session , raise_error = raise_error )
249259 elif isinstance (query_with_format , (tuple , list )):
250- result = [await self ._post (url , transfer_sql (sql ), session ) for sql in query_with_format ]
251-
252- # # sql2 = self._transfer_sql_format(sql, convert_to=convert_to,
253- # # transfer_sql_format=transfer_sql_format)
254- # res =
255- # result.append(res)
260+ result = [await self ._post (url , transfer_sql (sql ), session , raise_error = raise_error ) for sql in
261+ query_with_format ]
256262 else :
257263 raise ValueError ('query_with_format must be str , list or tuple' )
258-
259264 return result
260265
261266 @classmethod
262- def _load_into_pd_ext (cls , sql : (str , list , tuple ), ret_value , convert_to : str , to_df : bool ):
267+ def _load_into_pd_ext (cls , sql : (str , list , tuple ), ret_value : (bytes , list , tuple ), convert_to : str ,
268+ to_df : bool = True ):
263269 """
264270 a way to parse into dataframe
265271 :param sql:
@@ -268,22 +274,22 @@ def _load_into_pd_ext(cls, sql: (str, list, tuple), ret_value, convert_to: str,
268274 :param to_df:
269275 :return:
270276 """
271- if isinstance (sql , str ):
272- if to_df or ret_value != b'' :
277+ if not to_df :
278+ result = ret_value
279+ elif isinstance (sql , str ):
280+ if ret_value != b'' and ret_value .status_code == 200 :
273281 result = cls ._load_into_pd (ret_value , convert_to )
274282 else :
275283 result = ret_value
276284 elif isinstance (sql , (list , tuple )):
277- if to_df :
278- result = [cls ._load_into_pd (s , convert_to ) if ret_value != b'' else None for s in ret_value ]
279- else :
280- result = ret_value
285+ result = [cls ._load_into_pd (s , convert_to ) if s != b'' and s .status_code == 200 else s for s in
286+ ret_value ]
281287 else :
282288 raise ValueError (f'sql must be str or list or tuple,but get { type (sql )} ' )
283289 return result
284290
285291 def __execute__ (self , sql : (str , list , tuple ), convert_to : str = 'dataframe' , transfer_sql_format : bool = True ,
286- loop = None , to_df = True ):
292+ loop = None , to_df : bool = True , raise_error = True ):
287293 """
288294 the core execute function to run the whole requests and SQL or a list of SQL.
289295 :param sql:
@@ -296,18 +302,22 @@ def __execute__(self, sql: (str, list, tuple), convert_to: str = 'dataframe', tr
296302
297303 sem = asyncio .Semaphore (SEMAPHORE ) # limit async num
298304 resp_list = self ._compression_switched_request (sql , convert_to = convert_to ,
299- transfer_sql_format = transfer_sql_format , sem = sem )
305+ transfer_sql_format = transfer_sql_format , sem = sem ,
306+ raise_error = raise_error )
300307 if loop is None :
301308 loop = asyncio .get_event_loop () # init loop
302309 res = loop .run_until_complete (resp_list )
303- result = self ._load_into_pd_ext (sql , res , convert_to , to_df )
304-
310+ result = self ._load_into_pd_ext (sql , res , convert_to , to_df = to_df )
305311 return result
306312
307- def execute (self , * sql , convert_to : str = 'dataframe' , loop = None , output_df = True , ):
313+ def execute (self , * sql , convert_to : str = 'dataframe' , loop = None , output_df : bool = True ,
314+ enable_cache : bool = False , exploit_func : bool = True , raise_error : bool = True ):
308315 """
309316 execute sql or multi sql
310317
318+ :param raise_error:
319+ :param exploit_func:
320+ :param enable_cache:
311321 :param output_df:
312322 :param sql:
313323 :param convert_to:
@@ -316,30 +326,31 @@ def execute(self, *sql, convert_to: str = 'dataframe', loop=None, output_df=True
316326 """
317327 # TODO change to smart mode, can receive any kind sql combination and handle them
318328 # detect whether all query are insert process
319- insert_process = list ( map (lambda x : x .lower ().startswith (available_queries_insert ), sql ) )
320- # detect whether all query are select process
321- select_process = list ( map (lambda x : x .lower ().startswith (available_queries_select ), sql ) )
322- if all (insert_process ) is True :
323- to_df = transfer_sql_format = False
324- elif all (select_process ) is True :
325- to_df = transfer_sql_format = True
326- else :
327- # TODO change to smart mode, can receive any kind sql combination and handle them
328- raise ValueError (
329- 'the list of queries must be same type query! currently cannot handle various kind SQL type'
330- 'combination' )
331-
332- result = self . __execute__ (sql , convert_to = convert_to , transfer_sql_format = transfer_sql_format , loop = loop ,
333- to_df = to_df * output_df )
329+ # insert_process = map(lambda x: x.lower().startswith(available_queries_insert), sql)
330+ # # detect whether all query are select process
331+ # select_process = map(lambda x: x.lower().startswith(available_queries_select), sql)
332+ # if all(list(select_process) ) is True:
333+ # to_df = transfer_sql_format = True
334+ # elif all(list(insert_process) ) is True:
335+ # to_df = transfer_sql_format = False
336+ # else:
337+ # # TODO change to smart mode, can receive any kind sql combination and handle them
338+ # raise ValueError(
339+ # 'the list of queries must be same type query! currently cannot handle various kind SQL type'
340+ # 'combination')
341+ func = file_cache ( enable_cache = enable_cache , exploit_func = exploit_func )( self . __execute__ )
342+ result = func (sql , convert_to = convert_to , transfer_sql_format = True , loop = loop ,
343+ to_df = True * output_df , raise_error = raise_error )
334344
335345 return result
336346
337- def query (self , * sql : str , loop = None , output_df = True , enable_cache = False , exploit_func = True ):
347+ def query (self , * sql : str , loop = None , output_df : bool = True , raise_error = True ):
338348
339349 """
340350 add enable_cache and exploit_func
341351
342352 ## TODO require to upgrade
353+ :param raise_error:
343354 :param exploit_func:
344355 :param enable_cache:
345356 :param output_df:
@@ -348,12 +359,9 @@ def query(self, *sql: str, loop=None, output_df=True, enable_cache=False, exploi
348359 :return:
349360 """
350361
351- func = file_cache (enable_cache = enable_cache , exploit_func = exploit_func )(self .execute )
352- result = func (* sql , convert_to = 'dataframe' , loop = loop , output_df = output_df )
353- if len (sql ) == 1 :
354- return result [0 ]
355- else :
356- return result
362+ result = self .execute (* sql , convert_to = 'dataframe' , loop = loop , output_df = output_df , enable_cache = False ,
363+ exploit_func = False , raise_error = raise_error )
364+ return result
357365
358366
359367class ClickHouseTableNode (ClickHouseBaseNode ):
0 commit comments