Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/benchmarks/bench_eval_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pyspark.cloudpickle import dumps as cloudpickle_dumps
from pyspark.serializers import write_int, write_long, SpecialLengths
from pyspark.sql.types import (
ArrayType,
BinaryType,
BooleanType,
DoubleType,
Expand Down Expand Up @@ -251,6 +252,19 @@ class MockDataFactory:
"string": (lambda r: pa.array([f"s{j}" for j in range(r)]), StringType()),
"binary": (lambda r: pa.array([f"b{j}".encode() for j in range(r)]), BinaryType()),
"boolean": (lambda r: pa.array(np.random.choice([True, False], r)), BooleanType()),
"string_array": (
lambda r: pa.array(
[[f"s{j}", f"t{j}"] for j in range(r)], type=pa.list_(pa.string())
),
ArrayType(StringType()),
),
"nested_int_array": (
lambda r: pa.array(
[[[j, j + 1], [j + 2]] for j in range(r)],
type=pa.list_(pa.list_(pa.int32())),
),
ArrayType(ArrayType(IntegerType())),
),
}

MIXED_TYPES = [
Expand All @@ -266,6 +280,8 @@ class MockDataFactory:
"pure_ints": [TYPE_REGISTRY["int"]],
"pure_floats": [TYPE_REGISTRY["double"]],
"pure_strings": [TYPE_REGISTRY["string"]],
"pure_string_arrays": [TYPE_REGISTRY["string_array"]],
"pure_nested_int_arrays": [TYPE_REGISTRY["nested_int_array"]],
"pure_ts": [
(
lambda r: pa.array(
Expand Down Expand Up @@ -480,6 +496,8 @@ class _ArrowBatchedBenchMixin:
"pure_ints": ("pure_ints", 50_000, 10, 5_000),
"pure_floats": ("pure_floats", 50_000, 10, 5_000),
"pure_strings": ("pure_strings", 50_000, 10, 5_000),
"pure_string_arrays": ("pure_string_arrays", 50_000, 10, 5_000),
"pure_nested_int_arrays": ("pure_nested_int_arrays", 50_000, 10, 5_000),
"mixed_types": ("mixed", 50_000, 10, 5_000),
}

Expand All @@ -502,6 +520,16 @@ def _build_scenario(cls, name):
"identity_udf": (lambda x: x, None, [0]),
"stringify_udf": (lambda x: str(x), StringType(), [0]),
"nullcheck_udf": (lambda x: x is not None, BooleanType(), [0]),
# Input-focused: consumes the value and returns a scalar (trivial output),
# so the Arrow->Python input conversion dominates. Type-agnostic (works on
# scalars and arrays), so it stays valid across the whole cross product;
# pair with the pure_string_arrays / pure_nested_int_arrays scenarios to
# measure array (and nested-array) input conversion.
"consume_udf": (
lambda x: len(x) if isinstance(x, (list, tuple)) else (0 if x is None else 1),
IntegerType(),
[0],
),
}
params = [list(_scenario_configs), list(_udfs)]
param_names = ["scenario", "udf"]
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,51 @@ def test_nested_array_input(self):
],
)

def test_array_input_is_python_list(self):
# The input conversion for array columns must hand the UDF Python lists
# (not numpy ndarrays), at any nesting depth. This guards the fast input
# path that builds columns via to_pandas() instead of to_pylist().
with self.sql_conf(
{"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False}
):
df = self.spark.range(0, 3).selectExpr(
"transform(sequence(0, 2), i -> cast(id + i as int)) as arr",
"transform(sequence(0, 1), "
"i -> array(cast(id + i as int), cast(id as int))) as nested",
)

@udf(returnType=StringType())
def type_of(x):
return type(x).__name__

@udf(returnType=StringType())
def type_of_inner(x):
return type(x[0]).__name__ if x else "empty"

row = df.select(
type_of("arr").alias("outer"),
type_of_inner("nested").alias("inner"),
).first()
self.assertEqual(row.outer, "list")
self.assertEqual(row.inner, "list")

def test_array_string_input_values(self):
# array<string> input values must be preserved exactly through the fast
# input path.
with self.sql_conf(
{"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False}
):
df = self.spark.range(0, 3).selectExpr(
"transform(sequence(0, 1), i -> cast(id + i as string)) as arr"
)

@udf(returnType=StringType())
def joined(a):
return ",".join(a)

result = [r.res for r in df.select(joined("arr").alias("res")).collect()]
self.assertEqual(result, ["0,1", "1,2", "2,3"])

def test_type_coercion_string_to_numeric(self):
df_int_value = self.spark.createDataFrame(["1", "2"], schema="string")
df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string")
Expand Down
48 changes: 46 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3029,6 +3029,7 @@ def cogrouped_func(
eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
and not runner_conf.use_legacy_pandas_udf_conversion
):
import numpy as np
import pyarrow as pa

# --- UDF preparation ---
Expand Down Expand Up @@ -3065,6 +3066,49 @@ def cogrouped_func(
for f in eval_conf.input_type
]

def _input_fast_listify_safe(dt: DataType) -> bool:
# For an array column whose leaf elements need no per-element input
# converter, pa.Array.to_pylist() is markedly slower than
# arr.to_pandas() followed by turning the resulting numpy ndarrays back
# into Python lists. This flag marks such columns; the leaf must be a
# type for which ArrowTableToRowsConversion needs no converter (so the
# column-level converter is None), which is why we only recurse through
# ArrayType here -- Map/Struct/other types either need a converter or do
# not benefit.
return isinstance(dt, ArrayType) and (
_input_fast_listify_safe(dt.elementType)
or not ArrowTableToRowsConversion._need_converter(dt.elementType)
)

# Per input column: (converter, use_pandas_listify). The two are mutually
# exclusive -- listify only applies when no per-element converter is needed.
input_col_plan = [
(
conv,
conv is None and _input_fast_listify_safe(f.dataType),
)
for conv, f in zip(arrow_to_py_converters, eval_conf.input_type)
]

def _ndarray_to_list(value: Any) -> Any:
# Recursively turn numpy ndarrays (as produced by pa.Array.to_pandas()
# for nested list types) into Python lists, so UDFs see the same object
# types (list, not ndarray) that to_pylist() would have produced.
if isinstance(value, np.ndarray):
return [_ndarray_to_list(v) for v in value.tolist()]
elif isinstance(value, list):
return [_ndarray_to_list(v) for v in value]
else:
return value

def _column_to_pylist(col: "pa.Array", conv, listify: bool) -> list:
if listify:
return [_ndarray_to_list(v) for v in col.to_pandas()]
elif conv is not None:
return [conv(v) for v in col.to_pylist()]
else:
return col.to_pylist()

@fail_on_stopiteration
def _evaluate_batch_udf(udf_func, rows):
if runner_conf.arrow_concurrency_level <= 0:
Expand All @@ -3080,8 +3124,8 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record

# --- Input: Arrow -> Python columns ---
columns = [
[conv(v) for v in col.to_pylist()] if conv is not None else col.to_pylist()
for col, conv in zip(input_batch.itercolumns(), arrow_to_py_converters)
_column_to_pylist(col, conv, listify)
for col, (conv, listify) in zip(input_batch.itercolumns(), input_col_plan)
]
if not columns:
columns = [[_NoValue] * num_rows]
Expand Down