Skip to content

Commit 3226978

Browse files
timsaucerclaude
andcommitted
feat: inline encoding for Python aggregate and window UDFs
Extends the PythonLogicalCodec / PythonPhysicalCodec inline encoding introduced for scalar UDFs to also cover Python-defined aggregate and window UDFs. The cloudpickle tuple shape per family is: DFPYUDA (agg) (name, accumulator_factory, input_schema_bytes, return_schema_bytes, state_schema_bytes, volatility_str) DFPYUDW (window) (name, evaluator_factory, input_schema_bytes, return_schema_bytes, volatility_str) Same wire-framing as scalar (family magic + version byte + cloudpickle blob), same schema serde (arrow-rs native IPC), same cached cloudpickle handle. The agg state schema is encoded as a full IPC schema so the post-decode UDF reports the same names + nullability + metadata as the sender — relevant for accumulators whose StateFieldsArgs consumers key off names rather than positional DataType. Required restructuring two existing UDF impls so the codec can grab the Python callable directly: * udaf.rs: replaces create_udaf + AccumulatorFactoryFunction closure with a named PythonFunctionAggregateUDF that stores the Py<PyAny> accumulator factory. Synthesizes state_{i} field names when the Python constructor passes only Vec<DataType>; from_parts preserves the full state schema on the decode side. * udwf.rs: renames MultiColumnWindowUDF -> PythonFunctionWindowUDF, drops the PartitionEvaluatorFactory PtrEq wrapper, stores the Py<PyAny> evaluator directly. PartialEq and Hash get the same pointer-identity fast path + debug-log exception handling already on PythonFunctionScalarUDF. User-facing surface: * AggregateUDF.name and WindowUDF.name properties (parallel to the ScalarUDF.name shipped in PR1). * Existing UDAF/UDWF construction paths are unchanged. The per-session with_python_udf_inlining toggle, sender-side context, strict refusal, and user-guide docs land in PRs 3-4 of this series. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent afaeccb commit 3226978

7 files changed

Lines changed: 675 additions & 81 deletions

File tree

crates/core/src/codec.rs

Lines changed: 277 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,14 @@
6666
//! | Layer + kind | Family prefix |
6767
//! | ----------------------------- | ------------- |
6868
//! | `PythonLogicalCodec` scalar | `DFPYUDF` |
69+
//! | `PythonLogicalCodec` agg | `DFPYUDA` |
70+
//! | `PythonLogicalCodec` window | `DFPYUDW` |
6971
//! | `PythonPhysicalCodec` scalar | `DFPYUDF` |
72+
//! | `PythonPhysicalCodec` agg | `DFPYUDA` |
73+
//! | `PythonPhysicalCodec` window | `DFPYUDW` |
7074
//! | User FFI extension codec | user-chosen |
7175
//! | Default codec | (none) |
7276
//!
73-
//! Aggregate and window UDF families are reserved for follow-on work.
74-
//!
7577
//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported
7678
//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`.
7779
//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape
@@ -94,8 +96,8 @@ use datafusion::datasource::TableProvider;
9496
use datafusion::datasource::file_format::FileFormatFactory;
9597
use datafusion::execution::TaskContext;
9698
use datafusion::logical_expr::{
97-
AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
98-
Volatility, WindowUDF,
99+
AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature,
100+
TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
99101
};
100102
use datafusion::physical_expr::PhysicalExpr;
101103
use datafusion::physical_plan::ExecutionPlan;
@@ -105,7 +107,9 @@ use pyo3::prelude::*;
105107
use pyo3::sync::PyOnceLock;
106108
use pyo3::types::{PyBytes, PyTuple};
107109

110+
use crate::udaf::PythonFunctionAggregateUDF;
108111
use crate::udf::PythonFunctionScalarUDF;
112+
use crate::udwf::PythonFunctionWindowUDF;
109113

110114
// Wire-format framing for inlined Python UDF payloads.
111115
//
@@ -126,6 +130,16 @@ use crate::udf::PythonFunctionScalarUDF;
126130
/// volatility).
127131
pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF";
128132

