|
25 | 25 | #include "utils/VeloxArrowUtils.h" |
26 | 26 | #include "velox/buffer/Buffer.h" |
27 | 27 | #include "velox/common/base/Nulls.h" |
| 28 | +#include "velox/external/xxhash/xxhash.h" |
| 29 | +#include "velox/row/UnsafeRowFast.h" |
28 | 30 | #include "velox/type/HugeInt.h" |
29 | 31 | #include "velox/type/Timestamp.h" |
30 | 32 | #include "velox/type/Type.h" |
@@ -182,6 +184,11 @@ arrow::Status VeloxHashShuffleWriter::init() { |
182 | 184 |
|
183 | 185 | partitionBufferBase_.resize(numPartitions_); |
184 | 186 |
|
| 187 | + if (rowBasedChecksumEnabled_) { |
| 188 | + checksumXor_.resize(numPartitions_, 0); |
| 189 | + checksumSum_.resize(numPartitions_, 0); |
| 190 | + } |
| 191 | + |
185 | 192 | return arrow::Status::OK(); |
186 | 193 | } |
187 | 194 |
|
@@ -362,6 +369,17 @@ arrow::Status VeloxHashShuffleWriter::stop() { |
362 | 369 |
|
363 | 370 | stat(); |
364 | 371 |
|
| 372 | + // Populate row-based checksums into metrics. |
| 373 | + if (rowBasedChecksumEnabled_) { |
| 374 | + metrics_.rowBasedChecksums.resize(numPartitions_); |
| 375 | + for (auto pid = 0; pid < numPartitions_; ++pid) { |
| 376 | + int64_t xorVal = checksumXor_[pid]; |
| 377 | + int64_t sumVal = checksumSum_[pid]; |
| 378 | + int64_t rotated = (static_cast<uint64_t>(sumVal) << 27) | (static_cast<uint64_t>(sumVal) >> 37); |
| 379 | + metrics_.rowBasedChecksums[pid] = xorVal ^ rotated; |
| 380 | + } |
| 381 | + } |
| 382 | + |
365 | 383 | return arrow::Status::OK(); |
366 | 384 | } |
367 | 385 |
|
@@ -423,6 +441,7 @@ void VeloxHashShuffleWriter::setSplitState(SplitState state) { |
423 | 441 | arrow::Status VeloxHashShuffleWriter::doSplit(const facebook::velox::RowVector& rv, int64_t memLimit) { |
424 | 442 | auto rowNum = rv.size(); |
425 | 443 | RETURN_NOT_OK(buildPartition2Row(rowNum)); |
| 444 | + computeRowBasedChecksums(rv); |
426 | 445 | RETURN_NOT_OK(updateInputHasNull(rv)); |
427 | 446 |
|
428 | 447 | { |
@@ -1616,4 +1635,50 @@ bool VeloxHashShuffleWriter::isExtremelyLargeBatch(facebook::velox::RowVectorPtr |
1616 | 1635 | return (rv->size() > maxBatchSize_ && maxBatchSize_ > 0); |
1617 | 1636 | } |
1618 | 1637 |
|
| 1638 | +void VeloxHashShuffleWriter::computeRowBasedChecksums(const facebook::velox::RowVector& rv) { |
| 1639 | + if (!rowBasedChecksumEnabled_) { |
| 1640 | + return; |
| 1641 | + } |
| 1642 | + |
| 1643 | + auto numRows = rv.size(); |
| 1644 | + VELOX_DCHECK(rv.nulls() == nullptr, "RowVector with top-level nulls not supported for checksum"); |
| 1645 | + // Get the RowVector to serialize (strip pid column if present). |
| 1646 | + facebook::velox::RowVectorPtr dataVector; |
| 1647 | + if (partitioner_->hasPid()) { |
| 1648 | + // Strip the first column (partition id). |
| 1649 | + auto rowType = std::dynamic_pointer_cast<const facebook::velox::RowType>(rv.type()); |
| 1650 | + std::vector<std::string> names(rowType->names().begin() + 1, rowType->names().end()); |
| 1651 | + std::vector<facebook::velox::TypePtr> types(rowType->children().begin() + 1, rowType->children().end()); |
| 1652 | + std::vector<facebook::velox::VectorPtr> children(rv.children().begin() + 1, rv.children().end()); |
| 1653 | + auto dataType = facebook::velox::ROW(std::move(names), std::move(types)); |
| 1654 | + dataVector = |
| 1655 | + std::make_shared<facebook::velox::RowVector>(rv.pool(), dataType, nullptr, numRows, std::move(children)); |
| 1656 | + } else { |
| 1657 | + auto rowType = std::dynamic_pointer_cast<const facebook::velox::RowType>(rv.type()); |
| 1658 | + dataVector = std::make_shared<facebook::velox::RowVector>(rv.pool(), rowType, nullptr, numRows, rv.children()); |
| 1659 | + } |
| 1660 | + |
| 1661 | + facebook::velox::row::UnsafeRowFast fast(dataVector); |
| 1662 | + auto dataType = std::dynamic_pointer_cast<const facebook::velox::RowType>(dataVector->type()); |
| 1663 | + auto fixedSize = facebook::velox::row::UnsafeRowFast::fixedRowSize(dataType); |
| 1664 | + int32_t bufSize = fixedSize.value_or(1024); |
| 1665 | + if (checksumBuffer_.size() < static_cast<size_t>(bufSize)) { |
| 1666 | + checksumBuffer_.resize(bufSize); |
| 1667 | + } |
| 1668 | + |
| 1669 | + for (uint32_t row = 0; row < numRows; ++row) { |
| 1670 | + auto pid = row2Partition_[row]; |
| 1671 | + auto size = fast.rowSize(row); |
| 1672 | + if (size > static_cast<int32_t>(checksumBuffer_.size())) { |
| 1673 | + checksumBuffer_.resize(size); |
| 1674 | + } |
| 1675 | + std::memset(checksumBuffer_.data(), 0, size); |
| 1676 | + fast.serialize(row, checksumBuffer_.data()); |
| 1677 | + |
| 1678 | + auto hash = static_cast<int64_t>(XXH64(checksumBuffer_.data(), size, 0)); |
| 1679 | + checksumXor_[pid] ^= hash; |
| 1680 | + checksumSum_[pid] += hash; |
| 1681 | + } |
| 1682 | +} |
| 1683 | + |
1619 | 1684 | } // namespace gluten |
0 commit comments