Skip to content

Commit fe887c6

Browse files
committed
[GLUTEN-11915][VL] Support RowBasedChecksum in ColumnarShuffleWriter (SPARK-51756)
Implement order-independent row-based checksum for non-deterministic stage retry detection. - C++ computeRowBasedChecksums(): UnsafeRowFast + XXH64, per-partition XOR+SUM - JNI: pass config, return checksum array - Scala: read SQLConf (OR logic), pass to native, use for MapStatus - Shim: GlutenMapStatusUtil for Spark 3.3-4.1 compatibility - Tests: C++ unit (4/4) + Scala integration (3/3)
1 parent 0b1e511 commit fe887c6

20 files changed

Lines changed: 470 additions & 12 deletions

File tree

backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
151151
nativeBufferSize,
152152
GlutenConfig.get.columnarShuffleReallocThreshold,
153153
GlutenConfig.get.columnarShufflePartitionBufferEvictThreshold,
154-
partitionWriterHandle
154+
partitionWriterHandle,
155+
false
155156
)
156157
case SortShuffleWriterType =>
157158
shuffleWriterJniWrapper.createSortShuffleWriter(

backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ protected void writeImpl(Iterator<Product2<K, V>> records) {
186186
nativeBufferSize,
187187
reallocThreshold,
188188
GlutenConfig.get().columnarShufflePartitionBufferEvictThreshold(),
189-
partitionWriterHandle);
189+
partitionWriterHandle,
190+
false);
190191
}
191192

192193
runtime

backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class ColumnarShuffleWriter[K, V](
6060

6161
private val blockManager = SparkEnv.get.blockManager
6262

63+
private val rowBasedChecksumEnabled: Boolean = GlutenMapStatusUtil.isRowBasedChecksumEnabled
64+
6365
// Are we in the process of stopping? Because map tasks can call stop() with success = true
6466
// and then call stop() with success = false if they get an exception, we want to make sure
6567
// we don't try deleting files, etc twice.
@@ -193,7 +195,8 @@ class ColumnarShuffleWriter[K, V](
193195
nativeBufferSize,
194196
reallocThreshold,
195197
GlutenConfig.get.columnarShufflePartitionBufferEvictThreshold,
196-
partitionWriterHandle
198+
partitionWriterHandle,
199+
rowBasedChecksumEnabled
197200
)
198201
}
199202

@@ -282,7 +285,15 @@ class ColumnarShuffleWriter[K, V](
282285
// almost 3 times than vanilla spark partitionLengths
283286
// This value is sensitive in rules such as AQE rule OptimizeSkewedJoin DynamicJoinSelection
284287
// May affect the final plan
285-
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
288+
val rowChecksums = splitResult.getRowBasedChecksums
289+
val aggregatedChecksum = if (rowChecksums != null && rowChecksums.nonEmpty) {
290+
rowChecksums.foldLeft(0L)((acc, c) => acc * 31L + c)
291+
} else 0L
292+
mapStatus = GlutenMapStatusUtil.createMapStatus(
293+
blockManager.shuffleServerId,
294+
partitionLengths,
295+
mapId,
296+
aggregatedChecksum)
286297
}
287298

288299
private def handleEmptyInput(): Unit = {
@@ -293,7 +304,11 @@ class ColumnarShuffleWriter[K, V](
293304
partitionLengths,
294305
Array[Long](),
295306
null)
296-
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
307+
mapStatus = GlutenMapStatusUtil.createMapStatus(
308+
blockManager.shuffleServerId,
309+
partitionLengths,
310+
mapId,
311+
0L)
297312
}
298313

299314
@throws[IOException]

cpp/core/jni/JniWrapper.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
269269
jniByteInputStreamClose = getMethodIdOrError(env, jniByteInputStreamClass, "close", "()V");
270270

