Skip to content

Commit 588be6b

Browse files
committed
fix: add handling for spark correlations with no numeric fields
Assembling a vector column in Spark with no numeric columns results in features with a NULL size, NULL indices, and an empty list of values. This causes an exception to be raised when computing correlations. The solution here is to avoid computing the correlation matrix when there are no interval columns (numeric). This change addresses issue #1722.
1 parent abde1f2 commit 588be6b

2 files changed

Lines changed: 76 additions & 0 deletions

File tree

src/ydata_profiling/model/spark/correlations_spark.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def _compute_corr_natively(df: DataFrame, summary: dict, corr_type: str) -> Arra
4949
interval_columns = [
5050
column for column, type_name in variables.items() if type_name == "Numeric"
5151
]
52+
53+
if not interval_columns:
54+
return [], interval_columns
55+
5256
df = df.select(*interval_columns)
5357

5458
# convert to vector column first
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
Test for issue 1722:
3+
https://github.com/ydataai/ydata-profiling/issues/1722
4+
"""
5+
6+
from ydata_profiling import ProfileReport
7+
from datetime import date, datetime
8+
from pyspark.sql import types as T, SparkSession
9+
10+
def make_non_numeric_df(spark: SparkSession):
11+
# Intentionally not including any numeric types
12+
schema = T.StructType(
13+
[
14+
T.StructField("id", T.StringType(), False),
15+
T.StructField("d", T.DateType(), False),
16+
T.StructField("ts", T.TimestampType(), False),
17+
T.StructField("arr", T.ArrayType(T.IntegerType()), False),
18+
T.StructField("mp", T.MapType(T.StringType(), T.IntegerType()), False),
19+
T.StructField(
20+
"struct",
21+
T.StructType(
22+
[
23+
T.StructField("a", T.IntegerType(), False),
24+
T.StructField("b", T.StringType(), False),
25+
]
26+
),
27+
False,
28+
),
29+
]
30+
)
31+
32+
data = [
33+
("r1", date(2020, 1, 1), datetime(2020, 1, 1, 12, 0), [1, 2], {"x": 1, "y": 2}, (10, "aa")),
34+
("r2", date(2021, 6, 15), datetime(2021, 6, 15, 8, 30), [3], {"z": 3}, (20, "bb")),
35+
("r3", date(2022, 12, 31), datetime(2022, 12, 31, 23, 59), [], {}, (30, "cc")),
36+
]
37+
38+
return spark.createDataFrame(data, schema=schema)
39+
40+
41+
def test_issue1722(test_output_dir, spark_session):
42+
from pyspark.sql import functions as F
43+
44+
spark = spark_session
45+
46+
non_numeric_df = make_non_numeric_df(spark)
47+
48+
# type casting 1
49+
df_casted = non_numeric_df.select(
50+
[
51+
(
52+
F.col(field.name).cast("string").alias(field.name)
53+
if isinstance(field.dataType, (T.DateType, T.TimestampType))
54+
else F.col(field.name)
55+
)
56+
for field in non_numeric_df.schema
57+
]
58+
)
59+
# type casting 2
60+
complex_columns = [
61+
field.name
62+
for field in non_numeric_df.schema.fields
63+
if isinstance(field.dataType, (T.ArrayType, T.MapType, T.StructType))
64+
]
65+
for col_name in complex_columns:
66+
df_casted = df_casted.withColumn(col_name, F.to_json(F.col(col_name)))
67+
68+
profile = ProfileReport(df_casted, title="non_numeric_1722", explorative=True)
69+
output_file = test_output_dir / "non_numeric_1722.html"
70+
profile.to_file(output_file)
71+
72+
assert output_file.exists()

0 commit comments

Comments
 (0)