Skip to content

Commit ff02516

Browse files
committed
fix: fix string to timestamp cast for UTC timestamps
1 parent 4ba0bcf commit ff02516

7 files changed

Lines changed: 219 additions & 72 deletions

File tree

native/spark-expr/src/conversion_funcs/string.rs

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,29 @@ macro_rules! cast_utf8_to_timestamp {
3737
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
3838
let len = $array.len();
3939
let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
40+
let mut cast_err: Option<SparkError> = None;
4041
for i in 0..len {
4142
if $array.is_null(i) {
4243
cast_array.append_null()
43-
} else if let Ok(Some(cast_value)) =
44-
$cast_method($array.value(i).trim(), $eval_mode, $tz)
45-
{
46-
cast_array.append_value(cast_value);
4744
} else {
48-
cast_array.append_null()
45+
match $cast_method($array.value(i).trim(), $eval_mode, $tz) {
46+
Ok(Some(cast_value)) => cast_array.append_value(cast_value),
47+
Ok(None) => cast_array.append_null(),
48+
Err(e) => {
49+
if $eval_mode == EvalMode::Ansi {
50+
cast_err = Some(e);
51+
break;
52+
}
53+
cast_array.append_null()
54+
}
55+
}
4956
}
5057
}
51-
let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
52-
result
58+
if let Some(e) = cast_err {
59+
Err(e)
60+
} else {
61+
Ok(Arc::new(cast_array.finish()) as ArrayRef)
62+
}
5363
}};
5464
}
5565

