Skip to content

Commit 9e85d06

Browse files
Yicong-HuangHyukjinKwon
authored andcommitted
[SPARK-56973][PYTHON] Consolidate verify_pandas_result with verify_arrow_result via shared helper
### What changes were proposed in this pull request? Consolidate the two UDF-result-verification paths in `worker.py`: - Extract `_verify_column_schema(actual_names, expected_names, *, assign_cols_by_name)` that raises `RESULT_COLUMN_NAMES_MISMATCH` (by-name) or `RESULT_COLUMN_SCHEMA_MISMATCH` (by-position). - `verify_pandas_result` now uses `verify_return_type` for the container check and the new helper for the schema check. - `verify_arrow_result` uses the same helper, keeping only its own `RESULT_COLUMN_TYPES_MISMATCH` check inline. - Fix `verify_return_type` to derive the top-level package (`pandas` instead of `pandas.core` for `pd.DataFrame`). ### Why are the changes needed? After SPARK-56937 added the column-count check to the arrow path, the pandas and arrow verifiers raise the same set of error classes but duplicate the name/count logic. A shared helper prevents drift as more pandas eval types are refactored under SPARK-55388. ### Does this PR introduce _any_ user-facing change? No. Same error classes raised under the same conditions with the same `messageParameters`. ### How was this patch tested? Existing tests. Verified locally: `test_pandas_map`, `test_pandas_grouped_map`, `test_pandas_cogrouped_map`, `test_arrow_grouped_map`, `test_arrow_cogrouped_map`, and `test_udtf::LegacyUDTFArrowTests` all pass. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #56021 from Yicong-Huang/refactor/consolidate-verify-pandas-result. Authored-by: Yicong Huang <17627829+Yicong-Huang@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 411dedc commit 9e85d06

1 file changed

Lines changed: 100 additions & 119 deletions

File tree

python/pyspark/worker.py

Lines changed: 100 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
T = TypeVar("T")
4545

4646
if TYPE_CHECKING:
47+
import pandas as pd
48+
import pyarrow as pa
49+
4750
from pyspark.sql.pandas._typing import GroupedBatch
4851

4952
from pyspark.accumulators import (
@@ -256,8 +259,7 @@ def verify_return_type(result: T, expected_type: Type[T]) -> T:
256259
"""
257260
if get_origin(expected_type) is Iterator:
258261
(element_type,) = get_args(expected_type)
259-
package = getattr(inspect.getmodule(element_type), "__package__", "")
260-
label = f"iterator of {package}.{element_type.__name__}"
262+
label = f"iterator of {_top_level_package(element_type)}.{element_type.__name__}"
261263

262264
if not isinstance(result, Iterator):
263265
raise PySparkTypeError(
@@ -279,17 +281,21 @@ def check_element(element: T) -> T:
279281
return map(check_element, result) # type: ignore[return-value]
280282

281283
if not isinstance(result, expected_type):
282-
package = getattr(inspect.getmodule(expected_type), "__package__", "")
283284
raise PySparkTypeError(
284285
errorClass="UDF_RETURN_TYPE",
285286
messageParameters={
286-
"expected": f"{package}.{expected_type.__name__}",
287+
"expected": f"{_top_level_package(expected_type)}.{expected_type.__name__}",
287288
"actual": type(result).__name__,
288289
},
289290
)
290291
return result
291292

292293

294+
def _top_level_package(t: type) -> str:
295+
"""Return the top-level package of ``t`` (``pandas`` for ``pd.DataFrame``)."""
296+
return (t.__module__ or "").split(".", 1)[0]
297+
298+
293299
def verify_result_row_count(result_length: int, expected: int) -> None:
294300
"""Raise if the result row count doesn't match the expected input row count."""
295301
if result_length != expected:
@@ -429,64 +435,59 @@ def verify_element(elem):
429435
)
430436

431437

