1010
1111import base64
1212import datetime
13+ import io
14+ import numpy as np
15+ import pyarrow as pa
16+ import pyarrow .json
1317import re
1418from decimal import Decimal
1519from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED , create_default_context
4145_logger = logging .getLogger (__name__ )
4246
4347_TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)' )
48+ _INTERVAL_DAY_TIME_PATTERN = re .compile (r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)' )
4449
4550ssl_cert_parameter_map = {
4651 "none" : CERT_NONE ,
@@ -67,9 +72,36 @@ def _parse_timestamp(value):
6772 value = None
6873 return value
6974
75+ def _parse_date (value ):
76+ if value :
77+ format = '%Y-%m-%d'
78+ value = datetime .datetime .strptime (value , format ).date ()
79+ else :
80+ value = None
81+ return value
7082
71- TYPES_CONVERTER = {"DECIMAL_TYPE" : Decimal ,
72- "TIMESTAMP_TYPE" : _parse_timestamp }
83+ def _parse_interval_day_time (value ):
84+ if value :
85+ match = _INTERVAL_DAY_TIME_PATTERN .match (value )
86+ if match :
87+ days = int (match .group (1 ))
88+ hours = int (match .group (2 ))
89+ minutes = int (match .group (3 ))
90+ seconds = float (match .group (4 ))
91+ value = datetime .timedelta (days = days , hours = hours , minutes = minutes , seconds = seconds )
92+ else :
93+ raise Exception (
94+ 'Cannot convert "{}" into an interval_day_time' .format (value ))
95+ else :
96+ value = None
97+ return value
98+
99+ TYPES_CONVERTER = {
100+ "DECIMAL_TYPE" : Decimal ,
101+ "TIMESTAMP_TYPE" : _parse_timestamp ,
102+ "DATE_TYPE" : _parse_date ,
103+ "INTERVAL_DAY_TIME_TYPE" : _parse_interval_day_time ,
104+ }
73105
74106
75107class HiveParamEscaper (common .ParamEscaper ):
@@ -462,6 +494,48 @@ def cancel(self):
462494 response = self ._connection .client .CancelOperation (req )
463495 _check_status (response )
464496
497+ def fetchone (self ):
498+ return self .fetchmany (1 )
499+
500+ def fetchall (self ):
501+ return self .fetchmany (- 1 )
502+
503+ def fetchmany (self , size = None ):
504+ if size is None :
505+ size = self .arraysize
506+
507+ if self ._state == self ._STATE_NONE :
508+ raise exc .ProgrammingError ("No query yet" )
509+
510+ if size == - 1 :
511+ # Fetch everything
512+ self ._fetch_while (lambda : self ._state != self ._STATE_FINISHED )
513+ else :
514+ self ._fetch_while (lambda :
515+ (self ._state != self ._STATE_FINISHED ) and
516+ (self ._data is None or self ._data .num_rows < size )
517+ )
518+
519+ if not self ._data :
520+ return None
521+
522+ if size == - 1 :
523+ # Fetch everything
524+ size = self ._data .num_rows
525+ else :
526+ size = min (size , self ._data .num_rows )
527+
528+ self ._rownumber += size
529+ rows = self ._data [:size ]
530+
531+ if size == self ._data .num_rows :
532+ # Fetch everything
533+ self ._data = None
534+ else :
535+ self ._data = self ._data [size :]
536+
537+ return rows
538+
465539 def _fetch_more (self ):
466540 """Send another TFetchResultsReq and update state"""
467541 assert (self ._state == self ._STATE_RUNNING ), "Should be running when in _fetch_more"
@@ -479,13 +553,19 @@ def _fetch_more(self):
479553 assert not response .results .rows , 'expected data in columnar format'
480554 columns = [_unwrap_column (col , col_schema [1 ]) for col , col_schema in
481555 zip (response .results .columns , schema )]
482- new_data = list ( zip ( * columns ))
483- self . _data += new_data
556+ names = [ col [ 0 ] for col in schema ]
557+ new_data = pa . Table . from_batches ([ pa . RecordBatch . from_arrays ( columns , names = names )])
484558 # response.hasMoreRows seems to always be False, so we instead check the number of rows
485559 # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
486560 # if not response.hasMoreRows:
487- if not new_data :
561+ if new_data . num_rows == 0 :
488562 self ._state = self ._STATE_FINISHED
563+ return
564+
565+ if self ._data is None :
566+ self ._data = new_data
567+ else :
568+ self ._data = pa .concat_tables ([self ._data , new_data ])
489569
490570 def poll (self , get_progress_update = True ):
491571 """Poll for and return the raw status data provided by the Hive Thrift REST API.
@@ -563,17 +643,42 @@ def _unwrap_column(col, type_=None):
563643 """Return a list of raw values from a TColumn instance."""
564644 for attr , wrapper in iteritems (col .__dict__ ):
565645 if wrapper is not None :
566- result = wrapper .values
567- nulls = wrapper .nulls # bit set describing what's null
568- assert isinstance (nulls , bytes )
569- for i , char in enumerate (nulls ):
570- byte = ord (char ) if sys .version_info [0 ] == 2 else char
571- for b in range (8 ):
572- if byte & (1 << b ):
573- result [i * 8 + b ] = None
574- converter = TYPES_CONVERTER .get (type_ , None )
575- if converter and type_ :
576- result = [converter (row ) if row else row for row in result ]
646+ if attr in ['boolVal' , 'byteVal' , 'i16Val' , 'i32Val' , 'i64Val' , 'doubleVal' ]:
647+ values = wrapper .values
648+ # unpack nulls as a byte array
649+ nulls = np .unpackbits (np .frombuffer (wrapper .nulls , dtype = 'uint8' )).view (bool )
650+ # override a full mask as trailing False values are not sent
651+ mask = np .zeros (values .shape , dtype = '?' )
652+ end = min (len (mask ), len (nulls ))
653+ mask [:end ] = nulls [:end ]
654+
655+ # float values are transferred as double
656+ if type_ == 'FLOAT_TYPE' :
657+ values = values .astype ('>f4' )
658+
659+ result = pa .array (values .byteswap ().newbyteorder (), mask = mask )
660+ else :
661+ result = wrapper .values
662+ nulls = wrapper .nulls # bit set describing what's null
663+ if len (result ) == 0 :
664+ return pa .array ([])
665+ assert isinstance (nulls , bytes )
666+ for i , char in enumerate (nulls ):
667+ byte = ord (char ) if sys .version_info [0 ] == 2 else char
668+ for b in range (8 ):
669+ if byte & (1 << b ):
670+ result [i * 8 + b ] = None
671+ converter = TYPES_CONVERTER .get (type_ , None )
672+ if converter and type_ :
673+ result = [converter (row ) if row else row for row in result ]
674+ if type_ in ['ARRAY_TYPE' , 'MAP_TYPE' , 'STRUCT_TYPE' ]:
675+ fd = io .BytesIO ()
676+ for row in result :
677+ if row is None :
678+ row = 'null'
679+ fd .write (f'{{"c":{ row } }}\n ' .encode ('utf8' ))
680+ fd .seek (0 )
681+ result = pa .json .read_json (fd )[0 ].combine_chunks ()
577682 return result
578683 raise DataError ("Got empty column value {}" .format (col )) # pragma: no cover
579684
0 commit comments