Skip to content

Commit 231aa90

Browse files
Merge branch 'apache:main' into main
2 parents d887555 + 638c2c3 commit 231aa90

5 files changed

Lines changed: 47 additions & 35 deletions

File tree

native/core/src/execution/shuffle/row.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -765,9 +765,6 @@ pub fn process_sorted_row_partition(
765765
initial_checksum: Option<u32>,
766766
codec: &CompressionCodec,
767767
) -> Result<(i64, Option<u32>), CometError> {
768-
// TODO: We can tune this parameter automatically based on row size and cache size.
769-
let row_step = 10;
770-
771768
// The current row number we are reading
772769
let mut current_row = 0;
773770
// Total number of bytes written
@@ -790,26 +787,19 @@ pub fn process_sorted_row_partition(
790787
})?;
791788

792789
// Appends rows to the array builders.
793-
let mut row_start: usize = current_row;
794-
while row_start < current_row + n {
795-
let row_end = std::cmp::min(row_start + row_step, current_row + n);
796-
797-
// For each column, iterating over rows and appending values to corresponding array
798-
// builder.
799-
for (idx, builder) in data_builders.iter_mut().enumerate() {
800-
append_columns(
801-
row_addresses_ptr,
802-
row_sizes_ptr,
803-
row_start,
804-
row_end,
805-
schema,
806-
idx,
807-
builder,
808-
prefer_dictionary_ratio,
809-
)?;
810-
}
811-
812-
row_start = row_end;
790+
// For each column, iterating over rows and appending values to corresponding array
791+
// builder.
792+
for (idx, builder) in data_builders.iter_mut().enumerate() {
793+
append_columns(
794+
row_addresses_ptr,
795+
row_sizes_ptr,
796+
current_row,
797+
current_row + n,
798+
schema,
799+
idx,
800+
builder,
801+
prefer_dictionary_ratio,
802+
)?;
813803
}
814804

815805
// Writes a record batch generated from the array builders to the output file.

native/spark-expr/src/math_funcs/modulo_expr.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,16 @@ pub fn create_modulo_expr(
100100
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
101101
));
102102

103+
// The UDF's return type must match what Arrow's rem function will actually return.
104+
// Since we're operating on Decimal256 inputs, rem will return Decimal256.
105+
let decimal256_return_type = match &data_type {
106+
DataType::Decimal128(p, s) => DataType::Decimal256(*p, *s),
107+
other => other.clone(),
108+
};
103109
let modulo_scalar_func = create_modulo_scalar_function(
104110
left_256,
105111
right_256,
106-
&data_type,
112+
&decimal256_return_type,
107113
registry,
108114
fail_on_error,
109115
)?;

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,8 +1772,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
17721772
}
17731773
}
17741774

