Skip to content

Commit c3d4d55

Browse files
committed
add result obj with status_code and add smart mode to query some
different type requests and add raise_error parameters
1 parent 363c39b commit c3d4d55

1 file changed

Lines changed: 92 additions & 84 deletions

File tree

ClickSQL/clickhouse/ClickHouse.py

Lines changed: 92 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import namedtuple
66
from functools import partial, partialmethod
77
from urllib import parse
8-
8+
import warnings
99
import nest_asyncio
1010
import pandas as pd
1111
import requests
@@ -23,30 +23,47 @@
2323

2424
nest_asyncio.apply() # allow run at jupyter and asyncio env
2525

26-
node_parameters = ['host', 'port', 'user', 'password', 'database']
26+
node_parameters = ('host', 'port', 'user', 'password', 'database')
2727
node = namedtuple('clickhouse', node_parameters)
2828
available_queries_select = ('select', 'show', 'desc')
2929
available_queries_insert = ('insert', 'optimize', 'create')
30-
PRINT_TEST_RESULT = True
30+
PRINT_CHECK_RESULT = True
31+
GLOBAL_RAISE_ERROR = True
3132
SEMAPHORE = 10 # control async number for whole query list
3233

3334

34-
class ParameterKeyError(Exception):
35-
pass
35+
class ParameterKeyError(Exception): pass
3636

3737

38-
class ParameterTypeError(Exception):
39-
pass
38+
class ParameterTypeError(Exception): pass
4039

4140

42-
class DatabaseTypeError(Exception):
43-
pass
41+
class DatabaseTypeError(Exception): pass
42+
43+
44+
class DatabaseError(Exception): pass
45+
46+
47+
class HeartbeatCheckFailure(Exception): pass
48+
49+
50+
# class SmartResult(object):
51+
#
52+
# def __init__(self, result, status_code: int):
53+
# self._obj = result
54+
# self.status_code = status_code
55+
56+
57+
def SmartBytes(result: bytes, status_code: int):
58+
result_cls = type('SmartBytes', (bytes,), {'status_code': property(lambda x: status_code)})
59+
60+
return result_cls(result)
4461

4562

4663
# TODO change to queue mode change remove aiohttp depends
4764
class ClickHouseTools(object):
4865
@staticmethod
49-
def _transfer_sql_format(sql, convert_to, transfer_sql_format=True):
66+
def _transfer_sql_format(sql: str, convert_to: str, transfer_sql_format: bool = True):
5067
"""
5168
provide a method which will translate a standard sql into clickhouse sql with might use format as suffix
5269
:param sql:
@@ -65,7 +82,7 @@ def _transfer_sql_format(sql, convert_to, transfer_sql_format=True):
6582
return sql
6683

6784
@staticmethod
68-
def _load_into_pd(ret_value, convert_to: str = 'dataframe', errors='ignore'):
85+
def _load_into_pd(ret_value: (str, bytes), convert_to: str = 'dataframe', errors='ignore'):
6986
"""
7087
will provide a approach to load data from clickhouse into pd.DataFrame format which may be easy to use
7188
@@ -76,7 +93,6 @@ def _load_into_pd(ret_value, convert_to: str = 'dataframe', errors='ignore'):
7693
"""
7794

7895
if convert_to.lower() == 'dataframe':
79-
8096
result_dict = json.loads(ret_value, strict=False)
8197
meta = result_dict['meta']
8298
name = map(lambda x: x['name'], meta)
@@ -108,15 +124,12 @@ def _merge_settings(cls, settings: (None, dict), updated_settings: (None, dict)
108124
raise ParameterTypeError(f'updated_settings must be dict type, but get {type(updated_settings)}')
109125
else:
110126
pass
111-
112127
if settings is not None and isinstance(settings, dict):
113128
invalid_setting_keys = set(settings.keys()) - set(updated_settings.keys())
114129
if len(invalid_setting_keys) > 0:
115130
raise ValueError('setting "{0}" are invalid, valid settings are: {1}'.format(
116131
','.join(invalid_setting_keys), ', '.join(updated_settings.keys())))
117-
118132
updated_settings.update(settings)
119-
120133
if extra_settings is not None and isinstance(extra_settings, dict):
121134
updated_settings.update(extra_settings)
122135

@@ -139,31 +152,28 @@ def __init__(self, **db_settings):
139152
'JSONCompact', 'JSONEachRow', 'TSKV', 'Pretty', 'PrettyCompact',
140153
'PrettyCompactMonoBlock', 'PrettyNoEscapes', 'PrettySpace', 'XML')
141154
142-
143-
144155
:param db_settings:
145156
"""
146-
self._check_db_settings(db_settings, available_db_type=[node.__name__])
147-
148-
self._db = db_settings['database']
157+
self._check_db_settings_(db_settings, available_db_type=[node.__name__])
149158
self._para = node(db_settings['host'], db_settings['port'], db_settings['user'],
150159
db_settings['password'], db_settings['database']) # store connection information
160+
self._db = self._para.database
151161
self._connect_url = 'http://{user}:{passwd}@{host}:{port}'.format(user=self._para.user,
152162
passwd=self._para.password,
153163
host=self._para.host,
154164
port=self._para.port)
155165
self.http_settings = self._merge_settings(None, updated_settings=self._default_settings,
156166
extra_settings={'user': self._para.user,
157167
'password': self._para.password})
158-
# self._session = ClientSession() # the reason of unclose session client
159168
# self.max_async_query_once = 5
160169
# self.is_closed = False
161-
162-
self._test_connection_("http://{host}:{port}/?".format(host=db_settings['host'], port=int(db_settings['port'])))
163-
self.cache_query = partialmethod(self.query, enable_cache=True)
170+
_base_url = "http://{host}:{port}/?".format(host=self._para.host, port=int(self._para.port))
171+
self.__heartbeat_test__(_base_url)
172+
self._heartbeat_test_ = partialmethod(self.__heartbeat_test__, _base_url=_base_url)
173+
self.cache_query = partialmethod(self.execute, enable_cache=True, exploit_func=True)
164174