432-
def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_return_schema):
433-
import pandas as pd
434-
435-
if isinstance(return_type, StructType):
436-
if not isinstance(result, pd.DataFrame):
437-
raise PySparkTypeError(
438-
errorClass="UDF_RETURN_TYPE",
438+
def _verify_column_schema(
439+
actual_names: list, expected_names: list, *, assign_cols_by_name: bool
440+
) -> None:
441+
"""Check column names (by-name) or count (by-position) match the expected schema."""
442+
if assign_cols_by_name:
443+
actual_set = set(actual_names)
444+
expected_set = set(expected_names)
445+
missing = sorted(expected_set.difference(actual_set))
446+
extra = sorted(actual_set.difference(expected_set))
447+
if missing or extra:
448+
raise PySparkRuntimeError(
449+
errorClass="RESULT_COLUMN_NAMES_MISMATCH",
439450
messageParameters={
440-
"expected": "pandas.DataFrame",
441-
"actual": type(result).__name__,
451+
"missing": f" Missing: {', '.join(missing)}." if missing else "",
452+
"extra": f" Unexpected: {', '.join(extra)}." if extra else "",
442453
},
443454
)
455+
elif len(actual_names) != len(expected_names):
456+
raise PySparkRuntimeError(
457+
errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
458+
messageParameters={
459+
"expected": str(len(expected_names)),
460+
"actual": str(len(actual_names)),
461+
},
462+
)
444463

445-
# check the schema of the result only if it is not empty or has columns
446-
if not result.empty or len(result.columns) != 0:
447-
# if any column name of the result is a string
448-
# the column names of the result have to match the return type
449-
# see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
450-
field_names = set([field.name for field in return_type.fields])
451-
# only the first len(field_names) result columns are considered
452-
# when truncating the return schema
453-
result_columns = (
454-
result.columns[: len(field_names)] if truncate_return_schema else result.columns
455-
)
456-
column_names = set(result_columns)
457-
if (
458-
assign_cols_by_name
459-
and any(isinstance(name, str) for name in result.columns)
460-
and column_names != field_names
461-
):
462-
missing = sorted(list(field_names.difference(column_names)))
463-
missing = f" Missing: {', '.join(missing)}." if missing else ""
464464

465-
extra = sorted(list(column_names.difference(field_names)))
466-
extra = f" Unexpected: {', '.join(extra)}." if extra else ""
465+
def verify_pandas_result(
466+
result: Union["pd.DataFrame", "pd.Series"],
467+
return_type: DataType,
468+
assign_cols_by_name: bool,
469+
truncate_return_schema: bool,
470+
) -> None:
471+
import pandas as pd
467472

468-
raise PySparkRuntimeError(
469-
errorClass="RESULT_COLUMN_NAMES_MISMATCH",
470-
messageParameters={
471-
"missing": missing,
472-
"extra": extra,
473-
},
474-
)
475-
# otherwise the number of columns of result have to match the return type
476-
elif len(result_columns) != len(return_type):
477-
raise PySparkRuntimeError(
478-
errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
479-
messageParameters={
480-
"expected": str(len(return_type)),
481-
"actual": str(len(result.columns)),
482-
},
483-
)
484-
else:
485-
if not isinstance(result, pd.Series):
486-
raise PySparkTypeError(
487-
errorClass="UDF_RETURN_TYPE",
488-
messageParameters={"expected": "pandas.Series", "actual": type(result).__name__},
489-
)
473+
if not isinstance(return_type, StructType):
474+
verify_return_type(result, pd.Series)
475+
return
476+
477+
verify_return_type(result, pd.DataFrame)
478+
479+
# Skip schema check on a fully empty result (no rows and no columns).
480+
if result.empty and len(result.columns) == 0:
481+
return
482+
483+
field_names = [field.name for field in return_type.fields]
484+
actual_names = (
485+
list(result.columns[: len(field_names)]) if truncate_return_schema else list(result.columns)
486+
)
487+
# By-name mode only applies when the result has string column names;
488+
# a numeric RangeIndex falls back to a by-position count check.
489+
by_name = assign_cols_by_name and any(isinstance(n, str) for n in result.columns)
490+
_verify_column_schema(actual_names, field_names, assign_cols_by_name=by_name)
490491

491492

492493
def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
@@ -511,73 +512,53 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se
511512
return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), return_type)]
512513

513514

