6666//! | Layer + kind | Family prefix |
6767//! | ----------------------------- | ------------- |
6868//! | `PythonLogicalCodec` scalar | `DFPYUDF` |
69+ //! | `PythonLogicalCodec` agg | `DFPYUDA` |
70+ //! | `PythonLogicalCodec` window | `DFPYUDW` |
6971//! | `PythonPhysicalCodec` scalar | `DFPYUDF` |
72+ //! | `PythonPhysicalCodec` agg | `DFPYUDA` |
73+ //! | `PythonPhysicalCodec` window | `DFPYUDW` |
7074//! | User FFI extension codec | user-chosen |
7175//! | Default codec | (none) |
7276//!
73- //! Aggregate and window UDF families are reserved for follow-on work.
74- //!
7577//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported
7678//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`.
7779//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape
@@ -94,8 +96,8 @@ use datafusion::datasource::TableProvider;
9496use datafusion:: datasource:: file_format:: FileFormatFactory ;
9597use datafusion:: execution:: TaskContext ;
9698use datafusion:: logical_expr:: {
97- AggregateUDF , Extension , LogicalPlan , ScalarUDF , ScalarUDFImpl , Signature , TypeSignature ,
98- Volatility , WindowUDF ,
99+ AggregateUDF , AggregateUDFImpl , Extension , LogicalPlan , ScalarUDF , ScalarUDFImpl , Signature ,
100+ TypeSignature , Volatility , WindowUDF , WindowUDFImpl ,
99101} ;
100102use datafusion:: physical_expr:: PhysicalExpr ;
101103use datafusion:: physical_plan:: ExecutionPlan ;
@@ -105,7 +107,9 @@ use pyo3::prelude::*;
105107use pyo3:: sync:: PyOnceLock ;
106108use pyo3:: types:: { PyBytes , PyTuple } ;
107109
110+ use crate :: udaf:: PythonFunctionAggregateUDF ;
108111use crate :: udf:: PythonFunctionScalarUDF ;
112+ use crate :: udwf:: PythonFunctionWindowUDF ;
109113
110114// Wire-format framing for inlined Python UDF payloads.
111115//
@@ -126,6 +130,16 @@ use crate::udf::PythonFunctionScalarUDF;
126130/// volatility).
127131pub ( crate ) const PY_SCALAR_UDF_FAMILY : & [ u8 ] = b"DFPYUDF" ;
128132
133+ /// Family prefix for an inlined Python aggregate UDF
134+ /// (cloudpickled tuple of name, accumulator factory, input schema,
135+ /// return type, state types schema, volatility).
136+ pub ( crate ) const PY_AGG_UDF_FAMILY : & [ u8 ] = b"DFPYUDA" ;
137+
138+ /// Family prefix for an inlined Python window UDF
139+ /// (cloudpickled tuple of name, evaluator factory, input schema,
140+ /// return type, volatility).
141+ pub ( crate ) const PY_WINDOW_UDF_FAMILY : & [ u8 ] = b"DFPYUDW" ;
142+
129143/// Wire-format version this build emits.
130144pub ( crate ) const WIRE_VERSION_CURRENT : u8 = 1 ;
131145
@@ -299,18 +313,30 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
299313 }
300314
301315 fn try_encode_udaf ( & self , node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
316+ if try_encode_python_agg_udf ( node, buf) ? {
317+ return Ok ( ( ) ) ;
318+ }
302319 self . inner . try_encode_udaf ( node, buf)
303320 }
304321
305322 fn try_decode_udaf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < AggregateUDF > > {
323+ if let Some ( udaf) = try_decode_python_agg_udf ( buf) ? {
324+ return Ok ( udaf) ;
325+ }
306326 self . inner . try_decode_udaf ( name, buf)
307327 }
308328
309329 fn try_encode_udwf ( & self , node : & WindowUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
330+ if try_encode_python_window_udf ( node, buf) ? {
331+ return Ok ( ( ) ) ;
332+ }
310333 self . inner . try_encode_udwf ( node, buf)
311334 }
312335
313336 fn try_decode_udwf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < WindowUDF > > {
337+ if let Some ( udwf) = try_decode_python_window_udf ( buf) ? {
338+ return Ok ( udwf) ;
339+ }
314340 self . inner . try_decode_udwf ( name, buf)
315341 }
316342}
@@ -389,18 +415,30 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
389415 }
390416
391417 fn try_encode_udaf ( & self , node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
418+ if try_encode_python_agg_udf ( node, buf) ? {
419+ return Ok ( ( ) ) ;
420+ }
392421 self . inner . try_encode_udaf ( node, buf)
393422 }
394423
395424 fn try_decode_udaf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < AggregateUDF > > {
425+ if let Some ( udaf) = try_decode_python_agg_udf ( buf) ? {
426+ return Ok ( udaf) ;
427+ }
396428 self . inner . try_decode_udaf ( name, buf)
397429 }
398430
399431 fn try_encode_udwf ( & self , node : & WindowUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
432+ if try_encode_python_window_udf ( node, buf) ? {
433+ return Ok ( ( ) ) ;
434+ }
400435 self . inner . try_encode_udwf ( node, buf)
401436 }
402437
403438 fn try_decode_udwf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < WindowUDF > > {
439+ if let Some ( udwf) = try_decode_python_window_udf ( buf) ? {
440+ return Ok ( udwf) ;
441+ }
404442 self . inner . try_decode_udwf ( name, buf)
405443 }
406444}
@@ -564,6 +602,11 @@ fn build_single_field_schema_bytes(field: &Field) -> PyResult<Vec<u8>> {
564602 schema_to_ipc_bytes ( & Schema :: new ( vec ! [ field. clone( ) ] ) ) . map_err ( arrow_to_py_err)
565603}
566604
605+ /// Emit a multi-field IPC schema blob.
606+ fn build_schema_bytes ( fields : Vec < Field > ) -> PyResult < Vec < u8 > > {
607+ schema_to_ipc_bytes ( & Schema :: new ( fields) ) . map_err ( arrow_to_py_err)
608+ }
609+
567610/// Decode the per-arg `DataType`s the encoder wrote via
568611/// [`build_input_schema_bytes`].
569612fn read_input_dtypes ( bytes : & [ u8 ] ) -> PyResult < Vec < DataType > > {
@@ -642,6 +685,200 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
642685 . map ( |cached| cached. bind ( py) . clone ( ) )
643686}
644687
688+ // =============================================================================
689+ // Shared Python window UDF encode / decode helpers
690+ //
691+ // Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes,
692+ // return_schema_bytes, volatility_str)`. The evaluator factory is the
693+ // Python callable that produces a new evaluator instance per partition.
694+ // =============================================================================
695+
696+ pub ( crate ) fn try_encode_python_window_udf ( node : & WindowUDF , buf : & mut Vec < u8 > ) -> Result < bool > {
697+ let Some ( py_udf) = node. inner ( ) . downcast_ref :: < PythonFunctionWindowUDF > ( ) else {
698+ return Ok ( false ) ;
699+ } ;
700+
701+ Python :: attach ( |py| -> Result < bool > {
702+ let py_version = current_python_version ( py)
703+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
704+ let bytes = encode_python_window_udf ( py, py_udf)
705+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
706+ write_wire_header ( buf, PY_WINDOW_UDF_FAMILY , py_version) ;
707+ buf. extend_from_slice ( & bytes) ;
708+ Ok ( true )
709+ } )
710+ }
711+
712+ pub ( crate ) fn try_decode_python_window_udf ( buf : & [ u8 ] ) -> Result < Option < Arc < WindowUDF > > > {
713+ Python :: attach ( |py| -> Result < Option < Arc < WindowUDF > > > {
714+ let py_version = current_python_version ( py)
715+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
716+ let Some ( payload) = strip_wire_header ( buf, PY_WINDOW_UDF_FAMILY , "window UDF" , py_version) ?
717+ else {
718+ return Ok ( None ) ;
719+ } ;
720+ let udf = decode_python_window_udf ( py, payload)
721+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
722+ Ok ( Some ( Arc :: new ( WindowUDF :: new_from_impl ( udf) ) ) )
723+ } )
724+ }
725+
726+ fn encode_python_window_udf ( py : Python < ' _ > , udf : & PythonFunctionWindowUDF ) -> PyResult < Vec < u8 > > {
727+ let signature = WindowUDFImpl :: signature ( udf) ;
728+ let input_dtypes = signature_input_dtypes ( signature, "PythonFunctionWindowUDF" ) ?;
729+ let input_schema_bytes = build_input_schema_bytes ( & input_dtypes) ?;
730+ let return_field = Field :: new ( "result" , udf. return_type ( ) . clone ( ) , true ) ;
731+ let return_schema_bytes = build_single_field_schema_bytes ( & return_field) ?;
732+ let volatility = volatility_wire_str ( signature. volatility ) ;
733+
734+ let payload = PyTuple :: new (
735+ py,
736+ [
737+ WindowUDFImpl :: name ( udf) . into_pyobject ( py) ?. into_any ( ) ,
738+ udf. evaluator ( ) . bind ( py) . clone ( ) . into_any ( ) ,
739+ PyBytes :: new ( py, & input_schema_bytes) . into_any ( ) ,
740+ PyBytes :: new ( py, & return_schema_bytes) . into_any ( ) ,
741+ volatility. into_pyobject ( py) ?. into_any ( ) ,
742+ ] ,
743+ ) ?;
744+
745+ cloudpickle ( py) ?
746+ . call_method1 ( "dumps" , ( payload, ) ) ?
747+ . extract :: < Vec < u8 > > ( )
748+ }
749+
750+ fn decode_python_window_udf ( py : Python < ' _ > , payload : & [ u8 ] ) -> PyResult < PythonFunctionWindowUDF > {
751+ let tuple = cloudpickle ( py) ?
752+ . call_method1 ( "loads" , ( PyBytes :: new ( py, payload) , ) ) ?
753+ . cast_into :: < PyTuple > ( ) ?;
754+
755+ let name: String = tuple. get_item ( 0 ) ?. extract ( ) ?;
756+ let evaluator: Py < PyAny > = tuple. get_item ( 1 ) ?. unbind ( ) ;
757+ let input_schema_bytes: Vec < u8 > = tuple. get_item ( 2 ) ?. extract ( ) ?;
758+ let return_schema_bytes: Vec < u8 > = tuple. get_item ( 3 ) ?. extract ( ) ?;
759+ let volatility_str: String = tuple. get_item ( 4 ) ?. extract ( ) ?;
760+
761+ let input_types = read_input_dtypes ( & input_schema_bytes) ?;
762+ let return_type = read_single_return_field ( & return_schema_bytes, "PythonFunctionWindowUDF" ) ?
763+ . data_type ( )
764+ . clone ( ) ;
765+ let volatility = parse_volatility_str ( & volatility_str) ?;
766+
767+ Ok ( PythonFunctionWindowUDF :: new (
768+ name,
769+ evaluator,
770+ input_types,
771+ return_type,
772+ volatility,
773+ ) )
774+ }
775+
776+ // =============================================================================
777+ // Shared Python aggregate UDF encode / decode helpers
778+ //
779+ // Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
780+ // return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator
781+ // factory is the Python callable that produces a new accumulator instance
782+ // per partition.
783+ // =============================================================================
784+
785+ pub ( crate ) fn try_encode_python_agg_udf ( node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < bool > {
786+ let Some ( py_udf) = node. inner ( ) . downcast_ref :: < PythonFunctionAggregateUDF > ( ) else {
787+ return Ok ( false ) ;
788+ } ;
789+
790+ Python :: attach ( |py| -> Result < bool > {
791+ let py_version = current_python_version ( py)
792+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
793+ let bytes = encode_python_agg_udf ( py, py_udf)
794+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
795+ write_wire_header ( buf, PY_AGG_UDF_FAMILY , py_version) ;
796+ buf. extend_from_slice ( & bytes) ;
797+ Ok ( true )
798+ } )
799+ }
800+
801+ pub ( crate ) fn try_decode_python_agg_udf ( buf : & [ u8 ] ) -> Result < Option < Arc < AggregateUDF > > > {
802+ Python :: attach ( |py| -> Result < Option < Arc < AggregateUDF > > > {
803+ let py_version = current_python_version ( py)
804+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
805+ let Some ( payload) = strip_wire_header ( buf, PY_AGG_UDF_FAMILY , "aggregate UDF" , py_version) ?
806+ else {
807+ return Ok ( None ) ;
808+ } ;
809+ let udf = decode_python_agg_udf ( py, payload)
810+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
811+ Ok ( Some ( Arc :: new ( AggregateUDF :: new_from_impl ( udf) ) ) )
812+ } )
813+ }
814+
815+ fn encode_python_agg_udf ( py : Python < ' _ > , udf : & PythonFunctionAggregateUDF ) -> PyResult < Vec < u8 > > {
816+ let signature = AggregateUDFImpl :: signature ( udf) ;
817+ let input_dtypes = signature_input_dtypes ( signature, "PythonFunctionAggregateUDF" ) ?;
818+ let input_schema_bytes = build_input_schema_bytes ( & input_dtypes) ?;
819+ let return_field = Field :: new ( "result" , udf. return_type ( ) . clone ( ) , true ) ;
820+ let return_schema_bytes = build_single_field_schema_bytes ( & return_field) ?;
821+ let state_fields: Vec < Field > = udf
822+ . state_fields_ref ( )
823+ . iter ( )
824+ . map ( |f| f. as_ref ( ) . clone ( ) )
825+ . collect ( ) ;
826+ let state_schema_bytes = build_schema_bytes ( state_fields) ?;
827+ let volatility = volatility_wire_str ( signature. volatility ) ;
828+
829+ let payload = PyTuple :: new (
830+ py,
831+ [
832+ AggregateUDFImpl :: name ( udf) . into_pyobject ( py) ?. into_any ( ) ,
833+ udf. accumulator ( ) . bind ( py) . clone ( ) . into_any ( ) ,
834+ PyBytes :: new ( py, & input_schema_bytes) . into_any ( ) ,
835+ PyBytes :: new ( py, & return_schema_bytes) . into_any ( ) ,
836+ PyBytes :: new ( py, & state_schema_bytes) . into_any ( ) ,
837+ volatility. into_pyobject ( py) ?. into_any ( ) ,
838+ ] ,
839+ ) ?;
840+
841+ cloudpickle ( py) ?
842+ . call_method1 ( "dumps" , ( payload, ) ) ?
843+ . extract :: < Vec < u8 > > ( )
844+ }
845+
846+ fn decode_python_agg_udf ( py : Python < ' _ > , payload : & [ u8 ] ) -> PyResult < PythonFunctionAggregateUDF > {
847+ let tuple = cloudpickle ( py) ?
848+ . call_method1 ( "loads" , ( PyBytes :: new ( py, payload) , ) ) ?
849+ . cast_into :: < PyTuple > ( ) ?;
850+
851+ let name: String = tuple. get_item ( 0 ) ?. extract ( ) ?;
852+ let accumulator: Py < PyAny > = tuple. get_item ( 1 ) ?. unbind ( ) ;
853+ let input_schema_bytes: Vec < u8 > = tuple. get_item ( 2 ) ?. extract ( ) ?;
854+ let return_schema_bytes: Vec < u8 > = tuple. get_item ( 3 ) ?. extract ( ) ?;
855+ let state_schema_bytes: Vec < u8 > = tuple. get_item ( 4 ) ?. extract ( ) ?;
856+ let volatility_str: String = tuple. get_item ( 5 ) ?. extract ( ) ?;
857+
858+ let input_types = read_input_dtypes ( & input_schema_bytes) ?;
859+ let return_type = read_single_return_field ( & return_schema_bytes, "PythonFunctionAggregateUDF" ) ?
860+ . data_type ( )
861+ . clone ( ) ;
862+ // Preserve the encoded state field metadata (names, nullability,
863+ // arbitrary key/value attributes) so the post-decode UDF reports
864+ // the same state schema as the sender's instance — important for
865+ // accumulators whose `StateFieldsArgs` consumers key off names or
866+ // nullability rather than positional `DataType`.
867+ let state_schema = schema_from_ipc_bytes ( & state_schema_bytes) . map_err ( arrow_to_py_err) ?;
868+ let state_fields: Vec < arrow:: datatypes:: FieldRef > =
869+ state_schema. fields ( ) . iter ( ) . cloned ( ) . collect ( ) ;
870+ let volatility = parse_volatility_str ( & volatility_str) ?;
871+
872+ Ok ( PythonFunctionAggregateUDF :: from_parts (
873+ name,
874+ accumulator,
875+ input_types,
876+ return_type,
877+ state_fields,
878+ volatility,
879+ ) )
880+ }
881+
645882#[ cfg( test) ]
646883mod wire_header_tests {
647884 use super :: * ;
@@ -729,7 +966,7 @@ mod wire_header_tests {
729966 }
730967
731968 #[ test]
732- fn write_then_strip_round_trips_payload ( ) {
969+ fn write_then_strip_round_trips_scalar_payload ( ) {
733970 let mut buf = Vec :: new ( ) ;
734971 write_wire_header ( & mut buf, PY_SCALAR_UDF_FAMILY , TEST_PY ) ;
735972 buf. extend_from_slice ( b"scalar-payload" ) ;
@@ -739,4 +976,39 @@ mod wire_header_tests {
739976 . unwrap ( ) ;
740977 assert_eq ! ( payload, b"scalar-payload" ) ;
741978 }
979+
980+ #[ test]
981+ fn write_then_strip_round_trips_agg_payload ( ) {
982+ let mut buf = Vec :: new ( ) ;
983+ write_wire_header ( & mut buf, PY_AGG_UDF_FAMILY , TEST_PY ) ;
984+ buf. extend_from_slice ( b"agg-payload" ) ;
985+
986+ let payload = strip_wire_header ( & buf, PY_AGG_UDF_FAMILY , "aggregate UDF" , TEST_PY )
987+ . unwrap ( )
988+ . unwrap ( ) ;
989+ assert_eq ! ( payload, b"agg-payload" ) ;
990+ }
991+
992+ #[ test]
993+ fn write_then_strip_round_trips_window_payload ( ) {
994+ let mut buf = Vec :: new ( ) ;
995+ write_wire_header ( & mut buf, PY_WINDOW_UDF_FAMILY , TEST_PY ) ;
996+ buf. extend_from_slice ( b"window-payload" ) ;
997+
998+ let payload = strip_wire_header ( & buf, PY_WINDOW_UDF_FAMILY , "window UDF" , TEST_PY )
999+ . unwrap ( )
1000+ . unwrap ( ) ;
1001+ assert_eq ! ( payload, b"window-payload" ) ;
1002+ }
1003+
1004+ #[ test]
1005+ fn strip_does_not_match_a_different_family ( ) {
1006+ let mut buf = Vec :: new ( ) ;
1007+ write_wire_header ( & mut buf, PY_SCALAR_UDF_FAMILY , TEST_PY ) ;
1008+ buf. extend_from_slice ( b"payload" ) ;
1009+ assert ! ( matches!(
1010+ strip_wire_header( & buf, PY_WINDOW_UDF_FAMILY , "window UDF" , TEST_PY ) ,
1011+ Ok ( None )
1012+ ) ) ;
1013+ }
7421014}
0 commit comments