271271
splitResultClass = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/vectorized/GlutenSplitResult;");
272-
splitResultConstructor = getMethodIdOrError(env, splitResultClass, "<init>", "(JJJJJJJJJJDJ[J[J)V");
272+
splitResultConstructor = getMethodIdOrError(env, splitResultClass, "<init>", "(JJJJJJJJJJDJ[J[J[J)V");
273273

274274
metricsBuilderClass = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/metrics/Metrics;");
275275

@@ -991,7 +991,8 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
991991
jint splitBufferSize,
992992
jdouble splitBufferReallocThreshold,
993993
jint partitionBufferEvictThreshold,
994-
jlong partitionWriterHandle) {
994+
jlong partitionWriterHandle,
995+
jboolean rowBasedChecksumEnabled) {
995996
JNI_METHOD_START
996997
const auto ctx = getRuntime(env, wrapper);
997998

@@ -1007,6 +1008,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
10071008
splitBufferSize,
10081009
splitBufferReallocThreshold,
10091010
partitionBufferEvictThreshold);
1011+
shuffleWriterOptions->rowBasedChecksumEnabled = rowBasedChecksumEnabled;
10101012

10111013
return ctx->saveObject(ctx->createShuffleWriter(numPartitions, partitionWriter, shuffleWriterOptions));
10121014
JNI_METHOD_END(kInvalidObjectHandle)
@@ -1161,6 +1163,13 @@ JNIEXPORT jobject JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrap
11611163
auto rawSrc = reinterpret_cast<const jlong*>(rawPartitionLengths.data());
11621164
env->SetLongArrayRegion(rawPartitionLengthArr, 0, rawPartitionLengths.size(), rawSrc);
11631165

1166+
const auto& rowBasedChecksums = shuffleWriter->rowBasedChecksums();
1167+
auto rowBasedChecksumArr = env->NewLongArray(rowBasedChecksums.size());
1168+
if (!rowBasedChecksums.empty()) {
1169+
auto checksumSrc = reinterpret_cast<const jlong*>(rowBasedChecksums.data());
1170+
env->SetLongArrayRegion(rowBasedChecksumArr, 0, rowBasedChecksums.size(), checksumSrc);
1171+
}
1172+
11641173
jobject splitResult = env->NewObject(
11651174
splitResultClass,
11661175
splitResultConstructor,
@@ -1177,7 +1186,8 @@ JNIEXPORT jobject JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrap
11771186
shuffleWriter->avgDictionaryFields(),
11781187
shuffleWriter->dictionarySize(),
11791188
partitionLengthArr,
1180-
rawPartitionLengthArr);
1189+
rawPartitionLengthArr,
1190+
rowBasedChecksumArr);
11811191

11821192
return splitResult;
11831193
JNI_METHOD_END(nullptr)

cpp/core/shuffle/Options.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct ShuffleWriterOptions {
7474
ShuffleWriterType shuffleWriterType;
7575
Partitioning partitioning = Partitioning::kRoundRobin;
7676
int32_t startPartitionId = 0;
77+
bool rowBasedChecksumEnabled = false;
7778

7879
ShuffleWriterOptions(ShuffleWriterType shuffleWriterType) : shuffleWriterType(shuffleWriterType) {}
7980

@@ -234,5 +235,6 @@ struct ShuffleWriterMetrics {
234235
int64_t dictionarySize{0};
235236
std::vector<int64_t> partitionLengths{};
236237
std::vector<int64_t> rawPartitionLengths{}; // Uncompressed size.
238+
std::vector<int64_t> rowBasedChecksums{}; // Per-partition row-based checksums.
237239
};
238240
} // namespace gluten

cpp/core/shuffle/ShuffleWriter.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ const std::vector<int64_t>& ShuffleWriter::rawPartitionLengths() const {
109109
return metrics_.rawPartitionLengths;
110110
}
111111

112+
const std::vector<int64_t>& ShuffleWriter::rowBasedChecksums() const {
113+
return metrics_.rowBasedChecksums;
114+
}
115+
112116
ShuffleWriter::ShuffleWriter(int32_t numPartitions, Partitioning partitioning)
113117
: numPartitions_(numPartitions), partitioning_(partitioning) {}
114118
} // namespace gluten

