Skip to content

Commit 454f3e0

Browse files
authored
Add block-level CUB JIT support (#1180)
* Add block-level CUB JIT support Use CUB Block* APIs for JIT-able reductions, scans, sorts, and argsorts, including reduced-rank block kernels and fusion-safe scalar handling for surrounding MatX operators. Query CUB temporary storage by compiling an NVRTC probe to PTX and reading the global initializer, then feed that static shared-memory usage into launch capability negotiation. Cap CUB JIT elements-per-thread to powers of two whose total item width is at most 16 bytes, and add coverage for sum/prod/min/max/cumsum/sort/argsort plus fused expressions with fftshift, fft, and linspace-generated inputs. * Address CUB JIT review feedback Use explicit SUM_QUERY identities for shared-memory aggregation, make non-JIT CUB shared-memory queries fail loudly, and make the NVRTC temp-storage probe robust to signed or unsigned PTX declarations. Also align generated binary JIT capability routing with direct block-reduction detection so fused CUB reductions keep block-level thread participation while non-reduction operands use scalar indexing. * Fix SetOp CUB block-reduction evaluation Evaluate direct block-reduction RHS operators with the active CapType so CUB sees the negotiated EPT and block size. Keep ScalarCap only for the output write, where the block aggregate is stored by a single thread.
1 parent 9f09387 commit 454f3e0

26 files changed

Lines changed: 3718 additions & 534 deletions

CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ endif()
339339

340340
# Enable JIT compilation support
341341
if (MATX_EN_JIT OR MATX_EN_MATHDX)
342+
include(GNUInstallDirs)
342343
message(STATUS "Enabling JIT compilation support via NVRTC")
343344
target_compile_definitions(matx INTERFACE MATX_EN_JIT)
344345

@@ -349,6 +350,17 @@ if (MATX_EN_JIT OR MATX_EN_MATHDX)
349350
string(REGEX REPLACE "-virtual$" "" NVRTC_CUDA_ARCH "${NVRTC_CUDA_ARCH}")
350351
target_compile_definitions(matx INTERFACE NVRTC_CUDA_ARCH="${NVRTC_CUDA_ARCH}")
351352
target_compile_definitions(matx INTERFACE NVRTC_CXX_STANDARD="${CMAKE_CXX_STANDARD}")
353+
target_compile_definitions(matx INTERFACE "$<BUILD_INTERFACE:MATX_NVRTC_BUILD_DIR_DEFAULT=\"${CMAKE_BINARY_DIR}\">")
354+
set(MATX_NVRTC_CCCL_BUILD_INCLUDE_DIRS
355+
"${CCCL_SOURCE_DIR}/thrust"
356+
"${CCCL_SOURCE_DIR}/libcudacxx/include"
357+
"${CCCL_SOURCE_DIR}/cub")
358+
list(JOIN MATX_NVRTC_CCCL_BUILD_INCLUDE_DIRS "|" MATX_NVRTC_CCCL_BUILD_INCLUDE_DIRS_DEFAULT)
359+
set(MATX_NVRTC_CCCL_INSTALL_INCLUDE_DIRS_DEFAULT
360+
"$<JOIN:$<TARGET_PROPERTY:CCCL::CCCL,INTERFACE_INCLUDE_DIRECTORIES>,|>")
361+
target_compile_definitions(matx INTERFACE
362+
"$<BUILD_INTERFACE:MATX_NVRTC_CCCL_INCLUDE_DIRS_DEFAULT=\"${MATX_NVRTC_CCCL_BUILD_INCLUDE_DIRS_DEFAULT}\">"
363+
"$<INSTALL_INTERFACE:MATX_NVRTC_CCCL_INCLUDE_DIRS_DEFAULT=\"${MATX_NVRTC_CCCL_INSTALL_INCLUDE_DIRS_DEFAULT}\">")
352364

353365
# Link NVRTC library
354366
target_link_libraries(matx INTERFACE CUDA::nvrtc)

include/matx/core/capabilities.h

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace detail {
5151
struct LTOIRQueryInput {
5252
std::set<std::string> ltoir_symbols;
5353
ElementsPerThread ept;
54-
};
54+
};
5555

5656
// Enum for different operator capabilities
5757
enum class OperatorCapability {
@@ -61,6 +61,7 @@ namespace detail {
6161
SET_ELEMENTS_PER_THREAD, // Set the elements per thread for the operator.
6262
JIT_CLASS_QUERY, // Result is the concatenation of the capabilities of the operator and its children.
6363
DYN_SHM_SIZE, // Result is the dynamic shared memory size required for the operator.
64+
STATIC_SHM_SIZE, // Result is the static shared memory size required for the operator.
6465
BLOCK_DIM, // Result is the block dimensions required for the operator.
6566
GENERATE_LTOIR, // Generate LTOIR code for the operator.
6667
JIT_TYPE_QUERY, // Result is the type of JIT code to generate for the operator.
@@ -72,6 +73,7 @@ namespace detail {
7273
ALIASED_MEMORY, // Whether the operator's input and output pointers alias
7374
GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level
7475
PASS_THROUGH_THREADS, // All threads must call operator() on nested operators; bounds checking done at tensor level
76+
BLOCK_REDUCES_RANK, // Block-level operator's critical dimension is not part of the output rank
7577
UNIT_STRIDE_LAST, // Whether all leaf tensors have stride[RANK-1] == 1
7678
// Add more capabilities as needed
7779
};
@@ -84,10 +86,11 @@ namespace detail {
8486
// The operator itself AND its children.
8587
MIN_QUERY, // Result is the minimum of the capabilities of the operator and its children.
8688
MAX_QUERY, // Result is the maximum of the capabilities of the operator and its children.
89+
SUM_QUERY, // Result is the sum of the capabilities of the operator and its children.
8790
STR_CAT_QUERY, // Result is the concatenation of the capabilities of the operator and its children.
8891
RANGE_QUERY, // Result is the range of the capabilities of the operator and its children.
8992
};
90-
93+
9194

9295
#if !defined(__CUDACC_RTC__)
9396
template <ElementsPerThread EPT, bool JIT, bool UNIT_STRIDE_LAST = false>
@@ -97,15 +100,17 @@ namespace detail {
97100
static constexpr bool unit_stride_last = UNIT_STRIDE_LAST;
98101
static constexpr int osize = 0;
99102
static constexpr int block_size = 0;
103+
static constexpr bool pass_through_threads = false;
104+
using scalar_cap = CapabilityParams<ElementsPerThread::ONE, JIT, UNIT_STRIDE_LAST>;
100105

101106
// For JIT there will be other capabilties patched in with a string
102107
};
103108

104-
using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false, false>;
105-
109+
using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false, false>;
110+
106111
// Concept to detect scoped enums
107112
template<typename T>
108-
concept is_scoped_enum_c = cuda::std::is_enum_v<T> &&
113+
concept is_scoped_enum_c = cuda::std::is_enum_v<T> &&
109114
!cuda::std::is_convertible_v<T, cuda::std::underlying_type_t<T>>;
110115

111116
// Legacy struct for backwards compatibility
@@ -139,7 +144,7 @@ namespace detail {
139144
static constexpr bool default_value = true;
140145
static constexpr bool or_identity = false;
141146
static constexpr bool and_identity = true;
142-
};
147+
};
143148

144149
template <>
145150
struct capability_attributes<OperatorCapability::ASYNC_LOADS_REQUESTED> {
@@ -148,16 +153,16 @@ namespace detail {
148153
static constexpr bool default_value = false;
149154
static constexpr bool or_identity = false;
150155
static constexpr bool and_identity = true;
151-
};
152-
156+
};
157+
153158
template <>
154159
struct capability_attributes<OperatorCapability::GLOBAL_KERNEL> {
155160
using type = bool;
156161
using input_type = VoidCapabilityType;
157162
static constexpr bool default_value = true;
158163
static constexpr bool or_identity = false;
159164
static constexpr bool and_identity = true;
160-
};
165+
};
161166

162167
template <>
163168
struct capability_attributes<OperatorCapability::ALIASED_MEMORY> {
@@ -166,7 +171,7 @@ namespace detail {
166171
static constexpr bool default_value = false;
167172
static constexpr bool or_identity = false;
168173
static constexpr bool and_identity = true;
169-
};
174+
};
170175

171176
template <>
172177
struct capability_attributes<OperatorCapability::GROUPS_PER_BLOCK> {
@@ -176,7 +181,7 @@ namespace detail {
176181
static constexpr cuda::std::array<int, 2> default_value = {1, 32}; // Example: 1 element per thread by default
177182
static constexpr cuda::std::array<int, 2> min_identity = {32, 1};
178183
static constexpr cuda::std::array<int, 2> max_identity = {1, 32};
179-
};
184+
};
180185

181186
template <>
182187
struct capability_attributes<OperatorCapability::BLOCK_DIM> {
@@ -186,7 +191,7 @@ namespace detail {
186191
static constexpr cuda::std::array<int, 2> default_value = {1, 1024}; // Example: 1 element per thread by default
187192
static constexpr cuda::std::array<int, 2> min_identity = {1024, 1};
188193
static constexpr cuda::std::array<int, 2> max_identity = {1, 1024};
189-
};
194+
};
190195

191196
template <>
192197
struct capability_attributes<OperatorCapability::SET_ELEMENTS_PER_THREAD> {
@@ -195,7 +200,7 @@ namespace detail {
195200
static constexpr bool default_value = true;
196201
static constexpr bool or_identity = false;
197202
static constexpr bool and_identity = true;
198-
};
203+
};
199204

200205
template <>
201206
struct capability_attributes<OperatorCapability::SET_GROUPS_PER_BLOCK> {
@@ -204,7 +209,7 @@ namespace detail {
204209
static constexpr bool default_value = true;
205210
static constexpr bool or_identity = false;
206211
static constexpr bool and_identity = true;
207-
};
212+
};
208213

209214
template <>
210215
struct capability_attributes<OperatorCapability::ELEMENTS_PER_THREAD> {
@@ -223,15 +228,15 @@ namespace detail {
223228
static constexpr bool default_value = true;
224229
static constexpr bool or_identity = false;
225230
static constexpr bool and_identity = true;
226-
};
231+
};
227232

228233
template <>
229234
struct capability_attributes<OperatorCapability::JIT_TYPE_QUERY> {
230235
using type = std::string;
231236
using input_type = VoidCapabilityType;
232237
static inline const std::string default_value = "";
233238
static inline const std::string min_identity = "";
234-
};
239+
};
235240

