Skip to content

Commit b860ef2

Browse files
committed
feat: support for decimal type in pandas and spark
1 parent 0ee95a4 commit b860ef2

5 files changed

Lines changed: 29 additions & 18 deletions

File tree

histogrammar/dfinterface/filling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def check_column(col, sep=":"):
3838
return col
3939

4040

41-
def check_dtype(dtype):
41+
def normalize_dtype(dtype):
4242
"""Convert datatype to consistent numpy datatype
4343
4444
:param dtype: input datatype

histogrammar/dfinterface/histogram_filler_base.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..primitives.stack import Stack
2828
from ..primitives.sum import Sum
2929

30-
from .filling_utils import check_column, check_dtype
30+
from .filling_utils import check_column, normalize_dtype
3131

3232

3333
class HistogramFillerBase(object):
@@ -111,7 +111,7 @@ def __init__(
111111
self.bin_specs = bin_specs or {}
112112
self.time_axis = time_axis
113113
var_dtype = var_dtype or {}
114-
self.var_dtype = {k: check_dtype(v) for k, v in var_dtype.items()}
114+
self.var_dtype = {k: normalize_dtype(v) for k, v in var_dtype.items()}
115115
self.read_key = read_key
116116
self.store_key = store_key
117117

@@ -404,32 +404,31 @@ def categorize_features(self, df):
404404

405405
for col_list in features:
406406
for col in col_list:
407+
# data type with metadata
408+
dt_col = self.get_data_type(df, col)
407409

408-
dt = self.var_dtype.get(col, check_dtype(self.get_data_type(df, col)))
410+
# normalized data type
411+
dt = self.var_dtype.get(col, normalize_dtype(dt_col))
409412

410413
if col not in self.var_dtype:
411414
self.var_dtype[col] = dt
412415

416+
# metadata indicates decimal
417+
if hasattr(dt_col, 'metadata') and dt_col.metadata is not None and dt_col.metadata["decimal"]:
418+
cols_by_type["decimal"].add(col)
419+
413420
if np.issubdtype(dt, np.integer):
414-
colset = cols_by_type["int"]
415-
if col not in colset:
416-
colset.add(col)
421+
cols_by_type["int"].add(col)
422+
417423
if np.issubdtype(dt, np.number):
418424
colset = cols_by_type["num"]
419-
if col not in colset:
420-
colset.add(col)
421425
elif np.issubdtype(dt, np.datetime64):
422426
colset = cols_by_type["dt"]
423-
if col not in colset:
424-
colset.add(col)
425427
elif np.issubdtype(dt, np.bool_):
426428
colset = cols_by_type["bool"]
427-
if col not in colset:
428-
colset.add(col)
429429
else:
430430
colset = cols_by_type["str"]
431-
if col not in colset:
432-
colset.add(col)
431+
colset.add(col)
433432

434433
self.logger.debug(
435434
'Data type of column "{col}" is "{type}".'.format(

histogrammar/dfinterface/make_histograms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from .pandas_histogrammar import PandasHistogrammar
4444
from .spark_histogrammar import SparkHistogrammar
45-
from .filling_utils import check_dtype
45+
from .filling_utils import normalize_dtype
4646
from ..util import _get_sub_hist
4747

4848
logger = logging.getLogger()
@@ -232,7 +232,7 @@ def get_time_axes(df):
232232
return [
233233
c
234234
for c in df.columns
235-
if np.issubdtype(check_dtype(get_data_type(df, c)), np.datetime64)
235+
if np.issubdtype(normalize_dtype(get_data_type(df, c)), np.datetime64)
236236
]
237237

238238

histogrammar/dfinterface/pandas_histogrammar.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def get_data_type(self, df, col):
136136
elif inferred == 'boolean':
137137
data_type = 'bool'
138138
elif inferred in {'decimal', 'floating', 'mixed-integer-float'}:
139-
data_type = 'float'
139+
# decimal needs preprocessing (cast), signal this in metadata
140+
if inferred == "decimal":
141+
data_type = np.dtype('float', metadata={"decimal": True})
142+
else:
143+
data_type = "float"
140144
elif inferred in {'date', 'datetime', 'datetime64'}:
141145
data_type = 'datetime64'
142146
else: # categorical, mixed, etc -> object uses to_string()
@@ -187,6 +191,12 @@ def process_features(self, df, cols_by_type):
187191
)
188192
)
189193
idf[col] = df[col].apply(to_ns)
194+
195+
# treat decimal as float, as decimal is not supported by .quantile
196+
# (https://github.com/pandas-dev/pandas/issues/13157)
197+
for col in cols_by_type["decimal"]:
198+
idf[col] = df[col].apply(float)
199+
190200
return idf
191201

192202
def fill_histograms(self, idf):

histogrammar/dfinterface/spark_histogrammar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def get_data_type(self, df, col):
169169
dt = bool
170170
elif dt == "bigint":
171171
dt = np.int64
172+
elif dt.startswith("decimal("):
173+
return np.dtype(float, metadata={"decimal": True})
172174

173175
return np.dtype(dt)
174176

0 commit comments

Comments
 (0)