-
Notifications
You must be signed in to change notification settings - Fork 402
run to run scan warpspeed impl sm100+ #9263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |||||
| #include <cuda/__memory/is_aligned.h> | ||||||
| #include <cuda/__ptx/instructions/get_sreg.h> | ||||||
| #include <cuda/__type_traits/is_trivially_copyable.h> | ||||||
| #include <cuda/std/__algorithm/min.h> | ||||||
| #include <cuda/std/__bit/popcount.h> | ||||||
| #include <cuda/std/__type_traits/underlying_type.h> | ||||||
|
|
||||||
|
|
@@ -261,6 +262,88 @@ template <int numTileStatesPerThread, typename AccumT, typename ScanOpT> | |||||
| return aggrExclusiveCtaCur; // must only be valid in lane_0 | ||||||
| } | ||||||
|
|
||||||
| // Deterministic version of warpIncrementalLookahead that returns the same aggrExclusiveCta. The difference is that it | ||||||
| // always starts the lookahead from a tile index that is a multiple of 32: it shifts the left pointer (idxTilePrev) down | ||||||
| // to the nearest multiple of 32 and reduces from there. Because every reduction begins at the same fixed tiles, no | ||||||
| // matter which tiles happened to finish first, the order in which values are summed is always the same and the result | ||||||
| // is identical on every run. idxTilePrev/aggrExclusiveCtaPrev are updated by reference to the last multiple of 32. | ||||||
| template <int numTileStatesPerThread, typename AccumT, typename ScanOpT> | ||||||
| [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE AccumT warpIncrementalLookaheadStable( | ||||||
| SpecialRegisters specialRegisters, | ||||||
| tile_state_t<AccumT>* ptrTileStates, | ||||||
| int& idxTilePrev, | ||||||
| AccumT& aggrExclusiveCtaPrev, | ||||||
| const int idxTileNext, | ||||||
| ScanOpT& scan_op) | ||||||
| { | ||||||
| const int laneIdx = specialRegisters.laneIdx; | ||||||
| const ::cuda::std::uint32_t lanemaskEq = ::cuda::ptx::get_sreg_lanemask_eq(); | ||||||
|
|
||||||
| // Adjust the left pointer down to the nearest 32-multiple so we do batched sums | ||||||
| int idxTileCur = (idxTilePrev / 32) * 32; | ||||||
| AccumT aggrExclusiveCtaCur = aggrExclusiveCtaPrev; | ||||||
|
|
||||||
| using warp_reduce_t = WarpReduce<AccumT>; | ||||||
| static_assert(sizeof(typename warp_reduce_t::TempStorage) <= 4, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why 4? I assume this is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please put this as a comment then in the src? 4 is quite a magic value to capture this, I would have expected 1 or something like that then
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not sure about that. I think it is just inheriting from Therefore the check that I came up with is static_assert(::cuda::std::is_base_of_v<cub::Uninitialized<cub::NullType>, TempStorage>, "Code assumes empty TempStorage");Pretty verbose/not super readable, but at least no magic number and a bit clearer in its motivation once one gets to the bottom of it? And no chance for this one to not trigger if we would start requiring temporary storage.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would strongly suggest an inline variable of the form: template<class>
inline constexpr bool __requires_temp_storage = true;
template<>
inline constexpr bool __requires_temp_storage<cub::Uninitialized<cub::NullType>> = false;
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, here is my attempt: #9294 |
||||||
| "WarpReduce with non-trivial temporary storage is not supported yet in this kernel."); | ||||||
| [[maybe_unused]] typename warp_reduce_t::TempStorage temp_storage; | ||||||
|
|
||||||
| using warp_reduce_or_t = WarpReduce<::cuda::std::uint32_t>; | ||||||
| typename warp_reduce_or_t::TempStorage temp_storage_or; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: |
||||||
| warp_reduce_or_t warp_reduce_or{temp_storage_or}; | ||||||
| constexpr ::cuda::std::bit_or<::cuda::std::uint32_t> or_op{}; | ||||||
|
|
||||||
| while (idxTileCur < idxTileNext) | ||||||
| { | ||||||
| tile_state_t<AccumT> regTmpStates[numTileStatesPerThread]; | ||||||
| warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext); | ||||||
|
|
||||||
| for (int idx = 0; idx < numTileStatesPerThread; ++idx) | ||||||
| { | ||||||
| // Bitmask with a 1 bit in the position of the current lane if current lane has a tile aggregate | ||||||
| const ::cuda::std::uint32_t lane_has_aggregate = | ||||||
| lanemaskEq * (regTmpStates[idx].state == scan_state::tile_aggregate); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you benchmarked this multiplication to be an improvement over predication? Otherwise I would stay with
Suggested change
My (possibly wrong) intuition is that the multiplication will result in either the same output or still generate a predicated move in addition to the multiplication since it needs to transform a predicate register into an integer. |
||||||
|
|
||||||
| // Bitmask with 1 bits indicating which lane has a tile aggregate | ||||||
| const ::cuda::std::uint32_t warp_has_aggregate_mask = warp_reduce_or.Reduce(lane_has_aggregate, or_op); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An even easier (and faster?) way of getting this mask would be a call to |
||||||
|
|
||||||
| // Bitmask with 1 bits for all rightmost lanes having a tile aggregate | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| const ::cuda::std::uint32_t warp_right_aggregates_mask = warp_has_aggregate_mask & (~warp_has_aggregate_mask - 1); | ||||||
|
|
||||||
| const ::cuda::std::uint32_t warp_right_aggregates_count = ::cuda::std::popcount(warp_right_aggregates_mask); | ||||||
|
|
||||||
| // Only reduce once a fixed number of contiguous tile aggregates are available, so the reduction order is fixed. | ||||||
| const ::cuda::std::uint32_t expected_count = | ||||||
| static_cast<::cuda::std::uint32_t>(::cuda::std::min(32, idxTileNext - idxTileCur)); | ||||||
| if (warp_right_aggregates_count < expected_count) | ||||||
| { | ||||||
| break; | ||||||
| } | ||||||
|
|
||||||
| const bool use_value = lanemaskEq & warp_right_aggregates_mask; | ||||||
| const AccumT value = use_value ? regTmpStates[idx].value : cuda::identity_element<ScanOpT, AccumT>(); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case there is no identity element, you could use the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think the deterministic path is only taken for FP32 and FP64. |
||||||
| const AccumT local_aggr = warp_reduce_t{temp_storage}.Reduce(value, scan_op); | ||||||
|
|
||||||
| if (expected_count == 32) | ||||||
| { | ||||||
| aggrExclusiveCtaCur = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr); | ||||||
| idxTileCur += 32; | ||||||
| } | ||||||
| else | ||||||
| { | ||||||
| const AccumT full_aggr = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr); | ||||||
| idxTilePrev = idxTileCur; | ||||||
| aggrExclusiveCtaPrev = aggrExclusiveCtaCur; | ||||||
| return full_aggr; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| idxTilePrev = idxTileNext; | ||||||
| aggrExclusiveCtaPrev = aggrExclusiveCtaCur; | ||||||
| return aggrExclusiveCtaCur; // must only be valid in lane_0 | ||||||
| } | ||||||
|
|
||||||
| #endif // __cccl_ptx_isa >= 860 | ||||||
| } // namespace detail::warpspeed | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Use
cuda::round_down.