514-
def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
515-
# the types of the fields have to be identical to return type
516-
# an empty table can have no columns; if there are columns, they have to match
517-
if result.num_columns != 0 or result.num_rows != 0:
518-
# columns are either mapped by name or position
519-
if assign_cols_by_name:
520-
actual_cols_and_types = {
521-
name: dataType for name, dataType in zip(result.schema.names, result.schema.types)
522-
}
523-
missing = sorted(
524-
list(set(expected_cols_and_types.keys()).difference(actual_cols_and_types.keys()))
525-
)
526-
extra = sorted(
527-
list(set(actual_cols_and_types.keys()).difference(expected_cols_and_types.keys()))
528-
)
529-
530-
if missing or extra:
531-
missing = f" Missing: {', '.join(missing)}." if missing else ""
532-
extra = f" Unexpected: {', '.join(extra)}." if extra else ""
533-
534-
raise PySparkRuntimeError(
535-
errorClass="RESULT_COLUMN_NAMES_MISMATCH",
536-
messageParameters={
537-
"missing": missing,
538-
"extra": extra,
539-
},
540-
)
515+
def verify_arrow_result(
516+
result: Union["pa.Table", "pa.RecordBatch"],
517+
assign_cols_by_name: bool,
518+
expected_cols_and_types: Union[dict[str, "pa.DataType"], list[tuple[str, "pa.DataType"]]],
519+
) -> None:
520+
# Skip schema check on a fully empty result (no rows and no columns).
521+
if result.num_columns == 0 and result.num_rows == 0:
522+
return
523+
524+
actual_names = list(result.schema.names)
525+
actual_types = list(result.schema.types)
526+
# expected_cols_and_types is a dict in by-name mode, list of (name, type) by position.
527+
if isinstance(expected_cols_and_types, dict):
528+
expected_names = list(expected_cols_and_types.keys())
529+
else:
530+
expected_names = [name for name, _ in expected_cols_and_types]
541531

542-
column_types = [
543-
(name, expected_cols_and_types[name], actual_cols_and_types[name])
544-
for name in sorted(expected_cols_and_types.keys())
545-
]
546-
else:
547-
actual_cols_and_types = [
548-
(name, dataType) for name, dataType in zip(result.schema.names, result.schema.types)
549-
]
550-
if len(actual_cols_and_types) != len(expected_cols_and_types):
551-
raise PySparkRuntimeError(
552-
errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
553-
messageParameters={
554-
"expected": str(len(expected_cols_and_types)),
555-
"actual": str(len(actual_cols_and_types)),
556-
},
557-
)
558-
column_types = [
559-
(expected_name, expected_type, actual_type)
560-
for (expected_name, expected_type), (actual_name, actual_type) in zip(
561-
expected_cols_and_types, actual_cols_and_types
562-
)
563-
]
532+
_verify_column_schema(actual_names, expected_names, assign_cols_by_name=assign_cols_by_name)
564533

565-
type_mismatch = [
566-
(name, expected, actual)
567-
for name, expected, actual in column_types
568-
if actual != expected
534+
if isinstance(expected_cols_and_types, dict):
535+
actual_by_name = dict(zip(actual_names, actual_types))
536+
column_types = [
537+
(name, expected_cols_and_types[name], actual_by_name[name])
538+
for name in sorted(expected_cols_and_types.keys())
569539
]
570-
571-
if type_mismatch:
572-
raise PySparkRuntimeError(
573-
errorClass="RESULT_COLUMN_TYPES_MISMATCH",
574-
messageParameters={
575-
"mismatch": ", ".join(
576-
"column '{}' (expected {}, actual {})".format(name, expected, actual)
577-
for name, expected, actual in type_mismatch
578-
)
579-
},
540+
else:
541+
column_types = [
542+
(expected_name, expected_type, actual_type)
543+
for (expected_name, expected_type), actual_type in zip(
544+
expected_cols_and_types, actual_types
580545
)
546+
]
547+
548+
type_mismatch = [
549+
(name, expected, actual) for name, expected, actual in column_types if actual != expected
550+
]
551+
552+
if type_mismatch:
553+
raise PySparkRuntimeError(
554+
errorClass="RESULT_COLUMN_TYPES_MISMATCH",
555+
messageParameters={
556+
"mismatch": ", ".join(
557+
"column '{}' (expected {}, actual {})".format(name, expected, actual)
558+
for name, expected, actual in type_mismatch
559+
)
560+
},
561+
)
581562

582563

583564
def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):

0 commit comments

Comments
 (0)