Skip to content

Commit 9dbe381

Browse files
gaogaotiantianzhengruifeng
authored andcommitted
[SPARK-56340][PYTHON] Move input_type schema to eval conf
### What changes were proposed in this pull request? Use eval conf to pass the schema json, instead of sending a random string before UDF. ### Why are the changes needed? Clean up JVM <-> python worker protocol. We should not randomly pass data. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `test_udf` passed locally, the rest is on CI. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #55170 from gaogaotiantian/move-input-type. Authored-by: Tian Gao <gaogaotiantian@hotmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
1 parent 491add8 commit 9dbe381

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

python/pyspark/worker.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ def state_server_socket_port(self) -> Optional[int | str]:
207207
except ValueError:
208208
return port
209209

210+
@property
211+
def input_type(self) -> Optional[DataType]:
212+
input_type = self.get("input_type", None, lower_str=False)
213+
if input_type is None:
214+
return None
215+
return _parse_datatype_json_string(input_type)
216+
210217

211218
def report_times(outfile, boot, init, finish, processing_time_ms):
212219
write_int(SpecialLengths.TIMING_DATA, outfile)
@@ -2532,11 +2539,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
25322539
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
25332540
PythonEvalType.SQL_SCALAR_ARROW_UDF,
25342541
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
2542+
PythonEvalType.SQL_ARROW_BATCHED_UDF,
25352543
):
25362544
ser = ArrowStreamSerializer(write_start_stream=True)
2537-
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
2538-
input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
2539-
ser = ArrowStreamSerializer(write_start_stream=True)
25402545
else:
25412546
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
25422547
# pandas Series. See SPARK-27240.
@@ -2865,7 +2870,7 @@ def grouped_func(
28652870
ArrowTableToRowsConversion._create_converter(
28662871
f.dataType, none_on_identity=True, binary_as_bytes=runner_conf.binary_as_bytes
28672872
)
2868-
for f in input_type
2873+
for f in eval_conf.input_type
28692874
]
28702875

28712876
@fail_on_stopiteration
@@ -2968,7 +2973,7 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record
29682973
pandas_columns = ArrowBatchTransformer.to_pandas(
29692974
input_batch,
29702975
timezone=runner_conf.timezone,
2971-
schema=input_type,
2976+
schema=eval_conf.input_type,
29722977
struct_in_pandas="row",
29732978
ndarray_as_list=True,
29742979
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,17 @@ class ArrowPythonWithNamedArgumentRunner(
136136

137137
override protected def runnerConf: Map[String, String] = super.runnerConf ++ pythonRunnerConf
138138

139-
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
139+
override protected def evalConf: Map[String, String] = {
140140
if (evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) {
141-
PythonWorkerUtils.writeUTF(schema.json, dataOut)
141+
super.evalConf ++ Map(
142+
"input_type" -> schema.json
143+
)
144+
} else {
145+
super.evalConf
142146
}
147+
}
148+
149+
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
143150
PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
144151
}
145152
}

0 commit comments

Comments
 (0)