1010#include < cstdint> // for uint8_t, uint32_t, int32_t
1111#include < memory> // for shared_ptr
1212
13+ #include " cuda_stream.h" // for StreamRef
1314#include " device_compression.cuh"
14- #include " cuda_stream.h" // for StreamRef
1515#include " device_helpers.cuh" // for MemcpyBatchAsync
1616#include " xgboost/span.h" // for Span
1717
@@ -277,11 +277,17 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
277277 // Fallback to nvcomp. This is only used during tests where we don't have access to DE
278278 // but still want the test coverage.
279279 CHECK (allow_fallback);
280- CheckAlign (nvcompBatchedSnappyDecompressRequiredAlignments);
280+ nvcompAlignmentRequirements_t decompression_alignment_reqs;
281+ SafeNvComp (nvcompBatchedSnappyDecompressGetRequiredAlignments (
282+ nvcompBatchedSnappyDecompressDefaultOpts, &decompression_alignment_reqs));
283+ CheckAlign (decompression_alignment_reqs);
281284 auto n_chunks = mgr_impl->Chunks ();
282285 // Get sketch space
283286 std::size_t n_tmp_bytes = 0 ;
284- SafeNvComp (nvcompBatchedSnappyDecompressGetTempSize (n_chunks, /* unused*/ 0 , &n_tmp_bytes));
287+ SafeNvComp (nvcompBatchedSnappyDecompressGetTempSizeAsync (
288+ n_chunks, /* max_uncompressed_chunk_bytes=*/ 0 , nvcompBatchedSnappyDecompressDefaultOpts,
289+ &n_tmp_bytes,
290+ /* max_total_uncompressed_bytes=*/ 0 ));
285291 dh::device_vector<char > tmp (n_tmp_bytes, 0 );
286292
287293 dh::device_vector<nvcompStatus_t> status (n_chunks, nvcompSuccess);
@@ -297,7 +303,8 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
297303 SafeNvComp (nvcompBatchedSnappyDecompressAsync (
298304 mgr_impl->d_in_chunk_ptrs .data ().get (), mgr_impl->d_in_chunk_sizes .data ().get (),
299305 mgr_impl->d_out_chunk_sizes .data ().get (), mgr_impl->act_nbytes .data ().get (), n_chunks,
300- tmp.data ().get (), n_tmp_bytes, d_out_ptrs.data ().get (), status.data ().get (), stream));
306+ tmp.data ().get (), n_tmp_bytes, d_out_ptrs.data ().get (),
307+ nvcompBatchedSnappyDecompressDefaultOpts, status.data ().get (), stream));
301308 }
302309}
303310
@@ -307,7 +314,7 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
307314 std::size_t chunk_size) {
308315 CHECK_GT (chunk_size, 0 );
309316 auto cuctx = ctx->CUDACtx ();
310- auto nvcomp_batched_snappy_opts = nvcompBatchedSnappyDefaultOpts ;
317+ auto nvcomp_batched_snappy_opts = nvcompBatchedSnappyCompressDefaultOpts ;
311318
312319 nvcompAlignmentRequirements_t compression_alignment_reqs;
313320 SafeNvComp (nvcompBatchedSnappyCompressGetRequiredAlignments (nvcomp_batched_snappy_opts,
@@ -352,8 +359,9 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
352359 * Outputs
353360 */
354361 std::size_t comp_temp_bytes;
355- SafeNvComp (nvcompBatchedSnappyCompressGetTempSize (n_chunks, chunk_size,
356- nvcomp_batched_snappy_opts, &comp_temp_bytes));
362+ SafeNvComp (nvcompBatchedSnappyCompressGetTempSizeAsync (
363+ n_chunks, chunk_size, nvcomp_batched_snappy_opts, &comp_temp_bytes,
364+ /* max_total_uncompressed_bytes=*/ in.size ()));
357365 CHECK_EQ (comp_temp_bytes, 0 );
358366 dh::DeviceUVector<char > comp_tmp (comp_temp_bytes);
359367
@@ -381,7 +389,8 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
381389 */
382390 SafeNvComp (nvcompBatchedSnappyCompressAsync (
383391 in_ptrs.data (), in_sizes.data (), max_in_nbytes, n_chunks, comp_tmp.data (), comp_temp_bytes,
384- out_ptrs.data (), out_sizes.data (), nvcomp_batched_snappy_opts, cuctx->Stream ()));
392+ out_ptrs.data (), out_sizes.data (), nvcomp_batched_snappy_opts, /* device_statuses=*/ nullptr ,
393+ cuctx->Stream ()));
385394 auto n_bytes = thrust::reduce (cuctx->CTP (), out_sizes.cbegin (), out_sizes.cend ());
386395 auto n_total_bytes = p_out->size ();
387396 auto ratio = static_cast <double >(n_total_bytes) / in.size_bytes ();
@@ -410,9 +419,8 @@ void DecompressSnappy(curt::StreamRef stream, SnappyDecomprMgr const& mgr,
410419}
411420
412421[[nodiscard]] common::RefResourceView<std::uint8_t > CoalesceCompressedBuffersToHost (
413- curt::StreamRef stream, std::shared_ptr<HostPinnedMemPool> pool,
414- CuMemParams const & in_params, dh::DeviceUVector<std::uint8_t > const & in_buf,
415- CuMemParams* p_out) {
422+ curt::StreamRef stream, std::shared_ptr<HostPinnedMemPool> pool, CuMemParams const & in_params,
423+ dh::DeviceUVector<std::uint8_t > const & in_buf, CuMemParams* p_out) {
416424 std::size_t n_total_act_bytes = in_params.TotalSrcActBytes ();
417425 std::size_t n_total_bytes = in_params.TotalSrcBytes ();
418426 if (n_total_bytes == 0 ) {
0 commit comments