Skip to content

Commit d2f550c

Browse files
authored
Add a unit_stride_last capability (#1153)
* Add a unit_stride_last capability Add a capability to indicate that a tensor has unit stride in the last dimension. When true, we can elide loading and multiplying by the last stride. The last dimension being unit-stride is the nominal case since that is what is created via make_tensor() without user-provided strides. This capability approach applies to the matxOpT*Kernel dispatch. It will not apply to custom kernels in MatX. It requires all nodes in the expression tree to have unit_stride_last for the optimization to activate as it is applied at the kernel level. Signed-off-by: Thomas Benson <tbenson@nvidia.com>
1 parent 013c856 commit d2f550c

6 files changed

Lines changed: 143 additions & 85 deletions

File tree

include/matx/core/capabilities.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ namespace detail {
7272
ALIASED_MEMORY, // Whether the operator's input and output pointers alias
7373
GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level
7474
PASS_THROUGH_THREADS, // All threads must call operator() on nested operators; bounds checking done at tensor level
75+
UNIT_STRIDE_LAST, // Whether all leaf tensors have stride[RANK-1] == 1
7576
// Add more capabilities as needed
7677
};
7778

@@ -89,17 +90,18 @@ namespace detail {
8990

9091

9192
#if !defined(__CUDACC_RTC__)
92-
template <ElementsPerThread EPT, bool JIT>
93+
template <ElementsPerThread EPT, bool JIT, bool UNIT_STRIDE_LAST = false>
9394
struct CapabilityParams {
9495
static constexpr ElementsPerThread ept = EPT;
9596
static constexpr bool jit = JIT;
97+
static constexpr bool unit_stride_last = UNIT_STRIDE_LAST;
9698
static constexpr int osize = 0;
9799
static constexpr int block_size = 0;
98100

99101
// For JIT there will be other capabilties patched in with a string
100-
};
102+
};
101103

102-
using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false>;
104+
using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false, false>;
103105

104106
// Concept to detect scoped enums
105107
template<typename T>
@@ -256,6 +258,18 @@ namespace detail {
256258
static constexpr bool default_value = false; // Default: operators do their own bounds checking
257259
static constexpr bool or_identity = false;
258260
static constexpr bool and_identity = true;
261+
};
262+
263+
template <>
264+
struct capability_attributes<OperatorCapability::UNIT_STRIDE_LAST> {
265+
using type = bool;
266+
using input_type = VoidCapabilityType;
267+
// Non-tensor ops (scalars, generators) are trivially unit-stride, so default those to true
268+
// since we will AND all unit stride capabilities in the expression tree. Tensor-like
269+
// types should handle the unit stride query in their get_capability() method.
270+
static constexpr bool default_value = true;
271+
static constexpr bool or_identity = false;
272+
static constexpr bool and_identity = true;
259273
};
260274

261275

@@ -324,6 +338,8 @@ namespace detail {
324338
return CapabilityQueryType::AND_QUERY; // The expression should generate LTOIR code if all its children generate it.
325339
case OperatorCapability::PASS_THROUGH_THREADS:
326340
return CapabilityQueryType::OR_QUERY; // If ANY operator needs pass-through, all threads must call operator()
341+
case OperatorCapability::UNIT_STRIDE_LAST:
342+
return CapabilityQueryType::AND_QUERY; // All leaf tensors must have stride[RANK-1] == 1
327343
default:
328344
// Default to OR_QUERY or handle as an error/assertion if a capability isn't mapped.
329345
return CapabilityQueryType::OR_QUERY;

include/matx/core/nvrtc_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ std::string generate_capability_params_string([[maybe_unused]] const Op &op, Ele
245245
"struct CapabilityParams {\n"
246246
" static constexpr ElementsPerThread ept = EPT;\n"
247247
" static constexpr bool jit = JIT;\n"
248+
// Note: no unit_stride_last here. JIT bakes strides as constexpr
249+
// values, so the compiler already eliminates multiply-by-1.
248250
" static constexpr int osize = " + std::to_string(osize) + ";\n"
249251
" static constexpr int block_size = " + std::to_string(block_size) + ";\n"
250252
" static constexpr bool pass_through_threads = " + pass_through_str + ";\n"

include/matx/core/tensor_impl.h

Lines changed: 75 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -154,71 +154,71 @@ class tensor_impl_t {
154154
" T *ldata_;\n" +
155155
" constexpr static cuda::std::array<index_t, " + std::to_string(Rank()) + "> strides_ = { " + detail::array_to_string(desc_.Strides()) + " };\n" +
156156
" constexpr static cuda::std::array<index_t, " + std::to_string(Rank()) + "> sizes_ = { " + detail::array_to_string(desc_.Shape()) + " };\n" +
157-
" template <detail::ElementsPerThread EPT, int I = 0, typename ...Is>\n" +
157+
" template <typename CapType, int I = 0, typename ...Is>\n" +
158158
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal([[maybe_unused]] cuda::std::tuple<Is...> tup) {\n" +
159159
" if constexpr (I < sizeof...(Is)) {\n" +
160-
" if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
161-
" return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(EPT));\n" +
160+
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
161+
" return GetVal<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(CapType::ept));\n" +
162162
" }\n" +
163163
" else {\n" +
164-
" return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
164+
" return GetVal<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
165165
" }\n" +
166166
" }\n" +
167167
" else {\n" +
168168
" return 0;\n" +
169169
" }\n" +
170170
" }\n" +
171-
" template <detail::ElementsPerThread EPT, int I = 0, typename ...Is>\n" +
171+
" template <typename CapType, int I = 0, typename ...Is>\n" +
172172
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC([[maybe_unused]] const cuda::std::tuple<Is...> tup) const {\n" +
173173
" if constexpr (I < sizeof...(Is)) {\n" +
174-
" if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
175-
" return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(EPT));\n" +
174+
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n" +
175+
" return GetValC<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(CapType::ept));\n" +
176176
" }\n" +
177177
" else {\n" +
178-
" return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
178+
" return GetValC<CapType, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n" +
179179
" }\n" +
180180
" }\n" +
181181
" else {\n" +
182182
" return 0;\n" +
183183
" }\n" +
184-
" }\n" +
185-
" template <detail::ElementsPerThread EPT, typename... Is>\n" +
184+
" }\n" +
185+
" template <typename CapType, typename... Is>\n" +
186186
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized(Is... indices) const {\n" +
187187
" constexpr size_t rank = sizeof...(Is);\n" +
188-
" constexpr int EPT_int = static_cast<int>(EPT);\n" +
188+
" constexpr int EPT_int = static_cast<int>(CapType::ept);\n" +
189189
" const cuda::std::array<index_t, rank> idx{indices...};\n" +
190190
" \n" +
191191
" if constexpr (rank == 1) {\n" +
192-
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
192+
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
193193
" return idx[0] * (strides_[0] * EPT_int);\n" +
194194
" } else {\n" +
195195
" return idx[0] * strides_[0];\n" +
196196
" }\n" +
197197
" }\n" +
198198
" else if constexpr (rank == 2) {\n" +
199-
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
199+
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
200200
" return idx[0] * strides_[0] + idx[1] * (strides_[1] * EPT_int);\n" +
201201
" } else {\n" +
202202
" return idx[0] * strides_[0] + idx[1] * strides_[1];\n" +
203203
" }\n" +
204204
" }\n" +
205205
" else if constexpr (rank == 3) {\n" +
206-
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
206+
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
207207
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * (strides_[2] * EPT_int);\n" +
208208
" } else {\n" +
209209
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2];\n" +
210210
" }\n" +
211211
" }\n" +
212212
" else if constexpr (rank == 4) {\n" +
213-
" if constexpr (EPT != detail::ElementsPerThread::ONE) {\n" +
213+
" if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n" +
214214
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * (strides_[3] * EPT_int);\n" +
215215
" } else {\n" +
216216
" return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * strides_[3];\n" +
217217
" }\n" +
218218
" }\n" +
219219
" else {\n" +
220220
" // For rank > 4, fall back to the recursive implementation\n" +
221-
" return GetValC<EPT, 0, Is...>(cuda::std::make_tuple(indices...));\n" +
221+
" return GetValC<CapType, 0, Is...>(cuda::std::make_tuple(indices...));\n" +
222222
" }\n" +
223223
" }\n" +
224224
" template <typename CapType, int I = 0, typename... Is>\n" +
@@ -246,7 +246,7 @@ class tensor_impl_t {
246246
" return ReturnType{};\n" +
247247
" }\n" +
248248
" }\n" +
249-
" const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
249+
" const index_t offset = GetOffsetOptimized<CapType>(indices...);\n" +
250250
" if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" +
251251
" return ldata_[offset];\n" +
252252
" } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n" +
@@ -272,7 +272,7 @@ class tensor_impl_t {
272272
" return dummy_;\n" +
273273
" }\n" +
274274
" }\n" +
275-
" const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
275+
" const index_t offset = GetOffsetOptimized<CapType>(indices...);\n" +
276276
" if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" +
277277
" return ldata_[offset];\n" +
278278
" } else {\n" +
@@ -296,7 +296,7 @@ class tensor_impl_t {
296296
" template <typename CapType, int M = RANK, typename... Is>\n" +
297297
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* data_ptr(index_t block_idx, index_t ttl_threads) const noexcept\n" +
298298
" {\n"
299-
" //const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
299+
" //const index_t offset = GetOffsetOptimized<CapType>(indices...);\n" +
300300
" //return ldata_ + offset;\n" +
301301
" return ldata_ + block_idx * ttl_threads * static_cast<index_t>(CapType::ept);\n" +
302302
" }\n" +
@@ -1107,7 +1107,7 @@ MATX_IGNORE_WARNING_POP_GCC
11071107
template <typename... Is>
11081108
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* GetPointer(Is... indices) const noexcept
11091109
{
1110-
return data_.ldata_ + GetOffsetOptimized<detail::ElementsPerThread::ONE>(indices...);
1110+
return data_.ldata_ + GetOffsetOptimized<detail::DefaultCapabilities>(indices...);
11111111
}
11121112

11131113
// Locates position of an element at given indices, or returns -1 when not
@@ -1204,76 +1204,80 @@ MATX_IGNORE_WARNING_POP_GCC
12041204
return desc_.IsContiguous();
12051205
}
12061206

1207-
template <typename detail::ElementsPerThread EPT, int I = 0, typename ...Is>
1207+
template <typename CapType, int I = 0, typename ...Is>
12081208
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal([[maybe_unused]] cuda::std::tuple<Is...> tup) {
12091209
if constexpr (I < sizeof...(Is)) {
12101210
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
1211-
if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {
1212-
return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I) * static_cast<index_t>(EPT));
1213-
}
1214-
else {
1215-
return GetVal<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I));
1216-
}
1211+
return GetVal<CapType, I+1, Is...>(tup) + DimStride<I, static_cast<int>(sizeof...(Is)), CapType>(cuda::std::get<I>(tup));
12171212
MATX_IGNORE_WARNING_POP_GCC
12181213
}
12191214
else {
12201215
return 0;
12211216
}
12221217
}
12231218

1224-
// Optimized offset calculation for ranks 1-4 with explicit stride multiplications
1225-
template <detail::ElementsPerThread EPT, typename... Is>
1219+
// Compute the stride contribution for a single dimension, eliding the
1220+
// load and multiply when the last dimension is known at dispatch time
1221+
// to have unit stride (via CapType::unit_stride_last).
1222+
template <int DIM, int RANK_VAL, typename CapType>
1223+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type
1224+
DimStride(index_t idx_val) const {
1225+
constexpr bool is_last = (DIM == RANK_VAL - 1);
1226+
constexpr bool is_unit = CapType::unit_stride_last && is_last;
1227+
constexpr bool has_ept = (CapType::ept != detail::ElementsPerThread::ONE) && is_last;
1228+
1229+
if constexpr (is_unit && has_ept) {
1230+
return idx_val * static_cast<index_t>(CapType::ept);
1231+
} else if constexpr (is_unit) {
1232+
return idx_val;
1233+
} else if constexpr (has_ept) {
1234+
return idx_val * (this->desc_.Stride(DIM) * static_cast<index_t>(CapType::ept));
1235+
} else {
1236+
return idx_val * this->desc_.Stride(DIM);
1237+
}
1238+
}
1239+
1240+
// Optimized offset calculation for ranks 1-4 with explicit stride multiplications.
1241+
// When CapType::unit_stride_last is true, the stride load (ULDC) and
1242+
// multiply (IMAD) for the last dimension are elided.
1243+
template <typename CapType, typename... Is>
12261244
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized(Is... indices) const {
1227-
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
1245+
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
12281246
constexpr size_t rank = sizeof...(Is);
1229-
constexpr int EPT_int = static_cast<int>(EPT);
12301247
const cuda::std::array<index_t, rank> idx{indices...};
1231-
1248+
1249+
constexpr int R = static_cast<int>(rank);
1250+
12321251
if constexpr (rank == 1) {
1233-
if constexpr (EPT != detail::ElementsPerThread::ONE) {
1234-
return idx[0] * (this->desc_.Stride(0) * EPT_int);
1235-
} else {
1236-
return idx[0] * this->desc_.Stride(0);
1237-
}
1252+
return DimStride<0, R, CapType>(idx[0]);
12381253
}
12391254
else if constexpr (rank == 2) {
1240-
if constexpr (EPT != detail::ElementsPerThread::ONE) {
1241-
return idx[0] * this->desc_.Stride(0) + idx[1] * (this->desc_.Stride(1) * EPT_int);
1242-
} else {
1243-
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1);
1244-
}
1255+
return DimStride<0, R, CapType>(idx[0])
1256+
+ DimStride<1, R, CapType>(idx[1]);
12451257
}
12461258
else if constexpr (rank == 3) {
1247-
if constexpr (EPT != detail::ElementsPerThread::ONE) {
1248-
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * (this->desc_.Stride(2) * EPT_int);
1249-
} else {
1250-
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2);
1251-
}
1259+
return DimStride<0, R, CapType>(idx[0])
1260+
+ DimStride<1, R, CapType>(idx[1])
1261+
+ DimStride<2, R, CapType>(idx[2]);
12521262
}
12531263
else if constexpr (rank == 4) {
1254-
if constexpr (EPT != detail::ElementsPerThread::ONE) {
1255-
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2) + idx[3] * (this->desc_.Stride(3) * EPT_int);
1256-
} else {
1257-
return idx[0] * this->desc_.Stride(0) + idx[1] * this->desc_.Stride(1) + idx[2] * this->desc_.Stride(2) + idx[3] * this->desc_.Stride(3);
1258-
}
1264+
return DimStride<0, R, CapType>(idx[0])
1265+
+ DimStride<1, R, CapType>(idx[1])
1266+
+ DimStride<2, R, CapType>(idx[2])
1267+
+ DimStride<3, R, CapType>(idx[3]);
12591268
}
12601269
else {
12611270
// For rank > 4, fall back to the recursive implementation
1262-
return GetValC<EPT, 0, Is...>(cuda::std::make_tuple(indices...));
1271+
return GetValC<CapType, 0, Is...>(cuda::std::make_tuple(indices...));
12631272
}
1264-
MATX_IGNORE_WARNING_POP_GCC
1273+
MATX_IGNORE_WARNING_POP_GCC
12651274
}
12661275

1267-
template <detail::ElementsPerThread EPT, int I = 0, typename ...Is>
1276+
template <typename CapType, int I = 0, typename ...Is>
12681277
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC([[maybe_unused]] const cuda::std::tuple<Is...> tup) const {
12691278
if constexpr (I < sizeof...(Is)) {
12701279
MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
1271-
if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {
1272-
return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I) * static_cast<index_t>(EPT));
1273-
}
1274-
else {
1275-
return GetValC<EPT, I+1, Is...>(tup) + cuda::std::get<I>(tup)*(this->desc_.Stride(I));
1276-
}
1280+
return GetValC<CapType, I+1, Is...>(tup) + DimStride<I, static_cast<int>(sizeof...(Is)), CapType>(cuda::std::get<I>(tup));
12771281
MATX_IGNORE_WARNING_POP_GCC
12781282
}
12791283
else {
@@ -1299,7 +1303,7 @@ MATX_IGNORE_WARNING_POP_GCC
12991303
assert(data_.ldata_ != nullptr);
13001304
#endif
13011305
constexpr int EPT_int = static_cast<int>(CapType::ept);
1302-
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);
1306+
const index_t offset = GetOffsetOptimized<CapType>(indices...);
13031307