@@ -668,15 +678,13 @@ pub(crate) fn cast_string_to_timestamp(
668678
let tz = &timezone::Tz::from_str(timezone_str).unwrap();
669679

670680
let cast_array: ArrayRef = match to_type {
671-
DataType::Timestamp(_, _) => {
672-
cast_utf8_to_timestamp!(
673-
string_array,
674-
eval_mode,
675-
TimestampMicrosecondType,
676-
timestamp_parser,
677-
tz
678-
)
679-
}
681+
DataType::Timestamp(_, _) => cast_utf8_to_timestamp!(
682+
string_array,
683+
eval_mode,
684+
TimestampMicrosecondType,
685+
timestamp_parser,
686+
tz
687+
)?,
680688
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
681689
};
682690
Ok(cast_array)
@@ -961,6 +969,12 @@ fn get_timestamp_values<T: TimeZone>(
961969
) -> SparkResult<Option<i64>> {
962970
let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
963971
let year = values[0].parse::<i32>().unwrap_or_default();
972+
973+
// NaiveDate (used internally by chrono's with_ymd_and_hms) is bounded to ±262142.
974+
if !(-262143..=262142).contains(&year) {
975+
return Ok(None);
976+
}
977+
964978
let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
965979
let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
966980
let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
@@ -1004,7 +1018,7 @@ fn get_timestamp_values<T: TimeZone>(
10041018
.with_second(second)
10051019
.with_microsecond(microsecond),
10061020
_ => {
1007-
return Err(SparkError::CastInvalidValue {
1021+
return Err(SparkError::InvalidInputInCastToDatetime {
10081022
value: value.to_string(),
10091023
from_type: "STRING".to_string(),
10101024
to_type: "TIMESTAMP".to_string(),
@@ -1082,7 +1096,6 @@ fn parse_str_to_microsecond_timestamp<T: TimeZone>(
10821096
get_timestamp_values(value, "microsecond", tz)
10831097
}
10841098

1085-
// used in tests only
10861099
fn timestamp_parser<T: TimeZone>(
10871100
value: &str,
10881101
eval_mode: EvalMode,
@@ -1095,31 +1108,31 @@ fn timestamp_parser<T: TimeZone>(
10951108
// Define regex patterns and corresponding parsing functions
10961109
let patterns = &[
10971110
(
1098-
Regex::new(r"^\d{4,5}$").unwrap(),
1111+
Regex::new(r"^\d{4,7}$").unwrap(),
10991112
parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
11001113
),
11011114
(
1102-
Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1115+
Regex::new(r"^\d{4,7}-\d{2}$").unwrap(),
11031116
parse_str_to_month_timestamp,
11041117
),
11051118
(
1106-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1119+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}$").unwrap(),
11071120
parse_str_to_day_timestamp,
11081121
),
11091122
(
1110-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1123+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
11111124
parse_str_to_hour_timestamp,
11121125
),
11131126
(
1114-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1127+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
11151128
parse_str_to_minute_timestamp,
11161129
),
11171130
(
1118-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1131+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
11191132
parse_str_to_second_timestamp,
11201133
),
11211134
(
1122-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1135+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
11231136
parse_str_to_microsecond_timestamp,
11241137
),
11251138
(
@@ -1140,7 +1153,7 @@ fn timestamp_parser<T: TimeZone>(
11401153

11411154
if timestamp.is_none() {
11421155
return if eval_mode == EvalMode::Ansi {
1143-
Err(SparkError::CastInvalidValue {
1156+
Err(SparkError::InvalidInputInCastToDatetime {
11441157
value: value.to_string(),
11451158
from_type: "STRING".to_string(),
11461159
to_type: "TIMESTAMP".to_string(),
@@ -1150,12 +1163,7 @@ fn timestamp_parser<T: TimeZone>(
11501163
};
11511164
}
11521165

1153-
match timestamp {
1154-
Some(ts) => Ok(Some(ts)),
1155-
None => Err(SparkError::Internal(
1156-
"Failed to parse timestamp".to_string(),
1157-
)),
1158-
}
1166+
Ok(timestamp)
11591167
}
11601168

11611169
fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
@@ -1202,17 +1210,20 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
12021210
}
12031211

12041212
fn is_valid_digits(segment: i32, digits: usize) -> bool {
1205-
// An integer is able to represent a date within [+-]5 million years.
1213+
// NaiveDate is bounded to [-262142, 262142] (6 digits). We allow up to 7 digits to support
1214+
// leading-zero year strings like "0002020" (= year 2020), matching Spark's
1215+
// isValidDigits. Values outside the bounds are caught by an explicit bounds
1216+
// check below.
12061217
let max_digits_year = 7;
1207-
//year (segment 0) can be between 4 to 7 digits,
1208-
//month and day (segment 1 and 2) can be between 1 to 2 digits
1218+
// year (segment 0) can be between 4 to 7 digits,
1219+
// month and day (segment 1 and 2) can be between 1 to 2 digits
12091220
(segment == 0 && digits >= 4 && digits <= max_digits_year)
12101221
|| (segment != 0 && digits > 0 && digits <= 2)
12111222
}
12121223

12131224
fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
12141225
if eval_mode == EvalMode::Ansi {
1215-
Err(SparkError::CastInvalidValue {
1226+
Err(SparkError::InvalidInputInCastToDatetime {
12161227
value: date_str.to_string(),
12171228
from_type: "STRING".to_string(),
12181229
to_type: "DATE".to_string(),
@@ -1285,11 +1296,13 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
12851296

12861297
date_segments[current_segment as usize] = current_segment_value.0;
12871298

1288-
match NaiveDate::from_ymd_opt(
1289-
sign * date_segments[0],
1290-
date_segments[1] as u32,
1291-
date_segments[2] as u32,
1292-
) {
1299+
// Reject out-of-range years explicitly
1300+
let year = sign * date_segments[0];
1301+
if !(-262143..=262142).contains(&year) {
1302+
return Ok(None);
1303+
}
1304+
1305+
match NaiveDate::from_ymd_opt(year, date_segments[1] as u32, date_segments[2] as u32) {
12931306
Some(date) => {
12941307
let duration_since_epoch = date
12951308
.signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
@@ -1341,7 +1354,8 @@ mod tests {
13411354
TimestampMicrosecondType,
13421355
timestamp_parser,
13431356
tz
1344-
);
1357+
)
1358+
.unwrap();
13451359

13461360
assert_eq!(
13471361
result.data_type(),
@@ -1350,6 +1364,33 @@ mod tests {
13501364
assert_eq!(result.len(), 4);
13511365
}
13521366

1367+
#[test]
1368+
fn test_cast_string_to_timestamp_ansi_error() {
1369+
// In ANSI mode, an invalid timestamp string must produce an error rather than null.
1370+
let array: ArrayRef = Arc::new(StringArray::from(vec![
1371+
Some("2020-01-01T12:34:56.123456"),
1372+
Some("not_a_timestamp"),
1373+
]));
1374+
let tz = &timezone::Tz::from_str("UTC").unwrap();
1375+
let string_array = array
1376+
.as_any()
1377+
.downcast_ref::<GenericStringArray<i32>>()
1378+
.expect("Expected a string array");
1379+
1380+
let eval_mode = EvalMode::Ansi;
1381+
let result = cast_utf8_to_timestamp!(
1382+
&string_array,
1383+
eval_mode,
1384+
TimestampMicrosecondType,
1385+
timestamp_parser,
1386+
tz
1387+
);
1388+
assert!(
1389+
result.is_err(),
1390+
"ANSI mode should return Err for an invalid timestamp string"
1391+
);
1392+
}
1393+
13531394
#[test]
13541395
fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
13551396
// prepare input data

spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ object SparkErrorConverter extends ShimSparkErrorConverter {
100100
case None => Array.empty[QueryContext] // No context
101101
}
102102

103-
val summary: String = errorJson.summary.orNull
103+
val summary: String = errorJson.summary.getOrElse("")
104104

105105
// Delegate to version-specific shim - let conversion exceptions propagate
106106
val optEx = convertErrorType(errorJson.errorType, errorClass, params, sparkContext, summary)

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
217217
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
218218
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
219219
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
220-
case DataTypes.TimestampType if evalMode == CometEvalMode.ANSI =>
221-
Incompatible(Some("ANSI mode not supported"))
222220
case DataTypes.TimestampType =>
223-
// https://github.com/apache/datafusion-comet/issues/328
224221
Incompatible(Some("Not all valid formats are supported"))
225222
case _ =>
226223
unsupported(DataTypes.StringType, toType)

spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.io.FileNotFoundException
2323

2424
import scala.util.matching.Regex
2525

26-
import org.apache.spark.{QueryContext, SparkException}
26+
import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
2727
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
2828
import org.apache.spark.sql.errors.QueryExecutionErrors
2929
import org.apache.spark.sql.types._
@@ -172,6 +172,22 @@ trait ShimSparkErrorConverter {
172172
QueryExecutionErrors
173173
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
174174

175+
case "InvalidInputInCastToDatetime" =>
176+
val expression =
177+
s"'${params("value").toString.replace("\\", "\\\\").replace("'", "\\'")}'"
178+
val sourceType = s""""${params("fromType").toString}""""
179+
val targetType = s""""${params("toType").toString}""""
180+
Some(
181+
new SparkDateTimeException(
182+
errorClass = "CAST_INVALID_INPUT",
183+
messageParameters = Map(
184+
"expression" -> expression,
185+
"sourceType" -> sourceType,
186+
"targetType" -> targetType,
187+
"ansiConfig" -> "\"spark.sql.ansi.enabled\""),
188+
context = context,
189+
summary = summary))
190+
175191
case "CastOverFlow" =>
176192
val fromType = getDataType(params("fromType").toString)
177193
val toType = getDataType(params("toType").toString)

spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.io.FileNotFoundException
2323

2424
import scala.util.matching.Regex
2525

26-
import org.apache.spark.{QueryContext, SparkException}
26+
import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
2727
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
2828
import org.apache.spark.sql.errors.QueryExecutionErrors
2929
import org.apache.spark.sql.types._
@@ -170,6 +170,22 @@ trait ShimSparkErrorConverter {
170170
QueryExecutionErrors
171171
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
172172

173+
case "InvalidInputInCastToDatetime" =>
174+
val expression =
175+
s"'${params("value").toString.replace("\\", "\\\\").replace("'", "\\'")}'"
176+
val sourceType = s""""${params("fromType").toString}""""
177+
val targetType = s""""${params("toType").toString}""""
178+
Some(
179+
new SparkDateTimeException(
180+
errorClass = "CAST_INVALID_INPUT",
181+
messageParameters = Map(
182+
"expression" -> expression,
183+
"sourceType" -> sourceType,
184+
"targetType" -> targetType,
185+
"ansiConfig" -> "\"spark.sql.ansi.enabled\""),
186+
context = context,
187+
summary = summary))
188+
173189
case "CastOverFlow" =>
174190
val fromType = getDataType(params("fromType").toString)
175191
val toType = getDataType(params("toType").toString)

spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,13 @@ trait ShimSparkErrorConverter {
182182
QueryExecutionErrors
183183
.invalidInputInCastToNumberError(targetType, str, context.headOption.orNull))
184184

185+
case "InvalidInputInCastToDatetime" =>
186+
val str = UTF8String.fromString(params("value").toString)
187+
val targetType = getDataType(params("toType").toString)
188+
Some(
189+
QueryExecutionErrors
190+
.invalidInputInCastToDatetimeError(str, targetType, context.headOption.orNull))
191+
185192
case "CastOverFlow" =>
186193
val fromType = getDataType(params("fromType").toString)
187194
val toType = getDataType(params("toType").toString)

0 commit comments

Comments
 (0)