Skip to content

[CUB] Makes the batched top-k selection direction compile-time only#9286

Open
elstehle wants to merge 3 commits into
NVIDIA:mainfrom
elstehle:fix/topk-direction-static-only
Open

[CUB] Makes the batched top-k selection direction compile-time only#9286
elstehle wants to merge 3 commits into
NVIDIA:mainfrom
elstehle:fix/topk-direction-static-only

Conversation

@elstehle
Copy link
Copy Markdown
Contributor

@elstehle elstehle commented Jun 7, 2026

Description

This PR limits the top-k direction (top-k min versus max) to just a compile-time option, dropping the possibility of passing in a runtime value for the direction.

Previously, we thought about exposing the possibility for users to specify at, compile time, which directions they want to be able to invoke the algorithm with. E.g., users could compile the algorithm with support for both directions, being able to choose one of those two options at runtime. Or, alternatively, a single compile-time option, where the runtime(!) option would have had to match the compile-time option(!). This is exactly where a dangerous footgun existed.

@elstehle elstehle requested a review from a team as a code owner June 7, 2026 11:00
@elstehle elstehle requested a review from srinivasyadav18 June 7, 2026 11:00
@github-project-automation github-project-automation Bot moved this to Todo in CCCL Jun 7, 2026
@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL Jun 7, 2026
@elstehle elstehle requested review from gevtushenko and pauleonix June 7, 2026 11:01
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 7, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 95c62b9f-89e7-425e-827d-33d6dedf3215

📥 Commits

Reviewing files that changed from the base of the PR and between 87fdfe2 and 749bfa1.

📒 Files selected for processing (1)
  • cub/cub/device/dispatch/dispatch_batched_topk.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • cub/cub/device/dispatch/dispatch_batched_topk.cuh

Note: CodeRabbit is enabled on this repository as a convenience for maintainers
and contributors. Use your best judgment when considering its review comments and
suggestions — a suggested change may be inadequate, unnecessary, or safe to ignore.
Contributors are not expected to address every comment. Human reviews are what
ultimately matter for merging.

Overview

This PR restricts the batched top-k selection direction (top-k min vs max) to a compile-time-only option and removes runtime direction forms. It eliminates a prior design that allowed runtime or dual-direction compilation paths (which could lead to accidental mismatches) and surfaces direction selection via dedicated DeviceBatchedTopK entry points (DeviceBatchedTopK::Min*/::Max*).

The change also adds an explicit compile-time rejection and documentation for any non-compile-time direction forms so callers receive a clear diagnostic instead of noisy template errors. An attempted dedicated compile-fail test was considered then removed as redundant because the dispatch path is only invoked internally with compile-time constants.

Changes

Core Infrastructure (cub/cub/detail/segmented_params.cuh)

  • Removed runtime-backed discrete-parameter templates:
    • uniform_discrete_param<T, Options...>
    • per_segment_discrete_param<IteratorT, T, Options...>
  • Added static_discrete_param<T, Value> that encodes a compile-time-fixed discrete parameter and returns Value from get_param.
  • Updated dispatcher documentation to describe resolving generic discrete parameters to compile-time constants before invoking functors.

Batched Top-K Dispatch (cub/cub/device/dispatch/dispatch_batched_topk.cuh)

  • Constrained detail::batched_topk::wrap_select_direction to accept only compile-time directions passed as ::cuda::__argument::__constant.
  • Changed the constant-direction path to wrap into params::static_discrete_param instead of a runtime-backed param.
  • Removed overloads that accepted a runtime detail::topk::select value or per-segment iterator directions.
  • Added a catch-all wrap_select_direction overload that triggers a static_assert for non-__constant inputs, documenting the intentional compile-time-only limitation and providing a clearer diagnostic for misuse.

