Skip to content

Commit a9b9ccd

Browse files
committed
fix_test_failures
1 parent 61987e1 commit a9b9ccd

3 files changed

Lines changed: 40 additions & 37 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,8 +2317,8 @@ fn cast_string_to_decimal256_impl(
23172317
}
23182318

23192319
/// Parse a string to decimal following Spark's behavior
2320-
fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
2321-
let string_bytes = s.as_bytes();
2320+
fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
2321+
let string_bytes = input_str.as_bytes();
23222322
let mut start = 0;
23232323
let mut end = string_bytes.len();
23242324

@@ -2330,7 +2330,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
23302330
end -= 1;
23312331
}
23322332

2333-
let trimmed = &s[start..end];
2333+
let trimmed = &input_str[start..end];
23342334

23352335
if trimmed.is_empty() {
23362336
return Ok(None);
@@ -2347,15 +2347,21 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
23472347
return Ok(None);
23482348
}
23492349

2350-
// validate and parse mantissa and exponent
2350+
// validate and parse mantissa and exponent or bubble up the error
23512351
let (mantissa, exponent) = parse_decimal_str(
23522352
trimmed,
2353+
input_str,
23532354
"STRING",
23542355
&format!("DECIMAL({},{})", precision, scale),
23552356
)?;
23562357

2357-
// return early when mantissa is 0
2358+
// Early return mantissa 0, Spark checks if it fits digits and throw error in ansi
23582359
if mantissa == 0 {
2360+
if exponent < -37 {
2361+
return Err(SparkError::NumericOutOfRange {
2362+
value: input_str.to_string(),
2363+
});
2364+
}
23592365
return Ok(Some(0));
23602366
}
23612367

@@ -2424,10 +2430,15 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
24242430
}
24252431

24262432
/// Parse a decimal string into mantissa and scale
2427-
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
2428-
fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i128, i32)> {
2433+
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
2434+
fn parse_decimal_str(
2435+
s: &str,
2436+
original_str: &str,
2437+
from_type: &str,
2438+
to_type: &str,
2439+
) -> SparkResult<(i128, i32)> {
24292440
if s.is_empty() {
2430-
return Err(invalid_value(s, from_type, to_type));
2441+
return Err(invalid_value(original_str, from_type, to_type));
24312442
}
24322443

24332444
let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
@@ -2436,7 +2447,7 @@ fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i1
24362447
// Parse exponent
24372448
let exp: i32 = exponent_part
24382449
.parse()
2439-
.map_err(|_| invalid_value(s, from_type, to_type))?;
2450+
.map_err(|_| invalid_value(original_str, from_type, to_type))?;
24402451

24412452
(mantissa_part, exp)
24422453
} else {
@@ -2451,29 +2462,29 @@ fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i1
24512462
};
24522463

24532464
if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
2454-
return Err(invalid_value(s, from_type, to_type));
2465+
return Err(invalid_value(original_str, from_type, to_type));
24552466
}
24562467

24572468
let (integral_part, fractional_part) = match mantissa_str.find('.') {
24582469
Some(dot_pos) => {
24592470
if mantissa_str[dot_pos + 1..].contains('.') {
2460-
return Err(invalid_value(s, from_type, to_type));
2471+
return Err(invalid_value(original_str, from_type, to_type));
24612472
}
24622473
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
24632474
}
24642475
None => (mantissa_str, ""),
24652476
};
24662477

24672478
if integral_part.is_empty() && fractional_part.is_empty() {
2468-
return Err(invalid_value(s, from_type, to_type));
2479+
return Err(invalid_value(original_str, from_type, to_type));
24692480
}
24702481

24712482
if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) {
2472-
return Err(invalid_value(s, from_type, to_type));
2483+
return Err(invalid_value(original_str, from_type, to_type));
24732484
}
24742485

24752486
if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) {
2476-
return Err(invalid_value(s, from_type, to_type));
2487+
return Err(invalid_value(original_str, from_type, to_type));
24772488
}
24782489

24792490
// Parse integral part
@@ -2483,7 +2494,7 @@ fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i1
24832494
} else {
24842495
integral_part
24852496
.parse()
2486-
.map_err(|_| invalid_value(s, from_type, to_type))?
2497+
.map_err(|_| invalid_value(original_str, from_type, to_type))?
24872498
};
24882499