236241
template <>
237242
struct capability_attributes<OperatorCapability::DYN_SHM_SIZE> {
@@ -240,7 +245,18 @@ namespace detail {
240245
static constexpr int default_value = 0;
241246
static constexpr int min_identity = cuda::std::numeric_limits<int>::max();
242247
static constexpr int max_identity = 0;
243-
};
248+
static constexpr int sum_identity = 0;
249+
};
250+
251+
template <>
252+
struct capability_attributes<OperatorCapability::STATIC_SHM_SIZE> {
253+
using type = int;
254+
using input_type = VoidCapabilityType;
255+
static constexpr int default_value = 0;
256+
static constexpr int min_identity = cuda::std::numeric_limits<int>::max();
257+
static constexpr int max_identity = 0;
258+
static constexpr int sum_identity = 0;
259+
};
244260

245261
template <>
246262
struct capability_attributes<OperatorCapability::MAX_EPT_VEC_LOAD> {
@@ -249,6 +265,7 @@ namespace detail {
249265
static constexpr int default_value = 32;
250266
static constexpr int min_identity = 32;
251267
static constexpr int max_identity = 1;
268+
static constexpr int sum_identity = 0;
252269
};
253270

254271
template <>
@@ -260,6 +277,15 @@ namespace detail {
260277
static constexpr bool and_identity = true;
261278
};
262279

280+
template <>
281+
struct capability_attributes<OperatorCapability::BLOCK_REDUCES_RANK> {
282+
using type = bool;
283+
using input_type = VoidCapabilityType;
284+
static constexpr bool default_value = false;
285+
static constexpr bool or_identity = false;
286+
static constexpr bool and_identity = true;
287+
};
288+
263289
template <>
264290
struct capability_attributes<OperatorCapability::UNIT_STRIDE_LAST> {
265291
using type = bool;
@@ -270,7 +296,7 @@ namespace detail {
270296
static constexpr bool default_value = true;
271297
static constexpr bool or_identity = false;
272298
static constexpr bool and_identity = true;
273-
};
299+
};
274300

275301

276302
template <OperatorCapability Cap, typename OperatorType, typename InType>
@@ -292,7 +318,7 @@ namespace detail {
292318
return capability_attributes<Cap>::default_value;
293319
}
294320
}
295-
}
321+
}
296322