cpp/core/shuffle/ShuffleWriter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class ShuffleWriter : public Reclaimable {
6767

6868
const std::vector<int64_t>& rawPartitionLengths() const;
6969

70+
const std::vector<int64_t>& rowBasedChecksums() const;
71+
7072
protected:
7173
ShuffleWriter(int32_t numPartitions, Partitioning partitioning);
7274

cpp/velox/shuffle/VeloxHashShuffleWriter.cc

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "utils/VeloxArrowUtils.h"
2626
#include "velox/buffer/Buffer.h"
2727
#include "velox/common/base/Nulls.h"
28+
#include "velox/external/xxhash/xxhash.h"
29+
#include "velox/row/UnsafeRowFast.h"
2830
#include "velox/type/HugeInt.h"
2931
#include "velox/type/Timestamp.h"
3032
#include "velox/type/Type.h"
@@ -182,6 +184,11 @@ arrow::Status VeloxHashShuffleWriter::init() {
182184

183185
partitionBufferBase_.resize(numPartitions_);
184186

187+
if (rowBasedChecksumEnabled_) {
188+
checksumXor_.resize(numPartitions_, 0);
189+
checksumSum_.resize(numPartitions_, 0);
190+
}
191+
185192
return arrow::Status::OK();
186193
}
187194

@@ -362,6 +369,17 @@ arrow::Status VeloxHashShuffleWriter::stop() {
362369

363370
stat();
364371

372+
// Populate row-based checksums into metrics.
373+
if (rowBasedChecksumEnabled_) {
374+
metrics_.rowBasedChecksums.resize(numPartitions_);
375+
for (auto pid = 0; pid < numPartitions_; ++pid) {
376+
int64_t xorVal = checksumXor_[pid];
377+
int64_t sumVal = checksumSum_[pid];
378+
int64_t rotated = (static_cast<uint64_t>(sumVal) << 27) | (static_cast<uint64_t>(sumVal) >> 37);
379+
metrics_.rowBasedChecksums[pid] = xorVal ^ rotated;
380+
}
381+
}
382+
365383
return arrow::Status::OK();
366384
}
367385

@@ -423,6 +441,7 @@ void VeloxHashShuffleWriter::setSplitState(SplitState state) {
423441
arrow::Status VeloxHashShuffleWriter::doSplit(const facebook::velox::RowVector& rv, int64_t memLimit) {
424442
auto rowNum = rv.size();
425443
RETURN_NOT_OK(buildPartition2Row(rowNum));
444+
computeRowBasedChecksums(rv);
426445
RETURN_NOT_OK(updateInputHasNull(rv));
427446

428447
{
@@ -1616,4 +1635,50 @@ bool VeloxHashShuffleWriter::isExtremelyLargeBatch(facebook::velox::RowVectorPtr
16161635
return (rv->size() > maxBatchSize_ && maxBatchSize_ > 0);
16171636
}
16181637

1638+
void VeloxHashShuffleWriter::computeRowBasedChecksums(const facebook::velox::RowVector& rv) {
1639+
if (!rowBasedChecksumEnabled_) {
1640+
return;
1641+
}
1642+
1643+
auto numRows = rv.size();
1644+
VELOX_DCHECK(rv.nulls() == nullptr, "RowVector with top-level nulls not supported for checksum");
1645+
// Get the RowVector to serialize (strip pid column if present).
1646+
facebook::velox::RowVectorPtr dataVector;
1647+
if (partitioner_->hasPid()) {
1648+
// Strip the first column (partition id).
1649+
auto rowType = std::dynamic_pointer_cast<const facebook::velox::RowType>(rv.type());
1650+
std::vector<std::string> names(rowType->names().begin() + 1, rowType->names().end());
1651+
std::vector<facebook::velox::TypePtr> types(rowType->children().begin() + 1, rowType->children().end());
1652+
std::vector<facebook::velox::VectorPtr> children(rv.children().begin() + 1, rv.children().end());
1653+
auto dataType = facebook::velox::ROW(std::move(names), std::move(types));
1654+
dataVector =
1655+
std::make_shared<facebook::velox::RowVector>(rv.pool(), dataType, nullptr, numRows, std::move(children));
1656+
} else {
1657+
auto rowType = std::dynamic_pointer_cast<const facebook::velox::RowType>(rv.type());
1658+
dataVector = std::make_shared<facebook::velox::RowVector>(rv.pool(), rowType, nullptr, numRows, rv.children());
1659+
}
1660+
1661+
facebook::velox::row::UnsafeRowFast fast(dataVector);
1662+
auto dataType = std::dynamic_pointer_cast<const facebook::velox::RowType>(dataVector->type());
1663+
auto fixedSize = facebook::velox::row::UnsafeRowFast::fixedRowSize(dataType);
1664+
int32_t bufSize = fixedSize.value_or(1024);
1665+
if (checksumBuffer_.size() < static_cast<size_t>(bufSize)) {
1666+
checksumBuffer_.resize(bufSize);
1667+
}
1668+
1669+
for (uint32_t row = 0; row < numRows; ++row) {
1670+
auto pid = row2Partition_[row];
1671+
auto size = fast.rowSize(row);
1672+
if (size > static_cast<int32_t>(checksumBuffer_.size())) {
1673+
checksumBuffer_.resize(size);
1674+
}
1675+
std::memset(checksumBuffer_.data(), 0, size);
1676+
fast.serialize(row, checksumBuffer_.data());
1677+
1678+
auto hash = static_cast<int64_t>(XXH64(checksumBuffer_.data(), size, 0));
1679+
checksumXor_[pid] ^= hash;
1680+
checksumSum_[pid] += hash;
1681+
}
1682+
}
1683+
16191684
} // namespace gluten

cpp/velox/shuffle/VeloxHashShuffleWriter.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ class VeloxHashShuffleWriter : public VeloxShuffleWriter {
279279
: VeloxShuffleWriter(numPartitions, partitionWriter, options, memoryManager),
280280
splitBufferSize_(options->splitBufferSize),
281281
splitBufferReallocThreshold_(options->splitBufferReallocThreshold),
282-
partitionBufferEvictThreshold_(options->partitionBufferEvictThreshold) {
282+
partitionBufferEvictThreshold_(options->partitionBufferEvictThreshold),
283+
rowBasedChecksumEnabled_(options->rowBasedChecksumEnabled) {
283284
arenas_.resize(numPartitions);
284285
}
285286

@@ -516,6 +517,14 @@ class VeloxHashShuffleWriter : public VeloxShuffleWriter {
516517

517518
// See inputEncodingSkippedBatches() above.
518519
int64_t inputEncodingSkippedBatches_{0};
520+
521+
// Row-based checksum state (per-partition XOR + SUM aggregation).
522+
bool rowBasedChecksumEnabled_{false};
523+
std::vector<int64_t> checksumXor_;
524+
std::vector<int64_t> checksumSum_;
525+
std::vector<char> checksumBuffer_;
526+
527+
void computeRowBasedChecksums(const facebook::velox::RowVector& rv);
519528
}; // class VeloxHashBasedShuffleWriter
520529

521530
} // namespace gluten

cpp/velox/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ add_velox_test(runtime_test SOURCES RuntimeTest.cc)
138138
add_velox_test(velox_memory_test SOURCES MemoryManagerTest.cc)
139139
add_velox_test(buffer_outputstream_test SOURCES BufferOutputStreamTest.cc)
140140
add_velox_test(scoped_timer_test SOURCES ScopedTimerTest.cc)
141+
add_velox_test(row_based_checksum_test SOURCES RowBasedChecksumTest.cc)
141142
if(BUILD_EXAMPLES)
142143
add_velox_test(my_udf_test SOURCES MyUdfTest.cc)
143144
endif()

0 commit comments

Comments
 (0)