Skip to content

Commit 398b388

Browse files
timsaucerclaude
andcommitted
feat: accept pyarrow DataType in arrow_try_cast
Mirrors arrow_cast: arrow_try_cast now accepts `pa.DataType` in addition to `str` and `Expr`. Adds `Expr.try_cast(pa.DataType)` PyO3 binding for the pyarrow-type routing path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 04979ea commit 398b388

3 files changed

Lines changed: 40 additions & 2 deletions

File tree

crates/core/src/expr.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ impl PyExpr {
358358
expr.into()
359359
}
360360

361+
pub fn try_cast(&self, to: PyArrowType<DataType>) -> PyExpr {
362+
let expr = Expr::TryCast(TryCast::new(Box::new(self.expr.clone()), to.0));
363+
expr.into()
364+
}
365+
361366
#[pyo3(signature = (low, high, negated=false))]
362367
pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr {
363368
let expr = Expr::Between(Between::new(

python/datafusion/expr.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,28 @@ def cast(self, to: pa.DataType[Any] | type) -> Expr:
894894

895895
return Expr(self.expr.cast(to))
896896

897+
def try_cast(self, to: pa.DataType[Any] | type) -> Expr:
898+
"""Cast to a new data type, returning NULL on failure.
899+
900+
Like :py:meth:`cast` but produces NULL instead of erroring when the
901+
cast cannot be performed for a given row.
902+
903+
Examples:
904+
>>> ctx = dfn.SessionContext()
905+
>>> df = ctx.from_pydict({"a": ["oops"]})
906+
>>> result = df.select(col("a").try_cast(pa.float64()).alias("c"))
907+
>>> result.collect_column("c")[0].as_py() is None
908+
True
909+
"""
910+
if not isinstance(to, pa.DataType):
911+
try:
912+
to = self._to_pyarrow_types[to]
913+
except KeyError as err:
914+
error_msg = "Expected instance of pyarrow.DataType or builtins.type"
915+
raise TypeError(error_msg) from err
916+
917+
return Expr(self.expr.try_cast(to))
918+
897919
def between(self, low: Any, high: Any, negated: bool = False) -> Expr:
898920
"""Returns ``True`` if this expression is between a given range.
899921

python/datafusion/functions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2934,12 +2934,13 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
29342934
return Expr(f.arrow_cast(expr.expr, data_type.expr))
29352935

29362936

2937-
def arrow_try_cast(expr: Expr, data_type: Expr | str) -> Expr:
2937+
def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
29382938
"""Casts an expression to a specified data type, returning NULL on failure.
29392939
29402940
Like :py:func:`arrow_cast` but produces NULL instead of erroring when the
29412941
cast cannot be performed. The ``data_type`` may be a string in DataFusion
2942-
type syntax (for example ``"Float64"``) or an ``Expr`` of string type.
2942+
type syntax (for example ``"Float64"``), a ``pyarrow.DataType``, or an
2943+
``Expr`` of string type.
29432944
29442945
Examples:
29452946
>>> ctx = dfn.SessionContext()
@@ -2949,7 +2950,17 @@ def arrow_try_cast(expr: Expr, data_type: Expr | str) -> Expr:
29492950
... )
29502951
>>> result.collect_column("c")[0].as_py() is None
29512952
True
2953+
2954+
>>> result = df.select(
2955+
... dfn.functions.arrow_try_cast(
2956+
... dfn.col("a"), data_type=pa.float64()
2957+
... ).alias("c")
2958+
... )
2959+
>>> result.collect_column("c")[0].as_py() is None
2960+
True
29522961
"""
2962+
if isinstance(data_type, pa.DataType):
2963+
return expr.try_cast(data_type)
29532964
if isinstance(data_type, str):
29542965
data_type = Expr.string_literal(data_type)
29552966
return Expr(f.arrow_try_cast(expr.expr, data_type.expr))

0 commit comments

Comments
 (0)