Skip to content

Commit 0e94d48

Browse files
committed
address review comments
1 parent 30c1929 commit 0e94d48

12 files changed

Lines changed: 203 additions & 17 deletions

File tree

common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ object Utils extends CometTypeShim with Logging {
148148
}
149149
case TimestampNTZType =>
150150
new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
151-
case dt if dt.getClass.getSimpleName.startsWith("TimeType") =>
151+
case dt if isTimeType(dt) =>
152152
new ArrowType.Time(TimeUnit.NANOSECOND, 64)
153153
case _ =>
154154
throw new UnsupportedOperationException(
@@ -401,7 +401,7 @@ object Utils extends CometTypeShim with Logging {
401401
_: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector |
402402
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector |
403403
_: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector | _: ListVector |
404-
_: MapVector | _: NullVector) =>
404+
_: MapVector | _: NullVector | _: TimeNanoVector) =>
405405
v.asInstanceOf[FieldVector]
406406
case _ =>
407407
throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}")

common/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,7 @@ trait CometTypeShim {
3232

3333
@nowarn // Spark 4 feature; Variant shredding doesn't exist in Spark 3.x.
3434
def isVariantStruct(s: StructType): Boolean = false
35+
36+
@nowarn // Spark 4.1 feature; TimeType doesn't exist in Spark 3.x.
37+
def isTimeType(dt: DataType): Boolean = false
3538
}

common/src/main/spark-4.x/org/apache/comet/shims/CometTypeShim.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,7 @@ trait CometTypeShim {
5353
// variant shredding layout, so reading such a struct natively returns nulls. Detect the marker
5454
// and force scan fallback.
5555
def isVariantStruct(s: StructType): Boolean = VariantMetadata.isVariantStruct(s)
56+
57+
def isTimeType(dt: DataType): Boolean =
58+
dt.getClass.getSimpleName.startsWith("TimeType")
5659
}

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,7 @@ harness = false
118118
[[bench]]
119119
name = "map_sort"
120120
harness = false
121+
122+
[[bench]]
123+
name = "to_time"
124+
harness = false
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::StringArray;
19+
use criterion::{criterion_group, criterion_main, Criterion};
20+
use datafusion::physical_plan::ColumnarValue;
21+
use datafusion_comet_spark_expr::spark_to_time;
22+
use std::sync::Arc;
23+
24+
fn criterion_benchmark(c: &mut Criterion) {
25+
let mut group = c.benchmark_group("to_time");
26+
27+
let hh_mm = create_string_array(10000, |i| format!("{}:{:02}", i % 24, i % 60));
28+
group.bench_function("hh_mm", |b| {
29+
b.iter(|| spark_to_time(std::slice::from_ref(&hh_mm), true).unwrap());
30+
});
31+
32+
let hh_mm_ss = create_string_array(10000, |i| {
33+
format!("{}:{:02}:{:02}", i % 24, i % 60, i % 60)
34+
});
35+
group.bench_function("hh_mm_ss", |b| {
36+
b.iter(|| spark_to_time(std::slice::from_ref(&hh_mm_ss), true).unwrap());
37+
});
38+
39+
let fractional = create_string_array(10000, |i| {
40+
format!("{}:{:02}:{:02}.{:06}", i % 24, i % 60, i % 60, i * 7 % 1000000)
41+
});
42+
group.bench_function("fractional", |b| {
43+
b.iter(|| spark_to_time(std::slice::from_ref(&fractional), true).unwrap());
44+
});
45+
46+
let am_pm = create_string_array(10000, |i| {
47+
let hour = (i % 12) + 1;
48+
let suffix = if i % 2 == 0 { "AM" } else { "PM" };
49+
format!("{}:{:02}:{:02} {}", hour, i % 60, i % 60, suffix)
50+
});
51+
group.bench_function("am_pm", |b| {
52+
b.iter(|| spark_to_time(std::slice::from_ref(&am_pm), true).unwrap());
53+
});
54+
55+
group.finish();
56+
}
57+
58+
fn create_string_array(size: usize, f: impl Fn(usize) -> String) -> ColumnarValue {
59+
let values: Vec<String> = (0..size).map(&f).collect();
60+
let array = StringArray::from(values);
61+
ColumnarValue::Array(Arc::new(array))
62+
}
63+
64+
fn config() -> Criterion {
65+
Criterion::default()
66+
}
67+
68+
criterion_group! {
69+
name = benches;
70+
config = config();
71+
targets = criterion_benchmark
72+
}
73+
criterion_main!(benches);

native/spark-expr/src/datetime_funcs/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ pub use make_date::SparkMakeDate;
3838
pub use make_time::SparkMakeTime;
3939
pub use seconds_to_timestamp::SparkSecondsToTimestamp;
4040
pub use timestamp_trunc::TimestampTruncExpr;
41-
pub use to_time::{spark_to_time, to_time_return_type};
41+
pub use to_time::spark_to_time;
4242
pub use unix_timestamp::SparkUnixTimestamp;

native/spark-expr/src/datetime_funcs/to_time.rs

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
// under the License.
1717