24892500
// Parse fractional part
@@ -2493,14 +2504,14 @@ fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i1
24932504
} else {
24942505
fractional_part
24952506
.parse()
2496-
.map_err(|_| invalid_value(s, from_type, to_type))?
2507+
.map_err(|_| invalid_value(original_str, from_type, to_type))?
24972508
};
24982509

24992510
// Combine: value = integral * 10^fractional_scale + fractional
25002511
let mantissa = integral_value
25012512
.checked_mul(10_i128.pow(fractional_scale as u32))
25022513
.and_then(|v| v.checked_add(fractional_value))
2503-
.ok_or_else(|| invalid_value(s, from_type, to_type))?;
2514+
.ok_or_else(|| invalid_value(original_str, from_type, to_type))?;
25042515

25052516
let final_mantissa = if negative { -mantissa } else { mantissa };
25062517
// final scale = fractional_scale - exponent

native/spark-expr/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ pub enum SparkError {
3939
scale: i8,
4040
},
4141

42+
#[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")]
43+
NumericOutOfRange { value: String },
44+
4245
#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
4346
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
4447
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions.col
3333
import org.apache.spark.sql.internal.SQLConf
3434
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
3535

36-
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
3736
import org.apache.comet.expressions.{CometCast, CometEvalMode}
3837
import org.apache.comet.rules.CometScanTypeChecker
3938
import org.apache.comet.serde.Compatible
@@ -709,8 +708,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
709708

710709
test("cast StringType to DecimalType(10,2) (does not support fullwidth unicode digits)") {
711710
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
712-
// TODO fix for Spark 4.0.0
713-
assume(!isSpark40Plus)
714711
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
715712
Seq(true, false).foreach(ansiEnabled =>
716713
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
@@ -719,52 +716,46 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
719716

720717
test("cast StringType to DecimalType(2,2)") {
721718
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
722-
println("testing with simple input")
723719
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
724720
Seq(true, false).foreach(ansiEnabled =>
725721
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
726722
}
727723
}
728724

729-
test("cast StringType to DecimalType(2,2) check if right exception is being thrown") {
725+
test("cast StringType to DecimalType check if right exception message is thrown") {
730726
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
731-
println("testing with simple input")
732-
val values = Seq(" 3").toDF("a")
727+
val values = Seq("d11307\n").toDF("a")
733728
Seq(true, false).foreach(ansiEnabled =>
734729
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
735730
}
736731
}
737732

738-
test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") {
733+
test("cast StringType to DecimalType(2,2) check if right exception is being thrown") {
739734
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
740-
val values = Seq("0e31").toDF("a")
735+
val values = gen.generateInts(10000).map(" " + _).toDF("a")
741736
Seq(true, false).foreach(ansiEnabled =>
742-
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
737+
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
743738
}
744739
}
745740

746-
test("cast StringType to DecimalType(38,10) high precision") {
741+
test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") {
747742
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
748-
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
743+
val values = Seq("0e31", "000e3375", "0e40", "0E+695", "0e5887677").toDF("a")
749744
Seq(true, false).foreach(ansiEnabled =>
750745
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
751746
}
752747
}
753748

754-
test("cast StringType to DecimalType(38,10) high precision - 0 mantissa") {
749+
test("cast StringType to DecimalType(38,10) high precision") {
755750
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
756-
// TODO fix for Spark 4.0.0
757-
assume(!isSpark40Plus)
758-
val values = Seq("0e31").toDF("a")
751+
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
759752
Seq(true, false).foreach(ansiEnabled =>
760753
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
761754
}
762755
}
763756

764757
test("cast StringType to DecimalType(10,2) basic values") {
765758
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
766-
// TODO fix for Spark 4.0.0
767-
assume(!isSpark40Plus)
768759
val values = Seq(
769760
"123.45",
770761
"-67.89",
@@ -790,8 +781,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
790781

791782
test("cast StringType to Decimal type scientific notation") {
792783
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
793-
// TODO fix for Spark 4.0.0
794-
assume(!isSpark40Plus)
795784
val values = Seq(
796785
"1.23E-5",
797786
"1.23e10",

0 commit comments

Comments
 (0)