Skip to content

Commit 1913208

Browse files
committed
fix aggregation wrapping now that we don't have an extra CometExecIterator.
1 parent cdbb0e6 commit 1913208

2 files changed

Lines changed: 41 additions & 17 deletions

File tree

native/shuffle/src/shuffle_writer.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
2929
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
3030
use datafusion::physical_plan::EmptyRecordBatchStream;
3131
use datafusion::{
32-
arrow::{datatypes::SchemaRef, error::ArrowError},
32+
arrow::datatypes::SchemaRef,
3333
error::Result,
3434
execution::context::TaskContext,
3535
physical_plan::{
@@ -38,7 +38,7 @@ use datafusion::{
3838
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
3939
},
4040
};
41-
use futures::{StreamExt, TryFutureExt, TryStreamExt};
41+
use futures::{StreamExt, TryStreamExt};
4242
use std::{
4343
any::Any,
4444
fmt,
@@ -171,23 +171,23 @@ impl ExecutionPlan for ShuffleWriterExec {
171171
let input = self.input.execute(partition, Arc::clone(&context))?;
172172
let metrics = ShufflePartitionerMetrics::new(&self.metrics, 0);
173173

174+
// Propagate DataFusionError unchanged: the JNI bridge only downcasts a single
175+
// `DataFusionError::External(SparkError)` layer, so any extra wrap here loses the
176+
// typed exception (e.g. SparkArithmeticException on decimal overflow).
174177
Ok(Box::pin(RecordBatchStreamAdapter::new(
175178
self.schema(),
176-
futures::stream::once(
177-
external_shuffle(
178-
input,
179-
partition,
180-
self.output_data_file.clone(),
181-
self.output_index_file.clone(),
182-
self.partitioning.clone(),
183-
metrics,
184-
context,
185-
self.codec.clone(),
186-
self.tracing_enabled,
187-
self.write_buffer_size,
188-
)
189-
.map_err(|e| ArrowError::ExternalError(Box::new(e))),
190-
)
179+
futures::stream::once(external_shuffle(
180+
input,
181+
partition,
182+
self.output_data_file.clone(),
183+
self.output_index_file.clone(),
184+
self.partitioning.clone(),
185+
metrics,
186+
context,
187+
self.codec.clone(),
188+
self.tracing_enabled,
189+
self.write_buffer_size,
190+
))
191191
.try_flatten(),
192192
)))
193193
}

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18171817
// make sure that the error message throws overflow exception only
18181818
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
18191819
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1820+
assert(
1821+
cometExc.isInstanceOf[ArithmeticException],
1822+
"expected ArithmeticException, got " +
1823+
s"${cometExc.getClass.getName}: ${cometExc.getMessage}")
18201824
case _ => fail("Exception should be thrown for Long overflow in ANSI mode")
18211825
}
18221826
} else {
@@ -1831,6 +1835,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18311835
case (Some(sparkExc), Some(cometExc)) =>
18321836
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
18331837
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1838+
assert(
1839+
cometExc.isInstanceOf[ArithmeticException],
1840+
"expected ArithmeticException, got " +
1841+
s"${cometExc.getClass.getName}: ${cometExc.getMessage}")
18341842
case _ => fail("Exception should be thrown for Long underflow in ANSI mode")
18351843
}
18361844
} else {
@@ -1870,6 +1878,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18701878
case (Some(sparkExc), Some(cometExc)) =>
18711879
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
18721880
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1881+
assert(
1882+
cometExc.isInstanceOf[ArithmeticException],
1883+
"expected ArithmeticException, got " +
1884+
s"${cometExc.getClass.getName}: ${cometExc.getMessage}")
18731885
case _ =>
18741886
fail("Exception should be thrown for decimal overflow in ANSI mode")
18751887
}
@@ -1893,6 +1905,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18931905
case (Some(sparkExc), Some(cometExc)) =>
18941906
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
18951907
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1908+
assert(
1909+
cometExc.isInstanceOf[ArithmeticException],
1910+
"expected ArithmeticException, got " +
1911+
s"${cometExc.getClass.getName}: ${cometExc.getMessage}")
18961912
case _ =>
18971913
fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode")
18981914
}
@@ -1910,6 +1926,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
19101926
case (Some(sparkExc), Some(cometExc)) =>
19111927
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
19121928
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1929+
assert(
1930+
cometExc.isInstanceOf[ArithmeticException],
1931+
"expected ArithmeticException, got " +
1932+
s"${cometExc.getClass.getName}: ${cometExc.getMessage}")
19131933
case _ =>
19141934
fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode")
19151935
}
@@ -1951,6 +1971,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
19511971
case (Some(sparkExc), Some(cometExc)) =>
19521972
assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
19531973
assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
1974+
assert(
1975+
cometExc.isInstanceOf[ArithmeticException],
1976+
"expected ArithmeticException, got " +
1977+
s"${cometExc.getClass.getName}: ${cometExc.getMessage}")
19541978
case _ =>
19551979
fail("Exception should be thrown for decimal overflow with GROUP BY in ANSI mode")
19561980
}

0 commit comments

Comments
 (0)