1775+
test("Decimal modulus with Decimal256 intermediate type") {
1776+
// regression test for https://github.com/apache/datafusion-comet/issues/2911
1777+
withTable("test") {
1778+
sql("create table test(a decimal(33, 29), b decimal(28, 17)) using parquet")
1779+
sql(
1780+
"insert into test values (-6788.53035340376888409034576923353, " +
1781+
"70948216565.90127985418365471)")
1782+
withSQLConf(
1783+
"spark.comet.enabled" -> "true",
1784+
"spark.sql.decimalOperations.allowPrecisionLoss" -> "true") {
1785+
val df = sql("select a, b, a % b from test")
1786+
df.collect()
1787+
}
1788+
}
1789+
}
1790+
17751791
test("Decimal random number tests") {
1776-
val rand = scala.util.Random
1792+
val rand = new scala.util.Random(42)
17771793
def makeNum(p: Int, s: Int): String = {
17781794
val int1 = rand.nextLong()
17791795
val int2 = rand.nextLong().abs

spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ abstract class ParquetReadSuite extends CometTestBase {
348348
}
349349

350350
test("mixed nulls and non-nulls") {
351-
val rand = scala.util.Random
351+
val rand = new scala.util.Random(42)
352352
val data = (0 to 100).map { i =>
353353
val row: (Boolean, Integer, java.lang.Long, java.lang.Float, java.lang.Double, String) = {
354354
if (rand.nextBoolean()) {
@@ -403,7 +403,7 @@ abstract class ParquetReadSuite extends CometTestBase {
403403
pageSize = pageSize,
404404
dictionaryPageSize = pageSize)
405405

406-
val rand = scala.util.Random
406+
val rand = new scala.util.Random(42)
407407
val expected = (0 until n).map { i =>
408408
if (rand.nextBoolean()) {
409409
None
@@ -626,7 +626,7 @@ abstract class ParquetReadSuite extends CometTestBase {
626626
dictionaryPageSize = dictionaryPageSize,
627627
pageRowCountLimit = pageRowCount)
628628

629-
val rand = scala.util.Random
629+
val rand = new scala.util.Random(42)
630630
val expected = (0 until n).map { i =>
631631
// use a single value for the first page, to make sure dictionary encoding kicks in
632632
val value = if (i < pageRowCount) i % 8 else i
@@ -814,7 +814,7 @@ abstract class ParquetReadSuite extends CometTestBase {
814814
dictionaryPageSize = pageSize,
815815
rowGroupSize = 1024 * 128)
816816

817-
val rand = scala.util.Random
817+
val rand = new scala.util.Random(42)
818818
val expected = (0 until n).map { i =>
819819
if (rand.nextBoolean()) {
820820
None
@@ -1564,7 +1564,7 @@ abstract class ParquetReadSuite extends CometTestBase {
15641564
pageSize = pageSize,
15651565
dictionaryPageSize = pageSize)
15661566

1567-
val rand = scala.util.Random
1567+
val rand = new scala.util.Random(42)
15681568
val expected = (0 until n).map { i =>
15691569
if (rand.nextBoolean()) {
15701570
None
@@ -1662,7 +1662,7 @@ abstract class ParquetReadSuite extends CometTestBase {
16621662
dictionaryPageSize = dictionaryPageSize,
16631663
pageRowCountLimit = pageRowCount)
16641664

1665-
val rand = scala.util.Random
1665+
val rand = new scala.util.Random(42)
16661666
val expected = (0 until n).map { i =>
16671667
// use a single value for the first page, to make sure dictionary encoding kicks in
16681668
val value = if (i < pageRowCount) i % 8 else i

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ abstract class CometTestBase
694694

695695
val idGenerator = new AtomicInteger(0)
696696

697-
val rand = scala.util.Random
697+
val rand = new scala.util.Random(42)
698698
val data = (begin until end).map { i =>
699699
if (nullEnabled && rand.nextBoolean()) {
700700
None
@@ -788,7 +788,7 @@ abstract class CometTestBase
788788
rowGroupSize = rowGroupSize)
789789
val div = if (dictionaryEnabled) 10 else n // maps value to a small range for dict to kick in
790790

791-
val rand = scala.util.Random
791+
val rand = new scala.util.Random(42)
792792
val expected = (0 until n).map { i =>
793793
if (rand.nextBoolean()) {
794794
None
@@ -842,7 +842,7 @@ abstract class CometTestBase
842842
rowGroupSize = rowGroupSize)
843843
val div = if (dictionaryEnabled) 10 else n // maps value to a small range for dict to kick in
844844

845-
val rand = scala.util.Random
845+
val rand = new scala.util.Random(42)
846846
val expected = (0 until n).map { i =>
847847
if (rand.nextBoolean()) {
848848
None
@@ -1240,7 +1240,7 @@ abstract class CometTestBase
12401240
val schema = MessageTypeParser.parseMessageType(schemaStr)
12411241
val writer = createParquetWriter(schema, path, dictionaryEnabled = true)
12421242

1243-
val rand = scala.util.Random
1243+
val rand = new scala.util.Random(42)
12441244
val expected = (0 until total).map { i =>
12451245
// use a single value for the first page, to make sure dictionary encoding kicks in
12461246
if (rand.nextBoolean()) None

0 commit comments

Comments
 (0)