4444T = TypeVar ("T" )
4545
4646if TYPE_CHECKING :
47+ import pandas as pd
48+ import pyarrow as pa
49+
4750 from pyspark .sql .pandas ._typing import GroupedBatch
4851
4952from pyspark .accumulators import (
@@ -256,8 +259,7 @@ def verify_return_type(result: T, expected_type: Type[T]) -> T:
256259 """
257260 if get_origin (expected_type ) is Iterator :
258261 (element_type ,) = get_args (expected_type )
259- package = getattr (inspect .getmodule (element_type ), "__package__" , "" )
260- label = f"iterator of { package } .{ element_type .__name__ } "
262+ label = f"iterator of { _top_level_package (element_type )} .{ element_type .__name__ } "
261263
262264 if not isinstance (result , Iterator ):
263265 raise PySparkTypeError (
@@ -279,17 +281,21 @@ def check_element(element: T) -> T:
279281 return map (check_element , result ) # type: ignore[return-value]
280282
281283 if not isinstance (result , expected_type ):
282- package = getattr (inspect .getmodule (expected_type ), "__package__" , "" )
283284 raise PySparkTypeError (
284285 errorClass = "UDF_RETURN_TYPE" ,
285286 messageParameters = {
286- "expected" : f"{ package } .{ expected_type .__name__ } " ,
287+ "expected" : f"{ _top_level_package ( expected_type ) } .{ expected_type .__name__ } " ,
287288 "actual" : type (result ).__name__ ,
288289 },
289290 )
290291 return result
291292
292293
294+ def _top_level_package (t : type ) -> str :
295+ """Return the top-level package of ``t`` (``pandas`` for ``pd.DataFrame``)."""
296+ return (t .__module__ or "" ).split ("." , 1 )[0 ]
297+
298+
293299def verify_result_row_count (result_length : int , expected : int ) -> None :
294300 """Raise if the result row count doesn't match the expected input row count."""
295301 if result_length != expected :
@@ -429,64 +435,59 @@ def verify_element(elem):
429435 )
430436
431437
432- def verify_pandas_result (result , return_type , assign_cols_by_name , truncate_return_schema ):
433- import pandas as pd
434-
435- if isinstance (return_type , StructType ):
436- if not isinstance (result , pd .DataFrame ):
437- raise PySparkTypeError (
438- errorClass = "UDF_RETURN_TYPE" ,
438+ def _verify_column_schema (
439+ actual_names : list , expected_names : list , * , assign_cols_by_name : bool
440+ ) -> None :
441+ """Check column names (by-name) or count (by-position) match the expected schema."""
442+ if assign_cols_by_name :
443+ actual_set = set (actual_names )
444+ expected_set = set (expected_names )
445+ missing = sorted (expected_set .difference (actual_set ))
446+ extra = sorted (actual_set .difference (expected_set ))
447+ if missing or extra :
448+ raise PySparkRuntimeError (
449+ errorClass = "RESULT_COLUMN_NAMES_MISMATCH" ,
439450 messageParameters = {
440- "expected " : "pandas.DataFrame " ,
441- "actual " : type ( result ). __name__ ,
451+ "missing " : f" Missing: { ', ' . join ( missing ) } ." if missing else " " ,
452+ "extra " : f" Unexpected: { ', ' . join ( extra ) } ." if extra else "" ,
442453 },
443454 )
455+ elif len (actual_names ) != len (expected_names ):
456+ raise PySparkRuntimeError (
457+ errorClass = "RESULT_COLUMN_SCHEMA_MISMATCH" ,
458+ messageParameters = {
459+ "expected" : str (len (expected_names )),
460+ "actual" : str (len (actual_names )),
461+ },
462+ )
444463
445- # check the schema of the result only if it is not empty or has columns
446- if not result .empty or len (result .columns ) != 0 :
447- # if any column name of the result is a string
448- # the column names of the result have to match the return type
449- # see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
450- field_names = set ([field .name for field in return_type .fields ])
451- # only the first len(field_names) result columns are considered
452- # when truncating the return schema
453- result_columns = (
454- result .columns [: len (field_names )] if truncate_return_schema else result .columns
455- )
456- column_names = set (result_columns )
457- if (
458- assign_cols_by_name
459- and any (isinstance (name , str ) for name in result .columns )
460- and column_names != field_names
461- ):
462- missing = sorted (list (field_names .difference (column_names )))
463- missing = f" Missing: { ', ' .join (missing )} ." if missing else ""
464464
465- extra = sorted (list (column_names .difference (field_names )))
466- extra = f" Unexpected: { ', ' .join (extra )} ." if extra else ""
465+ def verify_pandas_result (
466+ result : Union ["pd.DataFrame" , "pd.Series" ],
467+ return_type : DataType ,
468+ assign_cols_by_name : bool ,
469+ truncate_return_schema : bool ,
470+ ) -> None :
471+ import pandas as pd
467472
468- raise PySparkRuntimeError (
469- errorClass = "RESULT_COLUMN_NAMES_MISMATCH" ,
470- messageParameters = {
471- "missing" : missing ,
472- "extra" : extra ,
473- },
474- )
475- # otherwise the number of columns of result have to match the return type
476- elif len (result_columns ) != len (return_type ):
477- raise PySparkRuntimeError (
478- errorClass = "RESULT_COLUMN_SCHEMA_MISMATCH" ,
479- messageParameters = {
480- "expected" : str (len (return_type )),
481- "actual" : str (len (result .columns )),
482- },
483- )
484- else :
485- if not isinstance (result , pd .Series ):
486- raise PySparkTypeError (
487- errorClass = "UDF_RETURN_TYPE" ,
488- messageParameters = {"expected" : "pandas.Series" , "actual" : type (result ).__name__ },
489- )
473+ if not isinstance (return_type , StructType ):
474+ verify_return_type (result , pd .Series )
475+ return
476+
477+ verify_return_type (result , pd .DataFrame )
478+
479+ # Skip schema check on a fully empty result (no rows and no columns).
480+ if result .empty and len (result .columns ) == 0 :
481+ return
482+
483+ field_names = [field .name for field in return_type .fields ]
484+ actual_names = (
485+ list (result .columns [: len (field_names )]) if truncate_return_schema else list (result .columns )
486+ )
487+ # By-name mode only applies when the result has string column names;
488+ # a numeric RangeIndex falls back to a by-position count check.
489+ by_name = assign_cols_by_name and any (isinstance (n , str ) for n in result .columns )
490+ _verify_column_schema (actual_names , field_names , assign_cols_by_name = by_name )
490491
491492
492493def wrap_cogrouped_map_pandas_udf (f , return_type , argspec , runner_conf ):
@@ -511,73 +512,53 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se
511512 return lambda kl , vl , kr , vr : [(wrapped (kl , vl , kr , vr ), return_type )]
512513
513514
514- def verify_arrow_result (result , assign_cols_by_name , expected_cols_and_types ):
515- # the types of the fields have to be identical to return type
516- # an empty table can have no columns; if there are columns, they have to match
517- if result .num_columns != 0 or result .num_rows != 0 :
518- # columns are either mapped by name or position
519- if assign_cols_by_name :
520- actual_cols_and_types = {
521- name : dataType for name , dataType in zip (result .schema .names , result .schema .types )
522- }
523- missing = sorted (
524- list (set (expected_cols_and_types .keys ()).difference (actual_cols_and_types .keys ()))
525- )
526- extra = sorted (
527- list (set (actual_cols_and_types .keys ()).difference (expected_cols_and_types .keys ()))
528- )
529-
530- if missing or extra :
531- missing = f" Missing: { ', ' .join (missing )} ." if missing else ""
532- extra = f" Unexpected: { ', ' .join (extra )} ." if extra else ""
533-
534- raise PySparkRuntimeError (
535- errorClass = "RESULT_COLUMN_NAMES_MISMATCH" ,
536- messageParameters = {
537- "missing" : missing ,
538- "extra" : extra ,
539- },
540- )
515+ def verify_arrow_result (
516+ result : Union ["pa.Table" , "pa.RecordBatch" ],
517+ assign_cols_by_name : bool ,
518+ expected_cols_and_types : Union [dict [str , "pa.DataType" ], list [tuple [str , "pa.DataType" ]]],
519+ ) -> None :
520+ # Skip schema check on a fully empty result (no rows and no columns).
521+ if result .num_columns == 0 and result .num_rows == 0 :
522+ return
523+
524+ actual_names = list (result .schema .names )
525+ actual_types = list (result .schema .types )
526+ # expected_cols_and_types is a dict in by-name mode, list of (name, type) by position.
527+ if isinstance (expected_cols_and_types , dict ):
528+ expected_names = list (expected_cols_and_types .keys ())
529+ else :
530+ expected_names = [name for name , _ in expected_cols_and_types ]
541531
542- column_types = [
543- (name , expected_cols_and_types [name ], actual_cols_and_types [name ])
544- for name in sorted (expected_cols_and_types .keys ())
545- ]
546- else :
547- actual_cols_and_types = [
548- (name , dataType ) for name , dataType in zip (result .schema .names , result .schema .types )
549- ]
550- if len (actual_cols_and_types ) != len (expected_cols_and_types ):
551- raise PySparkRuntimeError (
552- errorClass = "RESULT_COLUMN_SCHEMA_MISMATCH" ,
553- messageParameters = {
554- "expected" : str (len (expected_cols_and_types )),
555- "actual" : str (len (actual_cols_and_types )),
556- },
557- )
558- column_types = [
559- (expected_name , expected_type , actual_type )
560- for (expected_name , expected_type ), (actual_name , actual_type ) in zip (
561- expected_cols_and_types , actual_cols_and_types
562- )
563- ]
532+ _verify_column_schema (actual_names , expected_names , assign_cols_by_name = assign_cols_by_name )
564533
565- type_mismatch = [
566- (name , expected , actual )
567- for name , expected , actual in column_types
568- if actual != expected
534+ if isinstance (expected_cols_and_types , dict ):
535+ actual_by_name = dict (zip (actual_names , actual_types ))
536+ column_types = [
537+ (name , expected_cols_and_types [name ], actual_by_name [name ])
538+ for name in sorted (expected_cols_and_types .keys ())
569539 ]
570-
571- if type_mismatch :
572- raise PySparkRuntimeError (
573- errorClass = "RESULT_COLUMN_TYPES_MISMATCH" ,
574- messageParameters = {
575- "mismatch" : ", " .join (
576- "column '{}' (expected {}, actual {})" .format (name , expected , actual )
577- for name , expected , actual in type_mismatch
578- )
579- },
540+ else :
541+ column_types = [
542+ (expected_name , expected_type , actual_type )
543+ for (expected_name , expected_type ), actual_type in zip (
544+ expected_cols_and_types , actual_types
580545 )
546+ ]
547+
548+ type_mismatch = [
549+ (name , expected , actual ) for name , expected , actual in column_types if actual != expected
550+ ]
551+
552+ if type_mismatch :
553+ raise PySparkRuntimeError (
554+ errorClass = "RESULT_COLUMN_TYPES_MISMATCH" ,
555+ messageParameters = {
556+ "mismatch" : ", " .join (
557+ "column '{}' (expected {}, actual {})" .format (name , expected , actual )
558+ for name , expected , actual in type_mismatch
559+ )
560+ },
561+ )
581562
582563
583564def wrap_grouped_transform_with_state_pandas_udf (f , return_type , runner_conf ):
0 commit comments