297323
// Helper to safely get capability from an operator.
298324
// OperandType is likely base_type_t<ActualOpType> or a raw scalar/functor type.
@@ -301,7 +327,7 @@ namespace detail {
301327
get_operator_capability(const OperatorType& op) {
302328
VoidCapabilityType void_type{};
303329
return get_operator_capability<Cap>(op, void_type);
304-
}
330+
}
305331

306332

307333
// Helper function to get the query type associated with a capability
@@ -332,17 +358,21 @@ namespace detail {
332358
return CapabilityQueryType::STR_CAT_QUERY; // The expression should use the concatenation of the capabilities of its children.
333359
case OperatorCapability::DYN_SHM_SIZE:
334360
return CapabilityQueryType::MAX_QUERY; // The expression should use the maximum dynamic shared memory size of its children.
361+
case OperatorCapability::STATIC_SHM_SIZE:
362+
return CapabilityQueryType::SUM_QUERY; // Static shared memory declarations are additive in fused kernels.
335363
case OperatorCapability::BLOCK_DIM:
336364
return CapabilityQueryType::RANGE_QUERY; // The expression should use the minimum block size supported by all operators.
337365
case OperatorCapability::GENERATE_LTOIR:
338366
return CapabilityQueryType::AND_QUERY; // The expression should generate LTOIR code if all its children generate it.
339367
case OperatorCapability::PASS_THROUGH_THREADS:
340368
return CapabilityQueryType::OR_QUERY; // If ANY operator needs pass-through, all threads must call operator()
369+
case OperatorCapability::BLOCK_REDUCES_RANK:
370+
return CapabilityQueryType::OR_QUERY; // If ANY operator reduces rank, use the reduced-rank block kernel.
341371
case OperatorCapability::UNIT_STRIDE_LAST:
342372
return CapabilityQueryType::AND_QUERY; // All leaf tensors must have stride[RANK-1] == 1
343373
default:
344374
// Default to OR_QUERY or handle as an error/assertion if a capability isn't mapped.
345-
return CapabilityQueryType::OR_QUERY;
375+
return CapabilityQueryType::OR_QUERY;
346376
}
347377
}
348378