165175
@staticmethod
166-
def _check_db_settings(db_settings: dict, available_db_type=(node.__name__,)): # node.__name__ : clickhouse
176+
def _check_db_settings_(db_settings: dict, available_db_type=(node.__name__,)): # node.__name__ : clickhouse
167177
"""
168178
it is to check db setting whether is correct!
169179
:param db_settings:
@@ -173,22 +183,18 @@ def _check_db_settings(db_settings: dict, available_db_type=(node.__name__,)):
173183
if isinstance(db_settings, dict):
174184
if db_settings['name'].lower() not in available_db_type:
175185
raise DatabaseTypeError(
176-
f'database symbol is not accept, now only accept: {",".join(available_db_type)}')
177-
missing_keys = [key for key in node_parameters if key not in db_settings.keys()]
178-
# :
179-
# missing_keys.append(key)
180-
# else:
181-
# pass
182-
if len(missing_keys) == 0:
186+
f'database symbol is not accepted, now only accept: {",".join(available_db_type)}')
187+
188+
missing_keys = filter(lambda x: x not in db_settings.keys(), node_parameters) # can improve
189+
if len(tuple(missing_keys)) == 0:
183190
pass
184191
else:
185192
raise ParameterKeyError(f"the following keys are not at settings: {','.join(missing_keys)}")
186193
else:
187194
raise ParameterTypeError(f'db_setting must be dict type! but get {type(db_settings)}')
188195

189196
@staticmethod
190-
def _test_connection_(_base_url):
191-
197+
def __heartbeat_test__(_base_url: str):
192198
"""
193199
a function to test connection by normal way!
194200
@@ -197,11 +203,14 @@ def _test_connection_(_base_url):
197203
"""
198204

199205
ret_value = requests.get(_base_url)
200-
if PRINT_TEST_RESULT:
206+
status_code = ret_value.status_code
207+
if status_code != 200:
208+
raise HeartbeatCheckFailure(f'heartbeat check failure at {_base_url}')
209+
if PRINT_CHECK_RESULT:
201210
print('connection test: ', ret_value.text.strip())
202211
del ret_value
203212

204-
async def _post(self, url: str, sql: str, session):
213+
async def _post(self, url: str, sql: str, session, raise_error: bool = True):
205214
"""
206215
the async way to send post request to the server
207216
:param url:
@@ -218,14 +227,17 @@ async def _post(self, url: str, sql: str, session):
218227
async with session.post(url, body=sql.encode(), ) as resp:
219228
result = await resp.read()
220229