Tests

  • cub/test/catch2_test_device_segmented_topk_keys.cu and cub/test/catch2_test_device_segmented_topk_pairs.cu:
    • Converted selection direction to a compile-time test axis (select_direction_list covering min and max).
    • Replaced runtime GENERATE_COPY(...) direction generation with constexpr directions derived from the compile-time test type.
    • Updated batched_topk_keys() / batched_topk_pairs() invocations to pass ::cuda::__argument::__constant immediates instead of runtime arguments.
    • Regression tests updated to use compile-time direction immediates (e.g., __constantcub::detail::topk::select::min{}).

Review / Rationale / Notes

  • Reviewer-requested safety checks: reviewer suggested adding negative compile-time tests to prevent reintroduction of a runtime detail::topk::select overload. The author added an explicit static_assert rejection in the wrap_select_direction catch-all to document and enforce compile-time-only usage and initially added a compile-fail test, but later removed the dedicated compile-fail TU as redundant because the DeviceBatchedTopK entry points only invoke dispatch with ::cuda::__argument::__constant.
  • A nit recommended a non-__constant rejection path to provide a clear diagnostic; the PR implemented an explicit rejection (static_assert) so callers who attempt to pass a runtime detail::topk::select receive a clear, documented diagnostic rather than template noise.
  • Public API surface: entry points expose direction via separate overloads (DeviceBatchedTopK::{Min,Max}{Keys,Pairs}), so there is no user-facing path to pass a runtime detail::topk::select.

Impact

  • Enforces compile-time-only top-k selection direction, removing a class of runtime mismatch footguns.
  • Tests and internal dispatch are updated to reflect compile-time direction selection.
  • Slight API simplification: runtime/per-segment direction forms are removed for this dispatch path; callers should use the provided compile-time overloads.

Walkthrough

important: Discrete dispatch types now use a compile-time static_discrete_param; batched-topk dispatch accepts only compile-time __constant<Dir>. Tests updated to treat selection direction as a compile-time axis and pass __constant<direction> to batched-topk calls.

Changes

important:

Discrete parameters and top-k direction dispatch

Layer / File(s) Summary
Static discrete parameter type definition
cub/cub/detail/segmented_params.cuh
static_discrete_param<T, Value> introduced replacing uniform_discrete_param and per_segment_discrete_param; get_param returns compile-time Value. Dispatcher docs updated to describe generic discrete parameter resolution.
Dispatch wrapper integration
cub/cub/device/dispatch/dispatch_batched_topk.cuh
wrap_select_direction restricted to ::cuda::__argument::__constant<Dir> and maps to params::static_discrete_param; removed runtime and per-segment overloads and added catch-all static_assert overload.
Top-k keys test compile-time direction
cub/test/catch2_test_device_segmented_topk_keys.cu
Added select_direction_list compile-time axis; tests use constexpr direction from test type and pass ::cuda::__argument::__constant<direction>{} to batched_topk_keys. Regression test updated similarly.
Top-k pairs test compile-time direction
cub/test/catch2_test_device_segmented_topk_pairs.cu
Fixed-size and variable-size segment tests add select_direction_list axis; direction becomes constexpr from test type and is passed as ::cuda::__argument::__constant<direction> to batched_topk_pairs.

suggestion:

Possibly related PRs

  • NVIDIA/cccl#9074: Foundation refactor migrating discrete dispatch to cuda::argument/compile-time parameter model that this PR completes.

suggestion:

Suggested reviewers

  • wmaxey
  • pauleonix
  • shwina

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
cub/cub/device/dispatch/dispatch_batched_topk.cuh (1)

52-59: ⚡ Quick win

suggestion: Add an explicit non-__constant rejection path here. The new contract is enforced only by missing-overload resolution in wrap_select_direction, so a caller that still passes a runtime detail::topk::select gets template spew from inside dispatch instead of a direct diagnostic. A deleted fallback overload or a targeted static_assert would keep the compile-time-only behavior but make the failure mode obvious.

cub/test/catch2_test_device_segmented_topk_keys.cu (1)

88-90: ⚡ Quick win

