Skip to content

Commit 4754b49

Browse files
committed
Produce rows as slices of pyarrow.Table
1 parent d199a1b commit 4754b49

File tree

3 files changed

+129
-35
lines changed

3 files changed

+129
-35
lines changed

TCLIService/ttypes.py

Lines changed: 7 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyhive/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _reset_state(self):
3838

3939
# Internal helper state
4040
self._state = self._STATE_NONE
41-
self._data = collections.deque()
41+
self._data = None
4242
self._columns = None
4343

4444
def _fetch_while(self, fn):

pyhive/hive.py

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
import base64
1212
import datetime
13+
import io
14+
import numpy as np
15+
import pyarrow as pa
16+
import pyarrow.json
1317
import re
1418
from decimal import Decimal
1519
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
@@ -41,6 +45,7 @@
4145
_logger = logging.getLogger(__name__)
4246

4347
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
48+
_INTERVAL_DAY_TIME_PATTERN = re.compile(r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)')
4449

4550
ssl_cert_parameter_map = {
4651
"none": CERT_NONE,
@@ -67,9 +72,36 @@ def _parse_timestamp(value):
6772
value = None
6873
return value
6974

75+
def _parse_date(value):
76+
if value:
77+
format = '%Y-%m-%d'
78+
value = datetime.datetime.strptime(value, format).date()
79+
else:
80+
value = None
81+
return value
7082

71-
TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
72-
"TIMESTAMP_TYPE": _parse_timestamp}
83+
def _parse_interval_day_time(value):
84+
if value:
85+
match = _INTERVAL_DAY_TIME_PATTERN.match(value)
86+
if match:
87+
days = int(match.group(1))
88+
hours = int(match.group(2))
89+
minutes = int(match.group(3))
90+
seconds = float(match.group(4))
91+
value = datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
92+
else:
93+
raise Exception(
94+
'Cannot convert "{}" into an interval_day_time'.format(value))
95+
else:
96+
value = None
97+
return value
98+
99+
TYPES_CONVERTER = {
100+
"DECIMAL_TYPE": Decimal,
101+
"TIMESTAMP_TYPE": _parse_timestamp,
102+
"DATE_TYPE": _parse_date,
103+
"INTERVAL_DAY_TIME_TYPE": _parse_interval_day_time,
104+
}
73105

74106

75107
class HiveParamEscaper(common.ParamEscaper):
@@ -462,6 +494,48 @@ def cancel(self):
462494
response = self._connection.client.CancelOperation(req)
463495
_check_status(response)
464496

497+
def fetchone(self):
498+
return self.fetchmany(1)
499+
500+
def fetchall(self):
501+
return self.fetchmany(-1)
502+
503+
def fetchmany(self, size=None):
504+
if size is None:
505+
size = self.arraysize
506+
507+
if self._state == self._STATE_NONE:
508+
raise exc.ProgrammingError("No query yet")
509+
510+
if size == -1:
511+
# Fetch everything
512+
self._fetch_while(lambda: self._state != self._STATE_FINISHED)
513+
else:
514+
self._fetch_while(lambda:
515+
(self._state != self._STATE_FINISHED) and
516+
(self._data is None or self._data.num_rows < size)
517+
)
518+
519+
if not self._data:
520+
return None
521+
522+
if size == -1:
523+
# Fetch everything
524+
size = self._data.num_rows
525+
else:
526+
size = min(size, self._data.num_rows)
527+
528+
self._rownumber += size
529+
rows = self._data[:size]
530+
531+
if size == self._data.num_rows:
532+
# Fetch everything
533+
self._data = None
534+
else:
535+
self._data = self._data[size:]
536+
537+
return rows
538+
465539
def _fetch_more(self):
466540
"""Send another TFetchResultsReq and update state"""
467541
assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
@@ -479,13 +553,19 @@ def _fetch_more(self):
479553
assert not response.results.rows, 'expected data in columnar format'
480554
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
481555
zip(response.results.columns, schema)]
482-
new_data = list(zip(*columns))
483-
self._data += new_data
556+
names = [col[0] for col in schema]
557+
new_data = pa.Table.from_batches([pa.RecordBatch.from_arrays(columns, names=names)])
484558
# response.hasMoreRows seems to always be False, so we instead check the number of rows
485559
# https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
486560
# if not response.hasMoreRows:
487-
if not new_data:
561+
if new_data.num_rows == 0:
488562
self._state = self._STATE_FINISHED
563+
return
564+
565+
if self._data is None:
566+
self._data = new_data
567+
else:
568+
self._data = pa.concat_tables([self._data, new_data])
489569

490570
def poll(self, get_progress_update=True):
491571
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
@@ -563,17 +643,42 @@ def _unwrap_column(col, type_=None):
563643
"""Return a list of raw values from a TColumn instance."""
564644
for attr, wrapper in iteritems(col.__dict__):
565645
if wrapper is not None:
566-
result = wrapper.values
567-
nulls = wrapper.nulls # bit set describing what's null
568-
assert isinstance(nulls, bytes)
569-
for i, char in enumerate(nulls):
570-
byte = ord(char) if sys.version_info[0] == 2 else char
571-
for b in range(8):
572-
if byte & (1 << b):
573-
result[i * 8 + b] = None
574-
converter = TYPES_CONVERTER.get(type_, None)
575-
if converter and type_:
576-
result = [converter(row) if row else row for row in result]
646+
if attr in ['boolVal', 'byteVal', 'i16Val', 'i32Val', 'i64Val', 'doubleVal']:
647+
values = wrapper.values
648+
# unpack nulls as a byte array
649+
nulls = np.unpackbits(np.frombuffer(wrapper.nulls, dtype='uint8')).view(bool)
650+
# override a full mask as trailing False values are not sent
651+
mask = np.zeros(values.shape, dtype='?')
652+
end = min(len(mask), len(nulls))
653+
mask[:end] = nulls[:end]
654+
655+
# float values are transferred as double
656+
if type_ == 'FLOAT_TYPE':
657+
values = values.astype('>f4')
658+
659+
result = pa.array(values.byteswap().newbyteorder(), mask=mask)
660+
else:
661+
result = wrapper.values
662+
nulls = wrapper.nulls # bit set describing what's null
663+
if len(result) == 0:
664+
return pa.array([])
665+
assert isinstance(nulls, bytes)
666+
for i, char in enumerate(nulls):
667+
byte = ord(char) if sys.version_info[0] == 2 else char
668+
for b in range(8):
669+
if byte & (1 << b):
670+
result[i * 8 + b] = None
671+
converter = TYPES_CONVERTER.get(type_, None)
672+
if converter and type_:
673+
result = [converter(row) if row else row for row in result]
674+
if type_ in ['ARRAY_TYPE', 'MAP_TYPE', 'STRUCT_TYPE']:
675+
fd = io.BytesIO()
676+
for row in result:
677+
if row is None:
678+
row = 'null'
679+
fd.write(f'{{"c":{row}}}\n'.encode('utf8'))
680+
fd.seek(0)
681+
result = pa.json.read_json(fd)[0].combine_chunks()
577682
return result
578683
raise DataError("Got empty column value {}".format(col)) # pragma: no cover
579684

0 commit comments

Comments
 (0)