133+
/// Family prefix for an inlined Python aggregate UDF
134+
/// (cloudpickled tuple of name, accumulator factory, input schema,
135+
/// return type, state types schema, volatility).
136+
pub(crate) const PY_AGG_UDF_FAMILY: &[u8] = b"DFPYUDA";
137+
138+
/// Family prefix for an inlined Python window UDF
139+
/// (cloudpickled tuple of name, evaluator factory, input schema,
140+
/// return type, volatility).
141+
pub(crate) const PY_WINDOW_UDF_FAMILY: &[u8] = b"DFPYUDW";
142+
129143
/// Wire-format version this build emits.
130144
pub(crate) const WIRE_VERSION_CURRENT: u8 = 1;
131145

@@ -299,18 +313,30 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
299313
}
300314

301315
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
316+
if try_encode_python_agg_udf(node, buf)? {
317+
return Ok(());
318+
}
302319
self.inner.try_encode_udaf(node, buf)
303320
}
304321

305322
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
323+
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
324+
return Ok(udaf);
325+
}
306326
self.inner.try_decode_udaf(name, buf)
307327
}
308328

309329
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
330+
if try_encode_python_window_udf(node, buf)? {
331+
return Ok(());
332+
}
310333
self.inner.try_encode_udwf(node, buf)
311334
}
312335

313336
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
337+
if let Some(udwf) = try_decode_python_window_udf(buf)? {
338+
return Ok(udwf);
339+
}
314340
self.inner.try_decode_udwf(name, buf)
315341
}
316342
}
@@ -389,18 +415,30 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
389415
}
390416

391417
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
418+
if try_encode_python_agg_udf(node, buf)? {
419+
return Ok(());
420+
}
392421
self.inner.try_encode_udaf(node, buf)
393422
}
394423

395424
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
425+
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
426+
return Ok(udaf);
427+
}
396428
self.inner.try_decode_udaf(name, buf)
397429
}
398430

399431
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
432+
if try_encode_python_window_udf(node, buf)? {
433+
return Ok(());
434+
}
400435
self.inner.try_encode_udwf(node, buf)
401436
}
402437

403438
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
439+
if let Some(udwf) = try_decode_python_window_udf(buf)? {
440+
return Ok(udwf);
441+
}
404442
self.inner.try_decode_udwf(name, buf)
405443
}
406444
}
@@ -564,6 +602,11 @@ fn build_single_field_schema_bytes(field: &Field) -> PyResult<Vec<u8>> {
564602
schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err)
565603
}
566604

605+
/// Emit a multi-field IPC schema blob.
606+
fn build_schema_bytes(fields: Vec<Field>) -> PyResult<Vec<u8>> {
607+
schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err)
608+
}
609+
567610
/// Decode the per-arg `DataType`s the encoder wrote via
568611
/// [`build_input_schema_bytes`].
569612
fn read_input_dtypes(bytes: &[u8]) -> PyResult<Vec<DataType>> {
@@ -642,6 +685,200 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
642685
.map(|cached| cached.bind(py).clone())
643686
}
644687

