[CUB] Makes the batched top-k selection direction compile-time only#9286
[CUB] Makes the batched top-k selection direction compile-time only#9286elstehle wants to merge 3 commits into
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
OverviewThis PR restricts the batched top-k selection direction (top-k min vs max) to a compile-time-only option and removes runtime direction forms. It eliminates a prior design that allowed runtime or dual-direction compilation paths (which could lead to accidental mismatches) and surfaces direction selection via dedicated DeviceBatchedTopK entry points (DeviceBatchedTopK::Min*/::Max*). The change also adds an explicit compile-time rejection and documentation for any non-compile-time direction forms so callers receive a clear diagnostic instead of noisy template errors. An attempted dedicated compile-fail test was considered then removed as redundant because the dispatch path is only invoked internally with compile-time constants. ChangesCore Infrastructure (cub/cub/detail/segmented_params.cuh)
Batched Top-K Dispatch (cub/cub/device/dispatch/dispatch_batched_topk.cuh)
Tests
Review / Rationale / Notes
Impact
Walkthroughimportant: Discrete dispatch types now use a compile-time Changesimportant: Discrete parameters and top-k direction dispatch
suggestion: Possibly related PRs
suggestion: Suggested reviewers
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
cub/cub/device/dispatch/dispatch_batched_topk.cuh (1)
52-59: ⚡ Quick winsuggestion: Add an explicit non-
__constantrejection path here. The new contract is enforced only by missing-overload resolution inwrap_select_direction, so a caller that still passes a runtimedetail::topk::selectgets template spew from insidedispatchinstead of a direct diagnostic. A deleted fallback overload or a targetedstatic_assertwould keep the compile-time-only behavior but make the failure mode obvious.cub/test/catch2_test_device_segmented_topk_keys.cu (1)
88-90: ⚡ Quick winsuggestion: Add one negative compile-time check for runtime directions. These cases only cover
__constant<direction>, so accidentally reintroducing an overload that accepts a runtimedetail::topk::selectwould still keep this suite green even though it reopens the footgun this PR is removing. A smallstatic_assert(!requires { ... })or a compile-fail test would lock the new contract in.Also applies to: 108-109, 161-162, 190-191, 260-260, 298-298
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 935c0385-fbfe-4f24-a854-038054b3bd7e
📒 Files selected for processing (4)
cub/cub/detail/segmented_params.cuhcub/cub/device/dispatch/dispatch_batched_topk.cuhcub/test/catch2_test_device_segmented_topk_keys.cucub/test/catch2_test_device_segmented_topk_pairs.cu
Address review feedback:
* wrap_select_direction now emits a clear static_assert ("the batched
top-k selection direction must be a compile-time option: pass
::cuda::__argument::__constant<...>") for any non-__constant direction,
instead of relying on overload-resolution failure. This gives a direct,
self-explaining diagnostic at the call site.
* Add a CUB compile-fail test (test/test_device_segmented_topk_direction_fail.cu)
that passes a runtime direction and asserts, via the // expected-error
machinery, that compilation fails with that message -- locking in the
contract so reintroducing a runtime-direction overload fails the build.
dispatch is only ever invoked with a selection direction we construct
internally: the min/max device entry points
(DeviceBatchedTopK::{Max,Min}{Keys,Pairs}) create the
::cuda::__argument::__constant<Dir>, so there is no user-facing path that
could pass a runtime direction to wrap_select_direction. The dedicated
compile-fail test is therefore redundant and is removed.
The wrap_select_direction catch-all overload is kept, with its static_assert
reworded to document the intentional, current limitation to compile-time
constant directions selected by the {Max,Min}{Keys,Pairs} interfaces.
We plan to expose the direction via separate overloads via
Done. More for documentation purposes, though (for the reason mentioned above:
|
🥳 CI Workflow Results🟩 Finished in 2h 57m: Pass: 100%/284 | Total: 2d 09h | Max: 39m 36s | Hits: 93%/209094See results here. |
Description
This PR limits the top-k direction (top-k min versus max) to just a compile-time option, dropping the possibility of passing in a runtime value for the direction.
Previously, we thought about exposing the possibility for users to specify at, compile time, which directions they want to be able to invoke the algorithm with. E.g., users could compile the algorithm with support for both directions, being able to choose one of those two options at runtime. Or, alternatively, a single compile-time option, where the runtime(!) option would have had to match the compile-time option(!). This is exactly where a dangerous footgun existed.