@@ -375,6 +405,8 @@ namespace detail {
375405
children_aggregated_val = capability_attributes<Cap>::min_identity;
376406
} else if (query_type == CapabilityQueryType::MAX_QUERY) {
377407
children_aggregated_val = capability_attributes<Cap>::max_identity;
408+
} else if (query_type == CapabilityQueryType::SUM_QUERY) {
409+
children_aggregated_val = capability_attributes<Cap>::sum_identity;
378410
} else {
379411
// Default identity for int if not MIN_QUERY or MAX_QUERY (e.g. if it was SUM_QUERY, identity would be 0)
380412
// This path needs clear definition if other query types are used for int.
@@ -391,7 +423,7 @@ namespace detail {
391423
if constexpr (std::is_same_v<CapType, bool>) {
392424
if (query_type == CapabilityQueryType::OR_QUERY) {
393425
children_aggregated_val = capability_attributes<Cap>::or_identity;
394-
((children_aggregated_val = children_aggregated_val || child_vals), ...);
426+
((children_aggregated_val = children_aggregated_val || child_vals), ...);
395427
} else { // AND_QUERY
396428
children_aggregated_val = capability_attributes<Cap>::and_identity;
397429
((children_aggregated_val = children_aggregated_val && child_vals), ...);
@@ -411,6 +443,9 @@ namespace detail {
411443
for (CapType val : values) {
412444
children_aggregated_val = static_cast<CapType>(cuda::std::max(static_cast<int>(children_aggregated_val), static_cast<int>(val)));
413445
}
446+
} else if (query_type == CapabilityQueryType::SUM_QUERY) {
447+
children_aggregated_val = capability_attributes<Cap>::sum_identity;
448+
((children_aggregated_val += child_vals), ...);
414449
} else {
415450
// Not implemented for other query types.
416451
MATX_ASSERT_STR(false, matxInvalidParameter, "Not implemented for other query types.");
@@ -430,14 +465,14 @@ namespace detail {
430465
auto it = values.begin();
431466
children_aggregated_val = *it;
432467
++it;
433-
468+
434469
// Apply range intersection logic for remaining children
435470
for (; it != values.end(); ++it) {
436471
const auto& child_range = *it;
437472
// Minimum is the maximum of the two range's minimums
438473
// Maximum is the minimum of the two range's maximums
439474
// Check that the maximum (second element) is not smaller than the minimum on the other value
440-
if (static_cast<int>(child_range[1]) < static_cast<int>(children_aggregated_val[0]) ||
475+
if (static_cast<int>(child_range[1]) < static_cast<int>(children_aggregated_val[0]) ||
441476
static_cast<int>(children_aggregated_val[1]) < static_cast<int>(child_range[0])) {
442477
// If the max of the new range is less than the min of the current, clamp to empty/invalid range
443478
children_aggregated_val[0] = capability_attributes<Cap>::invalid;
@@ -480,6 +515,8 @@ namespace detail {
480515
return static_cast<CapType>(cuda::std::min(static_cast<int>(self_val), static_cast<int>(children_aggregated_val)));
481516
} else if (query_type == CapabilityQueryType::MAX_QUERY) {
482517
return static_cast<CapType>(cuda::std::max(static_cast<int>(self_val), static_cast<int>(children_aggregated_val)));
518+
} else if (query_type == CapabilityQueryType::SUM_QUERY) {
519+
return static_cast<CapType>(self_val + children_aggregated_val);
483520
} else {
484521
MATX_ASSERT_STR(false, matxInvalidParameter, "Not implemented for other query types.");
485522
return self_val;
@@ -495,15 +532,15 @@ namespace detail {
495532
// Handle RANGE_QUERY for cuda::std::array<T, 2> types
496533
if (query_type == CapabilityQueryType::RANGE_QUERY) {
497534
CapType result = self_val;
498-
// Apply range intersection logic:
535+
// Apply range intersection logic:
499536
// Minimum is the maximum of the two range's minimums
500537
// Maximum is the minimum of the two range's maximums
501538
// Check that the maximum (second element) is not smaller than the minimum on the other value
502-
if (static_cast<int>(children_aggregated_val[1]) < static_cast<int>(self_val[0]) ||
539+
if (static_cast<int>(children_aggregated_val[1]) < static_cast<int>(self_val[0]) ||
503540
static_cast<int>(self_val[1]) < static_cast<int>(children_aggregated_val[0])) {
504541
// If the max of the new range is less than the min of the current, clamp to empty/invalid range
505542
result[0] = capability_attributes<Cap>::invalid;
506-
result[1] = capability_attributes<Cap>::invalid;
543+
result[1] = capability_attributes<Cap>::invalid;
507544
}
508545
else {
509546
result[0] = static_cast<typename CapType::value_type>(
@@ -531,7 +568,7 @@ namespace detail {
531568
return cuda::std::apply([&in](const auto&... ops) {
532569
return combine_capabilities<Cap>(detail::get_operator_capability<Cap>(ops, in)...);
533570
}, ops_tuple);
534-
}
571+
}
535572

536573
#endif
537574

@@ -541,5 +578,5 @@ namespace detail {
541578
template <typename Op>
542579
__MATX_INLINE__ __MATX_HOST__ bool jit_supported(const Op &op) {
543580
return detail::get_operator_capability<detail::OperatorCapability::SUPPORTS_JIT>(op);
544-
}
545-
} // namespace matx
581+
}
582+
} // namespace matx

0 commit comments

Comments
 (0)