Skip to content

Commit 4a52804

Browse files
committed
Add handling for spark correlations with no numeric fields
This change addresses issue #1722 (#1722). Assembling a vector column in Spark with no numeric columns results in features with a NULL size, NULL indices, and an empty list of values. This causes an exception to be raised when computing correlations. The solution here is to avoid computing the correlation matrix when there are no interval columns (numeric).
1 parent 0450215 commit 4a52804

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

src/ydata_profiling/model/spark/correlations_spark.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def _compute_corr_natively(df: DataFrame, summary: dict, corr_type: str) -> Arra
4949
interval_columns = [
5050
column for column, type_name in variables.items() if type_name == "Numeric"
5151
]
52+
53+
if not interval_columns:
54+
return [], interval_columns
55+
5256
df = df.select(*interval_columns)
5357

5458
# convert to vector column first

tests/issues/test_issue1722.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Test for issue 1723:
3+
https://github.com/ydataai/ydata-profiling/issues/1723
4+
"""
5+
import subprocess
6+
7+
import pandas as pd
8+
import pytest
9+
from docutils.nodes import title
10+
11+
from ydata_profiling import ProfileReport
12+
from datetime import date, datetime
13+
from pyspark.sql import types as T, SparkSession
14+
15+
def make_non_numeric_df(spark: SparkSession):
16+
# Intentionally not including any numeric types
17+
schema = T.StructType(
18+
[
19+
T.StructField("id", T.StringType(), False),
20+
T.StructField("d", T.DateType(), False),
21+
T.StructField("ts", T.TimestampType(), False),
22+
T.StructField("arr", T.ArrayType(T.IntegerType()), False),
23+
T.StructField("mp", T.MapType(T.StringType(), T.IntegerType()), False),
24+
T.StructField(
25+
"struct",
26+
T.StructType(
27+
[
28+
T.StructField("a", T.IntegerType(), False),
29+
T.StructField("b", T.StringType(), False),
30+
]
31+
),
32+
False,
33+
),
34+
]
35+
)
36+
37+
data = [
38+
("r1", date(2020, 1, 1), datetime(2020, 1, 1, 12, 0), [1, 2], {"x": 1, "y": 2}, (10, "aa")),
39+
("r2", date(2021, 6, 15), datetime(2021, 6, 15, 8, 30), [3], {"z": 3}, (20, "bb")),
40+
("r3", date(2022, 12, 31), datetime(2022, 12, 31, 23, 59), [], {}, (30, "cc")),
41+
]
42+
43+
return spark.createDataFrame(data, schema=schema)
44+
45+
46+
def test_issue1722(test_output_dir, spark_session):
47+
from pyspark.sql import functions as F, types as T
48+
49+
spark = spark_session
50+
51+
non_numeric_df = make_non_numeric_df(spark)
52+
53+
# type casting 1
54+
df_casted = non_numeric_df.select(
55+
[
56+
(
57+
F.col(field.name).cast("string").alias(field.name)
58+
if isinstance(field.dataType, (T.DateType, T.TimestampType))
59+
else F.col(field.name)
60+
)
61+
for field in non_numeric_df.schema
62+
]
63+
)
64+
# type casting 2
65+
complex_columns = [
66+
field.name
67+
for field in non_numeric_df.schema.fields
68+
if isinstance(field.dataType, (T.ArrayType, T.MapType, T.StructType))
69+
]
70+
for col_name in complex_columns:
71+
df_casted = df_casted.withColumn(col_name, F.to_json(F.col(col_name)))
72+
73+
profile = ProfileReport(df_casted, title="non_numeric_1722", explorative=True)
74+
output_file = test_output_dir / "non_numeric_1722.html"
75+
profile.to_file(output_file)
76+
77+
assert output_file.exists()

0 commit comments

Comments
 (0)