Skip to content

Commit 900634b

Browse files
committed
fix copilot
1 parent 361e165 commit 900634b

2 files changed

Lines changed: 57 additions & 16 deletions

File tree

datafusion/functions/src/math/ceil.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@ impl ScalarUDFImpl for CeilFunc {
192192

193193
fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
194194
let [input] = inputs else {
195-
return Interval::make_unbounded(&DataType::Float64);
195+
let data_type = inputs
196+
.first()
197+
.map(|i| i.data_type())
198+
.unwrap_or(DataType::Float64);
199+
return Interval::make_unbounded(&data_type);
196200
};
197201
let data_type = input.data_type();
198202
match (ceil_scalar(input.lower()), ceil_scalar(input.upper())) {
@@ -208,7 +212,7 @@ impl ScalarUDFImpl for CeilFunc {
208212
inputs: &[&Interval],
209213
) -> Result<Option<Vec<Interval>>> {
210214
let [input_interval] = inputs else {
211-
return Ok(Some(vec![]));
215+
return Ok(Some(inputs.iter().map(|i| (*i).clone()).collect()));
212216
};
213217
// ceil(x) ∈ [N, M] → x ∈ (N−1, M] — conservative closed: [N−1, M]
214218
let lo = match interval.lower() {
@@ -234,7 +238,7 @@ impl ScalarUDFImpl for CeilFunc {
234238
let constraint = Interval::try_new(lo, hi)?;
235239
Ok(input_interval.intersect(constraint)?.map(|r| vec![r]))
236240
}
237-
_ => Ok(Some(vec![])),
241+
_ => Ok(Some(vec![(*input_interval).clone()])),
238242
}
239243
}
240244

datafusion/physical-expr/src/intervals/utils.rs

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,11 @@ mod tests {
207207
use crate::expressions::{Column, Literal};
208208
use crate::scalar_function::ScalarFunctionExpr;
209209
use arrow::datatypes::{Field, Schema};
210-
use datafusion_common::ScalarValue;
211210
use datafusion_common::config::ConfigOptions;
211+
use datafusion_common::{Result, ScalarValue};
212+
use datafusion_expr::{
213+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
214+
};
212215

213216
fn f64_schema() -> SchemaRef {
214217
Arc::new(Schema::new(vec![Field::new("x", DataType::Float64, false)]))
@@ -241,6 +244,48 @@ mod tests {
241244
))
242245
}
243246

247+
/// A minimal UDF whose declared return type is Utf8, used to test that
248+
/// check_support rejects functions with unsupported return types without
249+
/// relying on an invalid ceil-returns-Utf8 combination.
250+
#[derive(Debug, PartialEq, Eq, Hash)]
251+
struct Utf8UDF {
252+
signature: Signature,
253+
}
254+
255+
impl Utf8UDF {
256+
fn new() -> Self {
257+
Self {
258+
signature: Signature::uniform(
259+
1,
260+
vec![DataType::Float64],
261+
Volatility::Immutable,
262+
),
263+
}
264+
}
265+
}
266+
267+
impl ScalarUDFImpl for Utf8UDF {
268+
fn as_any(&self) -> &dyn std::any::Any {
269+
self
270+
}
271+
fn name(&self) -> &str {
272+
"utf8_udf"
273+
}
274+
fn signature(&self) -> &Signature {
275+
&self.signature
276+
}
277+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
278+
Ok(DataType::Utf8)
279+
}
280+
fn invoke_with_args(&self, _: ScalarFunctionArgs) -> Result<ColumnarValue> {
281+
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
282+
}
283+
}
284+
285+
fn utf8_udf() -> Arc<datafusion_expr::ScalarUDF> {
286+
Arc::new(datafusion_expr::ScalarUDF::from(Utf8UDF::new()))
287+
}
288+
244289
#[test]
245290
fn test_check_support_scalar_fn_supported_return_type() {
246291
// ceil(x) returns Float64 — both return type and child are supported
@@ -255,13 +300,9 @@ mod tests {
255300

256301
#[test]
257302
fn test_check_support_scalar_fn_unsupported_return_type() {
258-
// A UDF that returns Utf8 — not in is_datatype_supported
303+
// utf8_udf(x) returns Utf8 — not in is_datatype_supported
259304
let schema = f64_schema();
260-
let expr = scalar_fn_expr(
261-
datafusion_functions::math::ceil(),
262-
vec![col_x()],
263-
DataType::Utf8,
264-
);
305+
let expr = scalar_fn_expr(utf8_udf(), vec![col_x()], DataType::Utf8);
265306
assert!(!check_support(&expr, &schema));
266307
}
267308

@@ -294,13 +335,9 @@ mod tests {
294335

295336
#[test]
296337
fn test_check_support_scalar_fn_in_binary_expr_unsupported_return() {
297-
// f(x) > 5.0 where f returns Utf8 — should be false
338+
// utf8_udf(x) > 5.0 where f returns Utf8 — should be false
298339
let schema = f64_schema();
299-
let fn_expr = scalar_fn_expr(
300-
datafusion_functions::math::ceil(),
301-
vec![col_x()],
302-
DataType::Utf8,
303-
);
340+
let fn_expr = scalar_fn_expr(utf8_udf(), vec![col_x()], DataType::Utf8);
304341
let expr: Arc<dyn PhysicalExpr> =
305342
Arc::new(BinaryExpr::new(fn_expr, Operator::Gt, lit_f64(5.0)));
306343
assert!(!check_support(&expr, &schema));

0 commit comments

Comments
 (0)