221-
status = resp.status
230+
result = SmartBytes(result, resp.status)
222231
# reason = resp.reason
223-
if status != 200:
224-
raise ValueError(result)
232+
if result.status_code != 200:
233+
if raise_error and GLOBAL_RAISE_ERROR:
234+
raise DatabaseError(result)
235+
else:
236+
warnings.warn(str(result))
225237
return result
226238

227239
async def _compression_switched_request(self, query_with_format: (tuple, list, str), convert_to: str = 'dataframe',
228-
transfer_sql_format: bool = True, sem=None):
240+
transfer_sql_format: bool = True, sem=None, raise_error=True):
229241
"""
230242
the core request operator with compression switch adaptor
231243
@@ -236,30 +248,24 @@ async def _compression_switched_request(self, query_with_format: (tuple, list, s
236248
:return:
237249
"""
238250
url = self._connect_url + '/?' + parse.urlencode(self.http_settings)
239-
240251
transfer_sql = partial(self._transfer_sql_format, convert_to=convert_to,
241252
transfer_sql_format=transfer_sql_format)
242253
if sem is None:
243254
sem = asyncio.Semaphore(SEMAPHORE) # limit async num
244255
async with sem: # limit async number
245256
async with ClientSession() as session:
246257
if isinstance(query_with_format, str):
247-
# sql2 =
248-
result = await self._post(url, transfer_sql(query_with_format), session)
258+
result = await self._post(url, transfer_sql(query_with_format), session, raise_error=raise_error)
249259
elif isinstance(query_with_format, (tuple, list)):
250-
result = [await self._post(url, transfer_sql(sql), session) for sql in query_with_format]
251-
252-
# # sql2 = self._transfer_sql_format(sql, convert_to=convert_to,
253-
# # transfer_sql_format=transfer_sql_format)
254-
# res =
255-
# result.append(res)
260+
result = [await self._post(url, transfer_sql(sql), session, raise_error=raise_error) for sql in
261+
query_with_format]
256262
else:
257263
raise ValueError('query_with_format must be str , list or tuple')
258-
259264
return result
260265

261266
@classmethod
262-
def _load_into_pd_ext(cls, sql: (str, list, tuple), ret_value, convert_to: str, to_df: bool):
267+
def _load_into_pd_ext(cls, sql: (str, list, tuple), ret_value: (bytes, list, tuple), convert_to: str,
268+
to_df: bool = True):
263269
"""
264270
a way to parse into dataframe
265271
:param sql:
@@ -268,22 +274,22 @@ def _load_into_pd_ext(cls, sql: (str, list, tuple), ret_value, convert_to: str,
268274
:param to_df:
269275
:return:
270276
"""
271-
if isinstance(sql, str):
272-
if to_df or ret_value != b'':
277+
if not to_df:
278+
result = ret_value
279+
elif isinstance(sql, str):
280+
if ret_value != b'' and ret_value.status_code == 200:
273281
result = cls._load_into_pd(ret_value, convert_to)
274282
else:
275283
result = ret_value
276284
elif isinstance(sql, (list, tuple)):
277-
if to_df:
278-
result = [cls._load_into_pd(s, convert_to) if ret_value != b'' else None for s in ret_value]
279-
else:
280-
result = ret_value
285+
result = [cls._load_into_pd(s, convert_to) if s != b'' and s.status_code == 200 else s for s in
286+
ret_value]
281287
else:
282288
raise ValueError(f'sql must be str or list or tuple,but get {type(sql)}')
283289
return result
284290

