Skip to content

Commit 4cf0948

Browse files
authored
Fix compilation with the latest nvcomp. (dmlc#11997)
1 parent c3e5782 commit 4cf0948

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

src/common/device_compression.cu

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
#include <cstdint> // for uint8_t, uint32_t, int32_t
1111
#include <memory> // for shared_ptr
1212

13+
#include "cuda_stream.h" // for StreamRef
1314
#include "device_compression.cuh"
14-
#include "cuda_stream.h" // for StreamRef
1515
#include "device_helpers.cuh" // for MemcpyBatchAsync
1616
#include "xgboost/span.h" // for Span
1717

@@ -277,11 +277,17 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
277277
// Fallback to nvcomp. This is only used during tests where we don't have access to DE
278278
// but still want the test coverage.
279279
CHECK(allow_fallback);
280-
CheckAlign(nvcompBatchedSnappyDecompressRequiredAlignments);
280+
nvcompAlignmentRequirements_t decompression_alignment_reqs;
281+
SafeNvComp(nvcompBatchedSnappyDecompressGetRequiredAlignments(
282+
nvcompBatchedSnappyDecompressDefaultOpts, &decompression_alignment_reqs));
283+
CheckAlign(decompression_alignment_reqs);
281284
auto n_chunks = mgr_impl->Chunks();
282285
// Get sketch space
283286
std::size_t n_tmp_bytes = 0;
284-
SafeNvComp(nvcompBatchedSnappyDecompressGetTempSize(n_chunks, /*unused*/ 0, &n_tmp_bytes));
287+
SafeNvComp(nvcompBatchedSnappyDecompressGetTempSizeAsync(
288+
n_chunks, /*max_uncompressed_chunk_bytes=*/0, nvcompBatchedSnappyDecompressDefaultOpts,
289+
&n_tmp_bytes,
290+
/*max_total_uncompressed_bytes=*/0));
285291
dh::device_vector<char> tmp(n_tmp_bytes, 0);
286292

287293
dh::device_vector<nvcompStatus_t> status(n_chunks, nvcompSuccess);
@@ -297,7 +303,8 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
297303
SafeNvComp(nvcompBatchedSnappyDecompressAsync(
298304
mgr_impl->d_in_chunk_ptrs.data().get(), mgr_impl->d_in_chunk_sizes.data().get(),
299305
mgr_impl->d_out_chunk_sizes.data().get(), mgr_impl->act_nbytes.data().get(), n_chunks,
300-
tmp.data().get(), n_tmp_bytes, d_out_ptrs.data().get(), status.data().get(), stream));
306+
tmp.data().get(), n_tmp_bytes, d_out_ptrs.data().get(),
307+
nvcompBatchedSnappyDecompressDefaultOpts, status.data().get(), stream));
301308
}
302309
}
303310

@@ -307,7 +314,7 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
307314
std::size_t chunk_size) {
308315
CHECK_GT(chunk_size, 0);
309316
auto cuctx = ctx->CUDACtx();
310-
auto nvcomp_batched_snappy_opts = nvcompBatchedSnappyDefaultOpts;
317+
auto nvcomp_batched_snappy_opts = nvcompBatchedSnappyCompressDefaultOpts;
311318

312319
nvcompAlignmentRequirements_t compression_alignment_reqs;
313320
SafeNvComp(nvcompBatchedSnappyCompressGetRequiredAlignments(nvcomp_batched_snappy_opts,
@@ -352,8 +359,9 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
352359
* Outputs
353360
*/
354361
std::size_t comp_temp_bytes;
355-
SafeNvComp(nvcompBatchedSnappyCompressGetTempSize(n_chunks, chunk_size,
356-
nvcomp_batched_snappy_opts, &comp_temp_bytes));
362+
SafeNvComp(nvcompBatchedSnappyCompressGetTempSizeAsync(
363+
n_chunks, chunk_size, nvcomp_batched_snappy_opts, &comp_temp_bytes,
364+
/*max_total_uncompressed_bytes=*/in.size()));
357365
CHECK_EQ(comp_temp_bytes, 0);
358366
dh::DeviceUVector<char> comp_tmp(comp_temp_bytes);
359367

@@ -381,7 +389,8 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
381389
*/
382390
SafeNvComp(nvcompBatchedSnappyCompressAsync(
383391
in_ptrs.data(), in_sizes.data(), max_in_nbytes, n_chunks, comp_tmp.data(), comp_temp_bytes,
384-
out_ptrs.data(), out_sizes.data(), nvcomp_batched_snappy_opts, cuctx->Stream()));
392+
out_ptrs.data(), out_sizes.data(), nvcomp_batched_snappy_opts, /*device_statuses=*/nullptr,
393+
cuctx->Stream()));
385394
auto n_bytes = thrust::reduce(cuctx->CTP(), out_sizes.cbegin(), out_sizes.cend());
386395
auto n_total_bytes = p_out->size();
387396
auto ratio = static_cast<double>(n_total_bytes) / in.size_bytes();
@@ -410,9 +419,8 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
410419
}
411420

412421
[[nodiscard]] common::RefResourceView<std::uint8_t> CoalesceCompressedBuffersToHost(
413-
curt::StreamRef stream, std::shared_ptr<HostPinnedMemPool> pool,
414-
CuMemParams const& in_params, dh::DeviceUVector<std::uint8_t> const& in_buf,
415-
CuMemParams* p_out) {
422+
curt::StreamRef stream, std::shared_ptr<HostPinnedMemPool> pool, CuMemParams const& in_params,
423+
dh::DeviceUVector<std::uint8_t> const& in_buf, CuMemParams* p_out) {
416424
std::size_t n_total_act_bytes = in_params.TotalSrcActBytes();
417425
std::size_t n_total_bytes = in_params.TotalSrcBytes();
418426
if (n_total_bytes == 0) {

0 commit comments

Comments
 (0)