suggestion: Add one negative compile-time check for runtime directions. These cases only cover __constant<direction>, so accidentally reintroducing an overload that accepts a runtime detail::topk::select would still keep this suite green even though it reopens the footgun this PR is removing. A small static_assert(!requires { ... }) or a compile-fail test would lock the new contract in.

Also applies to: 108-109, 161-162, 190-191, 260-260, 298-298


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 935c0385-fbfe-4f24-a854-038054b3bd7e

📥 Commits

Reviewing files that changed from the base of the PR and between cea7dcd and 87fdfe2.

📒 Files selected for processing (4)
  • cub/cub/detail/segmented_params.cuh
  • cub/cub/device/dispatch/dispatch_batched_topk.cuh
  • cub/test/catch2_test_device_segmented_topk_keys.cu
  • cub/test/catch2_test_device_segmented_topk_pairs.cu

elstehle added 2 commits June 7, 2026 04:35
Address review feedback:
  * wrap_select_direction now emits a clear static_assert ("the batched
    top-k selection direction must be a compile-time option: pass
    ::cuda::__argument::__constant<...>") for any non-__constant direction,
    instead of relying on overload-resolution failure. This gives a direct,
    self-explaining diagnostic at the call site.
  * Add a CUB compile-fail test (test/test_device_segmented_topk_direction_fail.cu)
    that passes a runtime direction and asserts, via the // expected-error
    machinery, that compilation fails with that message -- locking in the
    contract so reintroducing a runtime-direction overload fails the build.
dispatch is only ever invoked with a selection direction we construct
internally: the min/max device entry points
(DeviceBatchedTopK::{Max,Min}{Keys,Pairs}) create the
::cuda::__argument::__constant<Dir>, so there is no user-facing path that
could pass a runtime direction to wrap_select_direction. The dedicated
compile-fail test is therefore redundant and is removed.

The wrap_select_direction catch-all overload is kept, with its static_assert
reworded to document the intentional, current limitation to compile-time
constant directions selected by the {Max,Min}{Keys,Pairs} interfaces.
@elstehle
Copy link
Copy Markdown
Contributor Author

elstehle commented Jun 7, 2026

cub/test/catch2_test_device_segmented_topk_keys.cu (1)> 88-90: ⚡ Quick win

suggestion: Add one negative compile-time check for runtime directions. These cases only cover __constant<direction>, so accidentally reintroducing an overload that accepts a runtime detail::topk::select would still keep this suite green even though it reopens the footgun this PR is removing. A small static_assert(!requires { ... }) or a compile-fail test would lock the new contract in.
Also applies to: 108-109, 161-162, 190-191, 260-260, 298-298

We plan to expose the direction via separate overloads via DeviceBatchedTopK::{Min,Max}{Keys,Pairs}. So there is no existing and no planned user-facing path that can hand a runtime detail::topk::select to dispatch, so a dedicated fail-to-compile translation unit would only be testing that our own internal code keeps doing the right thing. Not really a value add.

🧹 Nitpick comments (2)

cub/cub/device/dispatch/dispatch_batched_topk.cuh (1)> 52-59: ⚡ Quick win

suggestion: Add an explicit non-__constant rejection path here. The new contract is enforced only by missing-overload resolution in wrap_select_direction, so a caller that still passes a runtime detail::topk::select gets template spew from inside dispatch instead of a direct diagnostic. A deleted fallback overload or a targeted static_assert would keep the compile-time-only behavior but make the failure mode obvious.

Done. More for documentation purposes, though (for the reason mentioned above:

We plan to expose the direction via separate overloads via DeviceBatchedTopK::{Min,Max}{Keys,Pairs}.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 7, 2026

🥳 CI Workflow Results

🟩 Finished in 2h 57m: Pass: 100%/284 | Total: 2d 09h | Max: 39m 36s | Hits: 93%/209094

See results here.

@elstehle elstehle changed the title makes the batched top-k selection direction compile-time only [CUB] Makes the batched top-k selection direction compile-time only Jun 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Review

Development

Successfully merging this pull request may close these issues.

1 participant