688+
// =============================================================================
689+
// Shared Python window UDF encode / decode helpers
690+
//
691+
// Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes,
692+
// return_schema_bytes, volatility_str)`. The evaluator factory is the
693+
// Python callable that produces a new evaluator instance per partition.
694+
// =============================================================================
695+
696+
pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut Vec<u8>) -> Result<bool> {
697+
let Some(py_udf) = node.inner().downcast_ref::<PythonFunctionWindowUDF>() else {
698+
return Ok(false);
699+
};
700+
701+
Python::attach(|py| -> Result<bool> {
702+
let py_version = current_python_version(py)
703+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
704+
let bytes = encode_python_window_udf(py, py_udf)
705+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
706+
write_wire_header(buf, PY_WINDOW_UDF_FAMILY, py_version);
707+
buf.extend_from_slice(&bytes);
708+
Ok(true)
709+
})
710+
}
711+
712+
pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> Result<Option<Arc<WindowUDF>>> {
713+
Python::attach(|py| -> Result<Option<Arc<WindowUDF>>> {
714+
let py_version = current_python_version(py)
715+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
716+
let Some(payload) = strip_wire_header(buf, PY_WINDOW_UDF_FAMILY, "window UDF", py_version)?
717+
else {
718+
return Ok(None);
719+
};
720+
let udf = decode_python_window_udf(py, payload)
721+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
722+
Ok(Some(Arc::new(WindowUDF::new_from_impl(udf))))
723+
})
724+
}
725+
726+
fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> PyResult<Vec<u8>> {
727+
let signature = WindowUDFImpl::signature(udf);
728+
let input_dtypes = signature_input_dtypes(signature, "PythonFunctionWindowUDF")?;
729+
let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?;
730+
let return_field = Field::new("result", udf.return_type().clone(), true);
731+
let return_schema_bytes = build_single_field_schema_bytes(&return_field)?;
732+
let volatility = volatility_wire_str(signature.volatility);
733+
734+
let payload = PyTuple::new(
735+
py,
736+
[
737+
WindowUDFImpl::name(udf).into_pyobject(py)?.into_any(),
738+
udf.evaluator().bind(py).clone().into_any(),
739+
PyBytes::new(py, &input_schema_bytes).into_any(),
740+
PyBytes::new(py, &return_schema_bytes).into_any(),
741+
volatility.into_pyobject(py)?.into_any(),
742+
],
743+
)?;
744+
745+
cloudpickle(py)?
746+
.call_method1("dumps", (payload,))?
747+
.extract::<Vec<u8>>()
748+
}
749+
750+
fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionWindowUDF> {
751+
let tuple = cloudpickle(py)?
752+
.call_method1("loads", (PyBytes::new(py, payload),))?
753+
.cast_into::<PyTuple>()?;
754+
755+
let name: String = tuple.get_item(0)?.extract()?;
756+
let evaluator: Py<PyAny> = tuple.get_item(1)?.unbind();
757+
let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
758+
let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
759+
let volatility_str: String = tuple.get_item(4)?.extract()?;
760+
761+
let input_types = read_input_dtypes(&input_schema_bytes)?;
762+
let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionWindowUDF")?
763+
.data_type()
764+
.clone();
765+
let volatility = parse_volatility_str(&volatility_str)?;
766+
767+
Ok(PythonFunctionWindowUDF::new(
768+
name,
769+
evaluator,
770+
input_types,
771+
return_type,
772+
volatility,
773+
))
774+
}
775+
776+
// =============================================================================
777+
// Shared Python aggregate UDF encode / decode helpers
778+
//
779+
// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
780+
// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator
781+
// factory is the Python callable that produces a new accumulator instance
782+
// per partition.
783+
// =============================================================================
784+
785+
pub(crate) fn try_encode_python_agg_udf(node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<bool> {
786+
let Some(py_udf) = node.inner().downcast_ref::<PythonFunctionAggregateUDF>() else {
787+
return Ok(false);
788+
};
789+
790+
Python::attach(|py| -> Result<bool> {
791+
let py_version = current_python_version(py)
792+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
793+
let bytes = encode_python_agg_udf(py, py_udf)
794+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
795+
write_wire_header(buf, PY_AGG_UDF_FAMILY, py_version);
796+
buf.extend_from_slice(&bytes);
797+
Ok(true)
798+
})
799+
}
800+
801+
pub(crate) fn try_decode_python_agg_udf(buf: &[u8]) -> Result<Option<Arc<AggregateUDF>>> {
802+
Python::attach(|py| -> Result<Option<Arc<AggregateUDF>>> {
803+
let py_version = current_python_version(py)
804+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
805+
let Some(payload) = strip_wire_header(buf, PY_AGG_UDF_FAMILY, "aggregate UDF", py_version)?
806+
else {
807+
return Ok(None);
808+
};
809+
let udf = decode_python_agg_udf(py, payload)
810+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
811+
Ok(Some(Arc::new(AggregateUDF::new_from_impl(udf))))
812+
})
813+
}
814+
815+
fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult<Vec<u8>> {
816+
let signature = AggregateUDFImpl::signature(udf);
817+
let input_dtypes = signature_input_dtypes(signature, "PythonFunctionAggregateUDF")?;
818+
let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?;
819+
let return_field = Field::new("result", udf.return_type().clone(), true);
820+
let return_schema_bytes = build_single_field_schema_bytes(&return_field)?;
821+
let state_fields: Vec<Field> = udf
822+
.state_fields_ref()
823+
.iter()
824+
.map(|f| f.as_ref().clone())
825+
.collect();
826+
let state_schema_bytes = build_schema_bytes(state_fields)?;
827+
let volatility = volatility_wire_str(signature.volatility);
828+
829+
let payload = PyTuple::new(
830+
py,
831+
[
832+
AggregateUDFImpl::name(udf).into_pyobject(py)?.into_any(),
833+
udf.accumulator().bind(py).clone().into_any(),
834+
PyBytes::new(py, &input_schema_bytes).into_any(),
835+
PyBytes::new(py, &return_schema_bytes).into_any(),
836+
PyBytes::new(py, &state_schema_bytes).into_any(),
837+
volatility.into_pyobject(py)?.into_any(),
838+
],
839+
)?;
840+
841+
cloudpickle(py)?
842+
.call_method1("dumps", (payload,))?
843+
.extract::<Vec<u8>>()
844+
}
845+
846+
fn decode_python_agg_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionAggregateUDF> {
847+
let tuple = cloudpickle(py)?
848+
.call_method1("loads", (PyBytes::new(py, payload),))?
849+
.cast_into::<PyTuple>()?;
850+
851+
let name: String = tuple.get_item(0)?.extract()?;
852+
let accumulator: Py<PyAny> = tuple.get_item(1)?.unbind();
853+
let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
854+
let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
855+
let state_schema_bytes: Vec<u8> = tuple.get_item(4)?.extract()?;
856+
let volatility_str: String = tuple.get_item(5)?.extract()?;
857+
858+
let input_types = read_input_dtypes(&input_schema_bytes)?;
859+
let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionAggregateUDF")?
860+
.data_type()
861+
.clone();
862+
// Preserve the encoded state field metadata (names, nullability,
863+
// arbitrary key/value attributes) so the post-decode UDF reports
864+
// the same state schema as the sender's instance — important for
865+
// accumulators whose `StateFieldsArgs` consumers key off names or
866+
// nullability rather than positional `DataType`.
867+
let state_schema = schema_from_ipc_bytes(&state_schema_bytes).map_err(arrow_to_py_err)?;
868+
let state_fields: Vec<arrow::datatypes::FieldRef> =
869+
state_schema.fields().iter().cloned().collect();
870+
let volatility = parse_volatility_str(&volatility_str)?;
871+
872+
Ok(PythonFunctionAggregateUDF::from_parts(
873+
name,
874+
accumulator,
875+
input_types,
876+
return_type,
877+
state_fields,
878+
volatility,
879+
))
880+
}
881+
645882
#[cfg(test)]
646883
mod wire_header_tests {
647884
use super::*;
@@ -729,7 +966,7 @@ mod wire_header_tests {
729966
}
730967

731968
#[test]
732-
fn write_then_strip_round_trips_payload() {
969+
fn write_then_strip_round_trips_scalar_payload() {
733970
let mut buf = Vec::new();
734971
write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, TEST_PY);
735972
buf.extend_from_slice(b"scalar-payload");
@@ -739,4 +976,39 @@ mod wire_header_tests {
739976
.unwrap();
740977
assert_eq!(payload, b"scalar-payload");
741978
}
979+
980+
#[test]
981+
fn write_then_strip_round_trips_agg_payload() {
982+
let mut buf = Vec::new();
983+
write_wire_header(&mut buf, PY_AGG_UDF_FAMILY, TEST_PY);
984+
buf.extend_from_slice(b"agg-payload");
985+
986+
let payload = strip_wire_header(&buf, PY_AGG_UDF_FAMILY, "aggregate UDF", TEST_PY)
987+
.unwrap()
988+
.unwrap();
989+
assert_eq!(payload, b"agg-payload");
990+
}
991+
992+
#[test]
993+
fn write_then_strip_round_trips_window_payload() {
994+
let mut buf = Vec::new();
995+
write_wire_header(&mut buf, PY_WINDOW_UDF_FAMILY, TEST_PY);
996+
buf.extend_from_slice(b"window-payload");
997+
998+
let payload = strip_wire_header(&buf, PY_WINDOW_UDF_FAMILY, "window UDF", TEST_PY)
999+
.unwrap()
1000+
.unwrap();
1001+
assert_eq!(payload, b"window-payload");
1002+
}
1003+
1004+
#[test]
1005+
fn strip_does_not_match_a_different_family() {
1006+
let mut buf = Vec::new();
1007+
write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, TEST_PY);
1008+
buf.extend_from_slice(b"payload");
1009+
assert!(matches!(
1010+
strip_wire_header(&buf, PY_WINDOW_UDF_FAMILY, "window UDF", TEST_PY),
1011+
Ok(None)
1012+
));
1013+
}
7421014
}

0 commit comments

Comments
 (0)