285291
def __execute__(self, sql: (str, list, tuple), convert_to: str = 'dataframe', transfer_sql_format: bool = True,
286-
loop=None, to_df=True):
292+
loop=None, to_df: bool = True, raise_error=True):
287293
"""
288294
the core execute function to run the whole requests and SQL or a list of SQL.
289295
:param sql:
@@ -296,18 +302,22 @@ def __execute__(self, sql: (str, list, tuple), convert_to: str = 'dataframe', tr
296302

297303
sem = asyncio.Semaphore(SEMAPHORE) # limit async num
298304
resp_list = self._compression_switched_request(sql, convert_to=convert_to,
299-
transfer_sql_format=transfer_sql_format, sem=sem)
305+
transfer_sql_format=transfer_sql_format, sem=sem,
306+
raise_error=raise_error)
300307
if loop is None:
301308
loop = asyncio.get_event_loop() # init loop
302309
res = loop.run_until_complete(resp_list)
303-
result = self._load_into_pd_ext(sql, res, convert_to, to_df)
304-
310+
result = self._load_into_pd_ext(sql, res, convert_to, to_df=to_df)
305311
return result
306312

307-
def execute(self, *sql, convert_to: str = 'dataframe', loop=None, output_df=True, ):
313+
def execute(self, *sql, convert_to: str = 'dataframe', loop=None, output_df: bool = True,
314+
enable_cache: bool = False, exploit_func: bool = True, raise_error: bool = True):
308315
"""
309316
execute sql or multi sql
310317
318+
:param raise_error:
319+
:param exploit_func:
320+
:param enable_cache:
311321
:param output_df:
312322
:param sql:
313323
:param convert_to:
@@ -316,30 +326,31 @@ def execute(self, *sql, convert_to: str = 'dataframe', loop=None, output_df=True
316326
"""
317327
# TODO change to smart mode, can receive any kind sql combination and handle them
318328
# detect whether all query are insert process
319-
insert_process = list(map(lambda x: x.lower().startswith(available_queries_insert), sql))
320-
# detect whether all query are select process
321-
select_process = list(map(lambda x: x.lower().startswith(available_queries_select), sql))
322-
if all(insert_process) is True:
323-
to_df = transfer_sql_format = False
324-
elif all(select_process) is True:
325-
to_df = transfer_sql_format = True
326-
else:
327-
# TODO change to smart mode, can receive any kind sql combination and handle them
328-
raise ValueError(
329-
'the list of queries must be same type query! currently cannot handle various kind SQL type'
330-
'combination')
331-
332-
result = self.__execute__(sql, convert_to=convert_to, transfer_sql_format=transfer_sql_format, loop=loop,
333-
to_df=to_df * output_df)
329+
# insert_process = map(lambda x: x.lower().startswith(available_queries_insert), sql)
330+
# # detect whether all query are select process
331+
# select_process = map(lambda x: x.lower().startswith(available_queries_select), sql)
332+
# if all(list(select_process)) is True:
333+
# to_df = transfer_sql_format = True
334+
# elif all(list(insert_process)) is True:
335+
# to_df = transfer_sql_format = False
336+
# else:
337+
# # TODO change to smart mode, can receive any kind sql combination and handle them
338+
# raise ValueError(
339+
# 'the list of queries must be same type query! currently cannot handle various kind SQL type'
340+
# 'combination')
341+
func = file_cache(enable_cache=enable_cache, exploit_func=exploit_func)(self.__execute__)
342+
result = func(sql, convert_to=convert_to, transfer_sql_format=True, loop=loop,
343+
to_df=True * output_df, raise_error=raise_error)
334344

335345
return result
336346

337-
def query(self, *sql: str, loop=None, output_df=True, enable_cache=False, exploit_func=True):
347+
def query(self, *sql: str, loop=None, output_df: bool = True, raise_error=True):
338348

339349
"""
340350
add enable_cache and exploit_func
341351
342352
## TODO require to upgrade
353+
:param raise_error:
343354
:param exploit_func:
344355
:param enable_cache:
345356
:param output_df:
@@ -348,12 +359,9 @@ def query(self, *sql: str, loop=None, output_df=True, enable_cache=False, exploi
348359
:return:
349360
"""
350361

351-
func = file_cache(enable_cache=enable_cache, exploit_func=exploit_func)(self.execute)
352-
result = func(*sql, convert_to='dataframe', loop=loop, output_df=output_df)
353-
if len(sql) == 1:
354-
return result[0]
355-
else:
356-
return result
362+
result = self.execute(*sql, convert_to='dataframe', loop=loop, output_df=output_df, enable_cache=False,
363+
exploit_func=False, raise_error=raise_error)
364+
return result
357365

358366

359367
class ClickHouseTableNode(ClickHouseBaseNode):

0 commit comments

Comments
 (0)