Skip to content

Commit 6a770aa

Browse files
adriangbclaudemartin-g
authored
feat: add cast_to_type UDF for type-based casting (#21322)
## Which issue does this PR close? N/A — new feature ## Rationale for this change DuckDB provides a [`cast_to_type(expression, reference)`](https://duckdb.org/docs/current/sql/expressions/cast#cast_to_type-function) function that casts the first argument to the data type of the second argument. This is useful in macros and generic SQL where types need to be preserved or matched dynamically. This PR adds the equivalent function to DataFusion, along with a fallible `try_cast_to_type` variant. ## What changes are included in this PR? - New `cast_to_type` scalar UDF in `datafusion/functions/src/core/cast_to_type.rs` - Takes two arguments: the expression to cast, and a reference expression whose **type** (not value) determines the target cast type - Uses `return_field_from_args` to infer return type from the second argument's data type - `simplify()` rewrites to `Expr::Cast` (or no-op if types match), so there is zero runtime overhead - New `try_cast_to_type` scalar UDF in `datafusion/functions/src/core/try_cast_to_type.rs` - Same as `cast_to_type` but returns NULL on cast failure instead of erroring - `simplify()` rewrites to `Expr::TryCast` - Output is always nullable - Registration of both functions in `datafusion/functions/src/core/mod.rs` ## Are these changes tested? Yes. New sqllogictest file `cast_to_type.slt` covering both functions: - Basic casts (string→int, string→double, int→string, int→double) - NULL handling - Same-type no-op - CASE expression as first argument - Arithmetic expression as first argument - Nested calls - Subquery as second argument - Column references as second argument - Boolean and date casts - Error on invalid cast (`cast_to_type`) vs NULL on invalid cast (`try_cast_to_type`) - Cross-column type matching ## Are there any user-facing changes? Two new SQL functions: - `cast_to_type(expression, reference)` — casts expression to the type of reference - `try_cast_to_type(expression, reference)` — same, but returns NULL on failure 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com>
1 parent 6c106ba commit 6a770aa

File tree

7 files changed

+711
-22
lines changed

7 files changed

+711
-22
lines changed

datafusion/functions/src/core/arrow_cast.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,20 @@ impl ScalarUDFImpl for ArrowCastFunc {
154154

155155
fn simplify(
156156
&self,
157-
mut args: Vec<Expr>,
157+
args: Vec<Expr>,
158158
info: &SimplifyContext,
159159
) -> Result<ExprSimplifyResult> {
160160
// convert this into a real cast
161-
let target_type = data_type_from_args(self.name(), &args)?;
162-
// remove second (type) argument
163-
args.pop().unwrap();
164-
let arg = args.pop().unwrap();
165-
166-
let source_type = info.get_data_type(&arg)?;
161+
let [source_arg, type_arg] = take_function_args(self.name(), args)?;
162+
let target_type = data_type_from_type_arg(self.name(), &type_arg)?;
163+
let source_type = info.get_data_type(&source_arg)?;
167164
let new_expr = if source_type == target_type {
168165
// the argument's data type is already the correct type
169-
arg
166+
source_arg
170167
} else {
171168
// Use an actual cast to get the correct type
172169
Expr::Cast(datafusion_expr::Cast {
173-
expr: Box::new(arg),
170+
expr: Box::new(source_arg),
174171
field: target_type.into_nullable_field_ref(),
175172
})
176173
};
@@ -183,10 +180,8 @@ impl ScalarUDFImpl for ArrowCastFunc {
183180
}
184181
}
185182

186-
/// Returns the requested type from the arguments
187-
pub(crate) fn data_type_from_args(name: &str, args: &[Expr]) -> Result<DataType> {
188-
let [_, type_arg] = take_function_args(name, args)?;
189-
183+
/// Returns the requested type from the type argument
184+
pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result<DataType> {
190185
let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else {
191186
return exec_err!(
192187
"{name} requires its second argument to be a constant string, got {:?}",

datafusion/functions/src/core/arrow_try_cast.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use datafusion_expr::{
3131
};
3232
use datafusion_macros::user_doc;
3333

34-
use super::arrow_cast::data_type_from_args;
34+
use super::arrow_cast::data_type_from_type_arg;
3535

3636
/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring.
3737
///
@@ -127,20 +127,18 @@ impl ScalarUDFImpl for ArrowTryCastFunc {
127127

128128
fn simplify(
129129
&self,
130-
mut args: Vec<Expr>,
130+
args: Vec<Expr>,
131131
info: &SimplifyContext,
132132
) -> Result<ExprSimplifyResult> {
133-
let target_type = data_type_from_args(self.name(), &args)?;
134-
// remove second (type) argument
135-
args.pop().unwrap();
136-
let arg = args.pop().unwrap();
133+
let [source_arg, type_arg] = take_function_args(self.name(), args)?;
134+
let target_type = data_type_from_type_arg(self.name(), &type_arg)?;
137135

138-
let source_type = info.get_data_type(&arg)?;
136+
let source_type = info.get_data_type(&source_arg)?;
139137
let new_expr = if source_type == target_type {
140-
arg
138+
source_arg
141139
} else {
142140
Expr::TryCast(datafusion_expr::TryCast {
143-
expr: Box::new(arg),
141+
expr: Box::new(source_arg),
144142
field: target_type.into_nullable_field_ref(),
145143
})
146144
};
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`CastToTypeFunc`]: Implementation of the `cast_to_type` function
19+
20+
use arrow::datatypes::{DataType, Field, FieldRef};
21+
use datafusion_common::{Result, internal_err, utils::take_function_args};
22+
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
23+
use datafusion_expr::{
24+
Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
25+
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
26+
};
27+
use datafusion_macros::user_doc;
28+
29+
/// Casts the first argument to the data type of the second argument.
30+
///
31+
/// Only the type of the second argument is used; its value is ignored.
32+
/// This is useful in macros or generic SQL where you need to preserve
33+
/// or match types dynamically.
34+
///
35+
/// For example:
36+
/// ```sql
37+
/// select cast_to_type('42', NULL::INTEGER);
38+
/// ```
39+
#[user_doc(
40+
doc_section(label = "Other Functions"),
41+
description = "Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored.",
42+
syntax_example = "cast_to_type(expression, reference)",
43+
sql_example = r#"```sql
44+
> select cast_to_type('42', NULL::INTEGER) as a;
45+
+----+
46+
| a |
47+
+----+
48+
| 42 |
49+
+----+
50+
51+
> select cast_to_type(1 + 2, NULL::DOUBLE) as b;
52+
+-----+
53+
| b |
54+
+-----+
55+
| 3.0 |
56+
+-----+
57+
```"#,
58+
argument(
59+
name = "expression",
60+
description = "The expression to cast. It can be a constant, column, or function, and any combination of operators."
61+
),
62+
argument(
63+
name = "reference",
64+
description = "Reference expression whose data type determines the target cast type. The value is ignored."
65+
)
66+
)]
67+
#[derive(Debug, PartialEq, Eq, Hash)]
68+
pub struct CastToTypeFunc {
69+
signature: Signature,
70+
}
71+
72+
impl Default for CastToTypeFunc {
73+
fn default() -> Self {
74+
Self::new()
75+
}
76+
}
77+
78+
impl CastToTypeFunc {
79+
pub fn new() -> Self {
80+
Self {
81+
signature: Signature::coercible(
82+
vec![
83+
Coercion::new_exact(TypeSignatureClass::Any),
84+
Coercion::new_exact(TypeSignatureClass::Any),
85+
],
86+
Volatility::Immutable,
87+
),
88+
}
89+
}
90+
}
91+
92+
impl ScalarUDFImpl for CastToTypeFunc {
93+
fn name(&self) -> &str {
94+
"cast_to_type"
95+
}
96+
97+
fn signature(&self) -> &Signature {
98+
&self.signature
99+
}
100+
101+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
102+
internal_err!("return_field_from_args should be called instead")
103+
}
104+
105+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
106+
let [source_field, reference_field] =
107+
take_function_args(self.name(), args.arg_fields)?;
108+
let target_type = reference_field.data_type().clone();
109+
// Nullability is inherited only from the first argument (the value
110+
// being cast). The second argument is used solely for its type, so
111+
// its own nullability is irrelevant. The one exception is when the
112+
// target type is Null – that type is inherently nullable.
113+
let nullable = source_field.is_nullable() || target_type == DataType::Null;
114+
Ok(Field::new(self.name(), target_type, nullable).into())
115+
}
116+
117+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
118+
internal_err!("cast_to_type should have been simplified to cast")
119+
}
120+
121+
fn simplify(
122+
&self,
123+
args: Vec<Expr>,
124+
info: &SimplifyContext,
125+
) -> Result<ExprSimplifyResult> {
126+
let [source_arg, type_arg] = take_function_args(self.name(), args)?;
127+
let target_type = info.get_data_type(&type_arg)?;
128+
let source_type = info.get_data_type(&source_arg)?;
129+
let new_expr = if source_type == target_type {
130+
// the argument's data type is already the correct type
131+
source_arg
132+
} else {
133+
let nullable = info.nullable(&source_arg)? || target_type == DataType::Null;
134+
// Use an actual cast to get the correct type
135+
Expr::Cast(datafusion_expr::Cast {
136+
expr: Box::new(source_arg),
137+
field: Field::new("", target_type, nullable).into(),
138+
})
139+
};
140+
Ok(ExprSimplifyResult::Simplified(new_expr))
141+
}
142+
143+
fn documentation(&self) -> Option<&Documentation> {
144+
self.doc()
145+
}
146+
}

datafusion/functions/src/core/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod arrow_cast;
2424
pub mod arrow_metadata;
2525
pub mod arrow_try_cast;
2626
pub mod arrowtypeof;
27+
pub mod cast_to_type;
2728
pub mod coalesce;
2829
pub mod expr_ext;
2930
pub mod getfield;
@@ -37,13 +38,16 @@ pub mod nvl2;
3738
pub mod overlay;
3839
pub mod planner;
3940
pub mod r#struct;
41+
pub mod try_cast_to_type;
4042
pub mod union_extract;
4143
pub mod union_tag;
4244
pub mod version;
4345

4446
// create UDFs
4547
make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast);
4648
make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast);
49+
make_udf_function!(cast_to_type::CastToTypeFunc, cast_to_type);
50+
make_udf_function!(try_cast_to_type::TryCastToTypeFunc, try_cast_to_type);
4751
make_udf_function!(nullif::NullIfFunc, nullif);
4852
make_udf_function!(nvl::NVLFunc, nvl);
4953
make_udf_function!(nvl2::NVL2Func, nvl2);
@@ -75,6 +79,14 @@ pub mod expr_fn {
7579
arrow_try_cast,
7680
"Casts a value to a specific Arrow data type, returning NULL if the cast fails",
7781
arg1 arg2
82+
),(
83+
cast_to_type,
84+
"Casts the first argument to the data type of the second argument",
85+
arg1 arg2
86+
),(
87+
try_cast_to_type,
88+
"Casts the first argument to the data type of the second argument, returning NULL on failure",
89+
arg1 arg2
7890
),(
7991
nvl,
8092
"Returns value2 if value1 is NULL; otherwise it returns value1",
@@ -147,6 +159,8 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
147159
nullif(),
148160
arrow_cast(),
149161
arrow_try_cast(),
162+
cast_to_type(),
163+
try_cast_to_type(),
150164
arrow_metadata(),
151165
nvl(),
152166
nvl2(),

0 commit comments

Comments
 (0)