1818
use arrow::array::{Array, StringArray, Time64NanosecondArray};
19-
use arrow::datatypes::{DataType, TimeUnit};
2019
use datafusion::common::{DataFusionError, Result};
2120
use datafusion::physical_plan::ColumnarValue;
2221
use std::sync::Arc;
@@ -81,11 +80,18 @@ pub fn spark_to_time(args: &[ColumnarValue], fail_on_error: bool) -> Result<Colu
8180
/// Parse a time string to nanoseconds from midnight, matching Spark's stringToTime behavior.
8281
/// Returns None for invalid input.
8382
fn string_to_time(s: &str) -> Option<i64> {
84-
let trimmed = s.trim_end();
83+
let trimmed = s.trim();
8584
if trimmed.is_empty() {
8685
return None;
8786
}
8887

88+
// Spark's parseTimestampString gates the T-prefix branch on j == 0 (start of
89+
// the trimmed string), so " T12:30" is rejected even though leading whitespace
90+
// is trimmed: the original segment start differs from the trimmed position.
91+
if trimmed.as_bytes()[0] == b'T' && s.as_bytes()[0].is_ascii_whitespace() {
92+
return None;
93+
}
94+
8995
let bytes = trimmed.as_bytes();
9096
let num_chars = bytes.len();
9197

@@ -286,11 +292,6 @@ fn parse_fractional(bytes: &[u8], start: usize) -> Option<(i32, usize)> {
286292
Some((value, pos))
287293
}
288294

289-
/// Return type for to_time
290-
pub fn to_time_return_type() -> DataType {
291-
DataType::Time64(TimeUnit::Nanosecond)
292-
}
293-
294295
#[cfg(test)]
295296
mod tests {
296297
use super::*;
@@ -334,6 +335,11 @@ mod tests {
334335
);
335336
// 6 digits
336337
assert_eq!(string_to_time("00:00:00.000001"), Some(NANOS_PER_MICRO));
338+
// >6 digits truncated to microseconds
339+
assert_eq!(
340+
string_to_time("00:00:00.1234567"),
341+
Some(123_456 * NANOS_PER_MICRO)
342+
);
337343
// Full precision
338344
assert_eq!(
339345
string_to_time("23:59:59.999999"),
@@ -459,8 +465,18 @@ mod tests {
459465
}
460466

461467
#[test]
462-
fn test_leading_space_with_t_prefix() {
463-
// Leading space before T should be rejected (Spark only right-trims)
468+
fn test_leading_whitespace() {
469+
assert_eq!(string_to_time(" 12:30"), string_to_time("12:30"));
470+
assert_eq!(string_to_time(" 12:30:45"), string_to_time("12:30:45"));
471+
assert_eq!(string_to_time(" 12:30:45 "), string_to_time("12:30:45"));
472+
assert_eq!(string_to_time(" 1:00:00 AM"), string_to_time("1:00:00 AM"));
473+
// Tabs and newlines are also trimmed (Spark's isWhitespaceOrISOControl)
474+
assert_eq!(string_to_time("\t12:30:45"), string_to_time("12:30:45"));
475+
assert_eq!(string_to_time("\n12:30:45"), string_to_time("12:30:45"));
476+
// T-prefix is rejected when preceded by whitespace because Spark's
477+
// parseTimestampString gates the T-prefix branch on j == 0 (start of
478+
// the already-trimmed segment), so leading whitespace moves j past 0.
479+
assert_eq!(string_to_time(" T12:30:45"), None);
464480
assert_eq!(string_to_time(" T12:30"), None);
465481
}
466482
}

native/spark-expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ pub use comet_scalar_funcs::{
7575
};
7676
pub use csv_funcs::*;
7777
pub use datetime_funcs::{
78-
spark_to_time, to_time_return_type, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc,
78+
spark_to_time, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc,
7979
SparkHour, SparkHoursTransform, SparkMakeDate, SparkMakeTime, SparkMinute, SparkSecond,
8080
SparkSecondsToTimestamp, SparkUnixTimestamp, TimestampTruncExpr,
8181
};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
363363
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |
364364
_: DecimalType | _: DateType | _: BooleanType | _: NullType =>
365365
true
366-
case dt if dt.getClass.getSimpleName.startsWith("TimeType") =>
366+
case dt if isTimeType(dt) =>
367367
true
368368
case s: StructType if allowComplex =>
369369
s.fields.nonEmpty && s.fields.map(_.dataType).forall(supportedDataType(_, allowComplex))
@@ -399,7 +399,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
399399
case _: ArrayType => 14
400400
case _: MapType => 15
401401
case _: StructType => 16
402-
case dt if dt.getClass.getSimpleName.startsWith("TimeType") => 17
402+
case dt if isTimeType(dt) => 17
403403
case dt =>
404404
logWarning(s"Cannot serialize Spark data type: $dt")
405405
return None

spark/src/test/resources/sql-tests/expressions/datetime/make_time.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ INSERT INTO test_make_time VALUES
3333
(12, 30, NULL),
3434
(NULL, NULL, NULL)
3535

36-
-- column arguments (spark_answer_only: shuffle does not support TimeType yet)
36+
-- column arguments (spark_answer_only: shuffle does not support TimeType yet; TODO: promote to
37+
-- full native-verification once SPARK-51779 lands)
3738
query spark_answer_only
3839
SELECT hours, minutes, secs, make_time(hours, minutes, secs) FROM test_make_time ORDER BY hours, minutes, secs
3940

0 commit comments

Comments
 (0)