13041308
if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
13051309
return data_.ldata_[offset];
@@ -1329,7 +1333,7 @@ MATX_IGNORE_WARNING_POP_GCC
13291333
{
13301334
static_assert(sizeof...(Is) == M, "Number of indices of data_ptr must match rank of tensor");
13311335
if constexpr (!is_sparse_data_v<TensorData>) {
1332-
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);
1336+
const index_t offset = GetOffsetOptimized<CapType>(indices...);
13331337
return data_.ldata_ + offset;
13341338
}
13351339
else {
@@ -1363,7 +1367,7 @@ MATX_IGNORE_WARNING_POP_GCC
13631367
assert(data_.ldata_ != nullptr);
13641368
#endif
13651369
constexpr int EPT_int = static_cast<int>(CapType::ept);
1366-
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);
1370+
const index_t offset = GetOffsetOptimized<CapType>(indices...);
13671371

13681372
if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
13691373
return data_.ldata_[offset];
@@ -1563,6 +1567,13 @@ MATX_IGNORE_WARNING_POP_GCC
15631567
return overlaps;
15641568
}
15651569
}
1570+
else if constexpr (Cap == OperatorCapability::UNIT_STRIDE_LAST) {
1571+
if constexpr (Rank() == 0) {
1572+
return true;
1573+
} else {
1574+
return (Stride(Rank() - 1) == 1);
1575+
}
1576+
}
15661577
else {
15671578
return detail::capability_attributes<Cap>::default_value;
15681579
}

0 commit comments

Comments
 (0)