Skip to content

Commit 1ae429e

Browse files
committed
Add handling for spark DecimalType
This change addresses issue #1602 (#1602). Computations in the summarize process result in some floats when computing against decimal columns. To solution this, we simply convert those types to a DoubleType when performing those numeric operations.
1 parent 0450215 commit 1ae429e

3 files changed

Lines changed: 43 additions & 1 deletion

File tree

src/ydata_profiling/model/spark/describe_counts_spark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pandas as pd
77
from pyspark.sql import DataFrame
8-
from pyspark.sql import functions as F
8+
from pyspark.sql import functions as F, types as T
99

1010
from ydata_profiling.config import Settings
1111
from ydata_profiling.model.summary_algorithms import describe_counts
@@ -25,6 +25,9 @@ def describe_counts_spark(
2525
Returns:
2626
Updated settings, input series, and summary dictionary.
2727
"""
28+
# Cast Decimal Type s
29+
if isinstance(series.schema.fields[0].dataType, T.DecimalType):
30+
series = series.select(F.col(series.columns[0]).cast(T.DoubleType()).alias(series.columns[0]))
2831

2932
# Count occurrences of each value
3033
value_counts = series.groupBy(series.columns[0]).count()

src/ydata_profiling/model/spark/summary_spark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def spark_describe_1d(
4343

4444
if str(series.schema[0].dataType).startswith("ArrayType"):
4545
dtype = "ArrayType"
46+
elif str(series.schema[0].dataType).startswith("Decimal"):
47+
dtype = "decimal"
4648
else:
4749
dtype = series.schema[0].dataType.simpleString()
4850

@@ -56,6 +58,7 @@ def spark_describe_1d(
5658
"boolean": "Boolean",
5759
"date": "DateTime",
5860
"timestamp": "DateTime",
61+
"decimal": "Numeric",
5962
}[dtype]
6063

6164
return summarizer.summarize(config, series, dtype=vtype)

tests/issues/test_issue1602.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
Test for issue 1602:
3+
https://github.com/ydataai/ydata-profiling/issues/1602
4+
"""
5+
6+
from ydata_profiling import ProfileReport
7+
from pyspark.sql import types as T
8+
9+
def test_spark_handles_decimal_type(test_output_dir, spark_session):
10+
from decimal import Decimal
11+
spark = spark_session
12+
13+
schema = T.StructType(
14+
[
15+
T.StructField("number", T.StringType(), True),
16+
T.StructField("decimal", T.DecimalType(10, 2), True)
17+
]
18+
)
19+
20+
data = [
21+
(f"test_{num + 1}", Decimal(num + 1)) for num in range(205)
22+
]
23+
24+
data.extend(
25+
[
26+
("test_1", Decimal("1.05")) for _ in range(205)
27+
]
28+
)
29+
30+
test_df = spark.createDataFrame(data, schema=schema)
31+
32+
profile = ProfileReport(test_df, title="decimal_handling", explorative=True)
33+
output_file = test_output_dir / "decimal_handling.html"
34+
profile.to_file(output_file)
35+
36+
assert output_file.exists()

0 commit comments

Comments
 (0)