@@ -51,7 +51,7 @@ namespace detail {
5151 struct LTOIRQueryInput {
5252 std::set<std::string> ltoir_symbols;
5353 ElementsPerThread ept;
54- };
54+ };
5555
5656 // Enum for different operator capabilities
5757 enum class OperatorCapability {
@@ -61,6 +61,7 @@ namespace detail {
6161 SET_ELEMENTS_PER_THREAD, // Set the elements per thread for the operator.
6262 JIT_CLASS_QUERY, // Result is the concatenation of the capabilities of the operator and its children.
6363 DYN_SHM_SIZE, // Result is the dynamic shared memory size required for the operator.
64+ STATIC_SHM_SIZE, // Result is the static shared memory size required for the operator.
6465 BLOCK_DIM, // Result is the block dimensions required for the operator.
6566 GENERATE_LTOIR, // Generate LTOIR code for the operator.
6667 JIT_TYPE_QUERY, // Result is the type of JIT code to generate for the operator.
@@ -72,6 +73,7 @@ namespace detail {
7273 ALIASED_MEMORY, // Whether the operator's input and output pointers alias
7374 GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level
7475 PASS_THROUGH_THREADS, // All threads must call operator() on nested operators; bounds checking done at tensor level
76+ BLOCK_REDUCES_RANK, // Block-level operator's critical dimension is not part of the output rank
7577 UNIT_STRIDE_LAST, // Whether all leaf tensors have stride[RANK-1] == 1
7678 // Add more capabilities as needed
7779 };
@@ -84,10 +86,11 @@ namespace detail {
8486 // The operator itself AND its children.
8587 MIN_QUERY, // Result is the minimum of the capabilities of the operator and its children.
8688 MAX_QUERY, // Result is the maximum of the capabilities of the operator and its children.
89+ SUM_QUERY, // Result is the sum of the capabilities of the operator and its children.
8790 STR_CAT_QUERY, // Result is the concatenation of the capabilities of the operator and its children.
8891 RANGE_QUERY, // Result is the range of the capabilities of the operator and its children.
8992 };
90-
93+
9194
9295#if !defined(__CUDACC_RTC__)
9396 template <ElementsPerThread EPT, bool JIT, bool UNIT_STRIDE_LAST = false >
@@ -97,15 +100,17 @@ namespace detail {
97100 static constexpr bool unit_stride_last = UNIT_STRIDE_LAST;
98101 static constexpr int osize = 0 ;
99102 static constexpr int block_size = 0 ;
103+ static constexpr bool pass_through_threads = false ;
104+ using scalar_cap = CapabilityParams<ElementsPerThread::ONE, JIT, UNIT_STRIDE_LAST>;
100105
101106 // For JIT there will be other capabilties patched in with a string
102107 };
103108
104- using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false , false >;
105-
109+ using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false , false >;
110+
106111 // Concept to detect scoped enums
107112 template <typename T>
108- concept is_scoped_enum_c = cuda::std::is_enum_v<T> &&
113+ concept is_scoped_enum_c = cuda::std::is_enum_v<T> &&
109114 !cuda::std::is_convertible_v<T, cuda::std::underlying_type_t <T>>;
110115
111116 // Legacy struct for backwards compatibility
@@ -139,7 +144,7 @@ namespace detail {
139144 static constexpr bool default_value = true ;
140145 static constexpr bool or_identity = false ;
141146 static constexpr bool and_identity = true ;
142- };
147+ };
143148
144149 template <>
145150 struct capability_attributes <OperatorCapability::ASYNC_LOADS_REQUESTED> {
@@ -148,16 +153,16 @@ namespace detail {
148153 static constexpr bool default_value = false ;
149154 static constexpr bool or_identity = false ;
150155 static constexpr bool and_identity = true ;
151- };
152-
156+ };
157+
153158 template <>
154159 struct capability_attributes <OperatorCapability::GLOBAL_KERNEL> {
155160 using type = bool ;
156161 using input_type = VoidCapabilityType;
157162 static constexpr bool default_value = true ;
158163 static constexpr bool or_identity = false ;
159164 static constexpr bool and_identity = true ;
160- };
165+ };
161166
162167 template <>
163168 struct capability_attributes <OperatorCapability::ALIASED_MEMORY> {
@@ -166,7 +171,7 @@ namespace detail {
166171 static constexpr bool default_value = false ;
167172 static constexpr bool or_identity = false ;
168173 static constexpr bool and_identity = true ;
169- };
174+ };
170175
171176 template <>
172177 struct capability_attributes <OperatorCapability::GROUPS_PER_BLOCK> {
@@ -176,7 +181,7 @@ namespace detail {
176181 static constexpr cuda::std::array<int , 2 > default_value = {1 , 32 }; // Example: 1 element per thread by default
177182 static constexpr cuda::std::array<int , 2 > min_identity = {32 , 1 };
178183 static constexpr cuda::std::array<int , 2 > max_identity = {1 , 32 };
179- };
184+ };
180185
181186 template <>
182187 struct capability_attributes <OperatorCapability::BLOCK_DIM> {
@@ -186,7 +191,7 @@ namespace detail {
186191 static constexpr cuda::std::array<int , 2 > default_value = {1 , 1024 }; // Example: 1 element per thread by default
187192 static constexpr cuda::std::array<int , 2 > min_identity = {1024 , 1 };
188193 static constexpr cuda::std::array<int , 2 > max_identity = {1 , 1024 };
189- };
194+ };
190195
191196 template <>
192197 struct capability_attributes <OperatorCapability::SET_ELEMENTS_PER_THREAD> {
@@ -195,7 +200,7 @@ namespace detail {
195200 static constexpr bool default_value = true ;
196201 static constexpr bool or_identity = false ;
197202 static constexpr bool and_identity = true ;
198- };
203+ };
199204
200205 template <>
201206 struct capability_attributes <OperatorCapability::SET_GROUPS_PER_BLOCK> {
@@ -204,7 +209,7 @@ namespace detail {
204209 static constexpr bool default_value = true ;
205210 static constexpr bool or_identity = false ;
206211 static constexpr bool and_identity = true ;
207- };
212+ };
208213
209214 template <>
210215 struct capability_attributes <OperatorCapability::ELEMENTS_PER_THREAD> {
@@ -223,15 +228,15 @@ namespace detail {
223228 static constexpr bool default_value = true ;
224229 static constexpr bool or_identity = false ;
225230 static constexpr bool and_identity = true ;
226- };
231+ };
227232
228233 template <>
229234 struct capability_attributes <OperatorCapability::JIT_TYPE_QUERY> {
230235 using type = std::string;
231236 using input_type = VoidCapabilityType;
232237 static inline const std::string default_value = " " ;
233238 static inline const std::string min_identity = " " ;
234- };
239+ };
235240
236241 template <>
237242 struct capability_attributes <OperatorCapability::DYN_SHM_SIZE> {
@@ -240,7 +245,18 @@ namespace detail {
240245 static constexpr int default_value = 0 ;
241246 static constexpr int min_identity = cuda::std::numeric_limits<int >::max();
242247 static constexpr int max_identity = 0 ;
243- };
248+ static constexpr int sum_identity = 0 ;
249+ };
250+
251+ template <>
252+ struct capability_attributes <OperatorCapability::STATIC_SHM_SIZE> {
253+ using type = int ;
254+ using input_type = VoidCapabilityType;
255+ static constexpr int default_value = 0 ;
256+ static constexpr int min_identity = cuda::std::numeric_limits<int >::max();
257+ static constexpr int max_identity = 0 ;
258+ static constexpr int sum_identity = 0 ;
259+ };
244260
245261 template <>
246262 struct capability_attributes <OperatorCapability::MAX_EPT_VEC_LOAD> {
@@ -249,6 +265,7 @@ namespace detail {
249265 static constexpr int default_value = 32 ;
250266 static constexpr int min_identity = 32 ;
251267 static constexpr int max_identity = 1 ;
268+ static constexpr int sum_identity = 0 ;
252269 };
253270
254271 template <>
@@ -260,6 +277,15 @@ namespace detail {
260277 static constexpr bool and_identity = true ;
261278 };
262279
280+ template <>
281+ struct capability_attributes <OperatorCapability::BLOCK_REDUCES_RANK> {
282+ using type = bool ;
283+ using input_type = VoidCapabilityType;
284+ static constexpr bool default_value = false ;
285+ static constexpr bool or_identity = false ;
286+ static constexpr bool and_identity = true ;
287+ };
288+
263289 template <>
264290 struct capability_attributes <OperatorCapability::UNIT_STRIDE_LAST> {
265291 using type = bool ;
@@ -270,7 +296,7 @@ namespace detail {
270296 static constexpr bool default_value = true ;
271297 static constexpr bool or_identity = false ;
272298 static constexpr bool and_identity = true ;
273- };
299+ };
274300
275301
276302 template <OperatorCapability Cap, typename OperatorType, typename InType>
@@ -292,7 +318,7 @@ namespace detail {
292318 return capability_attributes<Cap>::default_value;
293319 }
294320 }
295- }
321+ }
296322
297323 // Helper to safely get capability from an operator.
298324 // OperandType is likely base_type_t<ActualOpType> or a raw scalar/functor type.
@@ -301,7 +327,7 @@ namespace detail {
301327 get_operator_capability (const OperatorType& op) {
302328 VoidCapabilityType void_type{};
303329 return get_operator_capability<Cap>(op, void_type);
304- }
330+ }
305331
306332
307333 // Helper function to get the query type associated with a capability
@@ -332,17 +358,21 @@ namespace detail {
332358 return CapabilityQueryType::STR_CAT_QUERY; // The expression should use the concatenation of the capabilities of its children.
333359 case OperatorCapability::DYN_SHM_SIZE:
334360 return CapabilityQueryType::MAX_QUERY; // The expression should use the maximum dynamic shared memory size of its children.
361+ case OperatorCapability::STATIC_SHM_SIZE:
362+ return CapabilityQueryType::SUM_QUERY; // Static shared memory declarations are additive in fused kernels.
335363 case OperatorCapability::BLOCK_DIM:
336364 return CapabilityQueryType::RANGE_QUERY; // The expression should use the minimum block size supported by all operators.
337365 case OperatorCapability::GENERATE_LTOIR:
338366 return CapabilityQueryType::AND_QUERY; // The expression should generate LTOIR code if all its children generate it.
339367 case OperatorCapability::PASS_THROUGH_THREADS:
340368 return CapabilityQueryType::OR_QUERY; // If ANY operator needs pass-through, all threads must call operator()
369+ case OperatorCapability::BLOCK_REDUCES_RANK:
370+ return CapabilityQueryType::OR_QUERY; // If ANY operator reduces rank, use the reduced-rank block kernel.
341371 case OperatorCapability::UNIT_STRIDE_LAST:
342372 return CapabilityQueryType::AND_QUERY; // All leaf tensors must have stride[RANK-1] == 1
343373 default :
344374 // Default to OR_QUERY or handle as an error/assertion if a capability isn't mapped.
345- return CapabilityQueryType::OR_QUERY;
375+ return CapabilityQueryType::OR_QUERY;
346376 }
347377 }
348378
@@ -375,6 +405,8 @@ namespace detail {
375405 children_aggregated_val = capability_attributes<Cap>::min_identity;
376406 } else if (query_type == CapabilityQueryType::MAX_QUERY) {
377407 children_aggregated_val = capability_attributes<Cap>::max_identity;
408+ } else if (query_type == CapabilityQueryType::SUM_QUERY) {
409+ children_aggregated_val = capability_attributes<Cap>::sum_identity;
378410 } else {
379411 // Default identity for int if not MIN_QUERY or MAX_QUERY (e.g. if it was SUM_QUERY, identity would be 0)
380412 // This path needs clear definition if other query types are used for int.
@@ -391,7 +423,7 @@ namespace detail {
391423 if constexpr (std::is_same_v<CapType, bool >) {
392424 if (query_type == CapabilityQueryType::OR_QUERY) {
393425 children_aggregated_val = capability_attributes<Cap>::or_identity;
394- ((children_aggregated_val = children_aggregated_val || child_vals), ...);
426+ ((children_aggregated_val = children_aggregated_val || child_vals), ...);
395427 } else { // AND_QUERY
396428 children_aggregated_val = capability_attributes<Cap>::and_identity;
397429 ((children_aggregated_val = children_aggregated_val && child_vals), ...);
@@ -411,6 +443,9 @@ namespace detail {
411443 for (CapType val : values) {
412444 children_aggregated_val = static_cast <CapType>(cuda::std::max (static_cast <int >(children_aggregated_val), static_cast <int >(val)));
413445 }
446+ } else if (query_type == CapabilityQueryType::SUM_QUERY) {
447+ children_aggregated_val = capability_attributes<Cap>::sum_identity;
448+ ((children_aggregated_val += child_vals), ...);
414449 } else {
415450 // Not implemented for other query types.
416451 MATX_ASSERT_STR (false , matxInvalidParameter, " Not implemented for other query types." );
@@ -430,14 +465,14 @@ namespace detail {
430465 auto it = values.begin ();
431466 children_aggregated_val = *it;
432467 ++it;
433-
468+
434469 // Apply range intersection logic for remaining children
435470 for (; it != values.end (); ++it) {
436471 const auto & child_range = *it;
437472 // Minimum is the maximum of the two range's minimums
438473 // Maximum is the minimum of the two range's maximums
439474 // Check that the maximum (second element) is not smaller than the minimum on the other value
440- if (static_cast <int >(child_range[1 ]) < static_cast <int >(children_aggregated_val[0 ]) ||
475+ if (static_cast <int >(child_range[1 ]) < static_cast <int >(children_aggregated_val[0 ]) ||
441476 static_cast <int >(children_aggregated_val[1 ]) < static_cast <int >(child_range[0 ])) {
442477 // If the max of the new range is less than the min of the current, clamp to empty/invalid range
443478 children_aggregated_val[0 ] = capability_attributes<Cap>::invalid;
@@ -480,6 +515,8 @@ namespace detail {
480515 return static_cast <CapType>(cuda::std::min (static_cast <int >(self_val), static_cast <int >(children_aggregated_val)));
481516 } else if (query_type == CapabilityQueryType::MAX_QUERY) {
482517 return static_cast <CapType>(cuda::std::max (static_cast <int >(self_val), static_cast <int >(children_aggregated_val)));
518+ } else if (query_type == CapabilityQueryType::SUM_QUERY) {
519+ return static_cast <CapType>(self_val + children_aggregated_val);
483520 } else {
484521 MATX_ASSERT_STR (false , matxInvalidParameter, " Not implemented for other query types." );
485522 return self_val;
@@ -495,15 +532,15 @@ namespace detail {
495532 // Handle RANGE_QUERY for cuda::std::array<T, 2> types
496533 if (query_type == CapabilityQueryType::RANGE_QUERY) {
497534 CapType result = self_val;
498- // Apply range intersection logic:
535+ // Apply range intersection logic:
499536 // Minimum is the maximum of the two range's minimums
500537 // Maximum is the minimum of the two range's maximums
501538 // Check that the maximum (second element) is not smaller than the minimum on the other value
502- if (static_cast <int >(children_aggregated_val[1 ]) < static_cast <int >(self_val[0 ]) ||
539+ if (static_cast <int >(children_aggregated_val[1 ]) < static_cast <int >(self_val[0 ]) ||
503540 static_cast <int >(self_val[1 ]) < static_cast <int >(children_aggregated_val[0 ])) {
504541 // If the max of the new range is less than the min of the current, clamp to empty/invalid range
505542 result[0 ] = capability_attributes<Cap>::invalid;
506- result[1 ] = capability_attributes<Cap>::invalid;
543+ result[1 ] = capability_attributes<Cap>::invalid;
507544 }
508545 else {
509546 result[0 ] = static_cast <typename CapType::value_type>(
@@ -531,7 +568,7 @@ namespace detail {
531568 return cuda::std::apply ([&in](const auto &... ops) {
532569 return combine_capabilities<Cap>(detail::get_operator_capability<Cap>(ops, in)...);
533570 }, ops_tuple);
534- }
571+ }
535572
536573#endif
537574
@@ -541,5 +578,5 @@ namespace detail {
541578 template <typename Op>
542579 __MATX_INLINE__ __MATX_HOST__ bool jit_supported (const Op &op) {
543580 return detail::get_operator_capability<detail::OperatorCapability::SUPPORTS_JIT>(op);
544- }
545- } // namespace matx
581+ }
582+ } // namespace matx
0 commit comments