@@ -154,71 +154,71 @@ class tensor_impl_t {
154154 " T *ldata_;\n " +
155155 " constexpr static cuda::std::array<index_t, " + std::to_string (Rank ()) + " > strides_ = { " + detail::array_to_string (desc_.Strides ()) + " };\n " +
156156 " constexpr static cuda::std::array<index_t, " + std::to_string (Rank ()) + " > sizes_ = { " + detail::array_to_string (desc_.Shape ()) + " };\n " +
157- " template <detail::ElementsPerThread EPT , int I = 0, typename ...Is>\n " +
157+ " template <typename CapType , int I = 0, typename ...Is>\n " +
158158 " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal([[maybe_unused]] cuda::std::tuple<Is...> tup) {\n " +
159159 " if constexpr (I < sizeof...(Is)) {\n " +
160- " if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n " +
161- " return GetVal<EPT , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(EPT ));\n " +
160+ " if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n " +
161+ " return GetVal<CapType , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(CapType::ept ));\n " +
162162 " }\n " +
163163 " else {\n " +
164- " return GetVal<EPT , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n " +
164+ " return GetVal<CapType , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n " +
165165 " }\n " +
166166 " }\n " +
167167 " else {\n " +
168168 " return 0;\n " +
169169 " }\n " +
170170 " }\n " +
171- " template <detail::ElementsPerThread EPT , int I = 0, typename ...Is>\n " +
171+ " template <typename CapType , int I = 0, typename ...Is>\n " +
172172 " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC([[maybe_unused]] const cuda::std::tuple<Is...> tup) const {\n " +
173173 " if constexpr (I < sizeof...(Is)) {\n " +
174- " if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n " +
175- " return GetValC<EPT , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(EPT ));\n " +
174+ " if constexpr (CapType::ept != detail::ElementsPerThread::ONE && I == sizeof...(Is) - 1) {\n " +
175+ " return GetValC<CapType , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I] * static_cast<index_t>(CapType::ept ));\n " +
176176 " }\n " +
177177 " else {\n " +
178- " return GetValC<EPT , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n " +
178+ " return GetValC<CapType , I+1, Is...>(tup) + cuda::std::get<I>(tup)*(strides_[I]);\n " +
179179 " }\n " +
180180 " }\n " +
181181 " else {\n " +
182182 " return 0;\n " +
183183 " }\n " +
184- " }\n " +
185- " template <detail::ElementsPerThread EPT , typename... Is>\n " +
184+ " }\n " +
185+ " template <typename CapType , typename... Is>\n " +
186186 " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized(Is... indices) const {\n " +
187187 " constexpr size_t rank = sizeof...(Is);\n " +
188- " constexpr int EPT_int = static_cast<int>(EPT );\n " +
188+ " constexpr int EPT_int = static_cast<int>(CapType::ept );\n " +
189189 " const cuda::std::array<index_t, rank> idx{indices...};\n " +
190190 " \n " +
191191 " if constexpr (rank == 1) {\n " +
192- " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n " +
192+ " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n " +
193193 " return idx[0] * (strides_[0] * EPT_int);\n " +
194194 " } else {\n " +
195195 " return idx[0] * strides_[0];\n " +
196196 " }\n " +
197197 " }\n " +
198198 " else if constexpr (rank == 2) {\n " +
199- " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n " +
199+ " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n " +
200200 " return idx[0] * strides_[0] + idx[1] * (strides_[1] * EPT_int);\n " +
201201 " } else {\n " +
202202 " return idx[0] * strides_[0] + idx[1] * strides_[1];\n " +
203203 " }\n " +
204204 " }\n " +
205205 " else if constexpr (rank == 3) {\n " +
206- " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n " +
206+ " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n " +
207207 " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * (strides_[2] * EPT_int);\n " +
208208 " } else {\n " +
209209 " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2];\n " +
210210 " }\n " +
211211 " }\n " +
212212 " else if constexpr (rank == 4) {\n " +
213- " if constexpr (EPT != detail::ElementsPerThread::ONE) {\n " +
213+ " if constexpr (CapType::ept != detail::ElementsPerThread::ONE) {\n " +
214214 " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * (strides_[3] * EPT_int);\n " +
215215 " } else {\n " +
216216 " return idx[0] * strides_[0] + idx[1] * strides_[1] + idx[2] * strides_[2] + idx[3] * strides_[3];\n " +
217217 " }\n " +
218218 " }\n " +
219219 " else {\n " +
220220 " // For rank > 4, fall back to the recursive implementation\n " +
221- " return GetValC<EPT , 0, Is...>(cuda::std::make_tuple(indices...));\n " +
221+ " return GetValC<CapType , 0, Is...>(cuda::std::make_tuple(indices...));\n " +
222222 " }\n " +
223223 " }\n " +
224224 " template <typename CapType, int I = 0, typename... Is>\n " +
@@ -246,7 +246,7 @@ class tensor_impl_t {
246246 " return ReturnType{};\n " +
247247 " }\n " +
248248 " }\n " +
249- " const index_t offset = GetOffsetOptimized<CapType::ept >(indices...);\n " +
249+ " const index_t offset = GetOffsetOptimized<CapType>(indices...);\n " +
250250 " if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n " +
251251 " return ldata_[offset];\n " +
252252 " } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n " +
@@ -272,7 +272,7 @@ class tensor_impl_t {
272272 " return dummy_;\n " +
273273 " }\n " +
274274 " }\n " +
275- " const index_t offset = GetOffsetOptimized<CapType::ept >(indices...);\n " +
275+ " const index_t offset = GetOffsetOptimized<CapType>(indices...);\n " +
276276 " if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n " +
277277 " return ldata_[offset];\n " +
278278 " } else {\n " +
@@ -296,7 +296,7 @@ class tensor_impl_t {
296296 " template <typename CapType, int M = RANK, typename... Is>\n " +
297297 " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* data_ptr(index_t block_idx, index_t ttl_threads) const noexcept\n " +
298298 " {\n "
299- " //const index_t offset = GetOffsetOptimized<CapType::ept >(indices...);\n " +
299+ " //const index_t offset = GetOffsetOptimized<CapType>(indices...);\n " +
300300 " //return ldata_ + offset;\n " +
301301 " return ldata_ + block_idx * ttl_threads * static_cast<index_t>(CapType::ept);\n " +
302302 " }\n " +
@@ -1107,7 +1107,7 @@ MATX_IGNORE_WARNING_POP_GCC
11071107 template <typename ... Is>
11081108 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* GetPointer (Is... indices) const noexcept
11091109 {
1110- return data_.ldata_ + GetOffsetOptimized<detail::ElementsPerThread::ONE >(indices...);
1110+ return data_.ldata_ + GetOffsetOptimized<detail::DefaultCapabilities >(indices...);
11111111 }
11121112
11131113 // Locates position of an element at given indices, or returns -1 when not
@@ -1204,76 +1204,80 @@ MATX_IGNORE_WARNING_POP_GCC
12041204 return desc_.IsContiguous ();
12051205 }
12061206
1207- template <typename detail::ElementsPerThread EPT , int I = 0 , typename ...Is>
1207+ template <typename CapType , int I = 0 , typename ...Is>
12081208 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetVal ([[maybe_unused]] cuda::std::tuple<Is...> tup) {
12091209 if constexpr (I < sizeof ...(Is)) {
12101210MATX_IGNORE_WARNING_PUSH_GCC (" -Wmaybe-uninitialized" )
1211- if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof ...(Is) - 1 ) {
1212- return GetVal<EPT, I+1 , Is...>(tup) + cuda::std::get<I>(tup)*(this ->desc_ .Stride (I) * static_cast <index_t >(EPT));
1213- }
1214- else {
1215- return GetVal<EPT, I+1 , Is...>(tup) + cuda::std::get<I>(tup)*(this ->desc_ .Stride (I));
1216- }
1211+ return GetVal<CapType, I+1 , Is...>(tup) + DimStride<I, static_cast <int >(sizeof ...(Is)), CapType>(cuda::std::get<I>(tup));
12171212MATX_IGNORE_WARNING_POP_GCC
12181213 }
12191214 else {
12201215 return 0 ;
12211216 }
12221217 }
12231218
1224- // Optimized offset calculation for ranks 1-4 with explicit stride multiplications
1225- template <detail::ElementsPerThread EPT, typename ... Is>
1219+ // Compute the stride contribution for a single dimension, eliding the
1220+ // load and multiply when the last dimension is known at dispatch time
1221+ // to have unit stride (via CapType::unit_stride_last).
1222+ template <int DIM, int RANK_VAL, typename CapType>
1223+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type
1224+ DimStride (index_t idx_val) const {
1225+ constexpr bool is_last = (DIM == RANK_VAL - 1 );
1226+ constexpr bool is_unit = CapType::unit_stride_last && is_last;
1227+ constexpr bool has_ept = (CapType::ept != detail::ElementsPerThread::ONE) && is_last;
1228+
1229+ if constexpr (is_unit && has_ept) {
1230+ return idx_val * static_cast <index_t >(CapType::ept);
1231+ } else if constexpr (is_unit) {
1232+ return idx_val;
1233+ } else if constexpr (has_ept) {
1234+ return idx_val * (this ->desc_ .Stride (DIM) * static_cast <index_t >(CapType::ept));
1235+ } else {
1236+ return idx_val * this ->desc_ .Stride (DIM);
1237+ }
1238+ }
1239+
1240+ // Optimized offset calculation for ranks 1-4 with explicit stride multiplications.
1241+ // When CapType::unit_stride_last is true, the stride load (ULDC) and
1242+ // multiply (IMAD) for the last dimension are elided.
1243+ template <typename CapType, typename ... Is>
12261244 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetOffsetOptimized (Is... indices) const {
1227- MATX_IGNORE_WARNING_PUSH_GCC (" -Wmaybe-uninitialized" )
1245+ MATX_IGNORE_WARNING_PUSH_GCC (" -Wmaybe-uninitialized" )
12281246 constexpr size_t rank = sizeof ...(Is);
1229- constexpr int EPT_int = static_cast <int >(EPT);
12301247 const cuda::std::array<index_t , rank> idx{indices...};
1231-
1248+
1249+ constexpr int R = static_cast <int >(rank);
1250+
12321251 if constexpr (rank == 1 ) {
1233- if constexpr (EPT != detail::ElementsPerThread::ONE) {
1234- return idx[0 ] * (this ->desc_ .Stride (0 ) * EPT_int);
1235- } else {
1236- return idx[0 ] * this ->desc_ .Stride (0 );
1237- }
1252+ return DimStride<0 , R, CapType>(idx[0 ]);
12381253 }
12391254 else if constexpr (rank == 2 ) {
1240- if constexpr (EPT != detail::ElementsPerThread::ONE) {
1241- return idx[0 ] * this ->desc_ .Stride (0 ) + idx[1 ] * (this ->desc_ .Stride (1 ) * EPT_int);
1242- } else {
1243- return idx[0 ] * this ->desc_ .Stride (0 ) + idx[1 ] * this ->desc_ .Stride (1 );
1244- }
1255+ return DimStride<0 , R, CapType>(idx[0 ])
1256+ + DimStride<1 , R, CapType>(idx[1 ]);
12451257 }
12461258 else if constexpr (rank == 3 ) {
1247- if constexpr (EPT != detail::ElementsPerThread::ONE) {
1248- return idx[0 ] * this ->desc_ .Stride (0 ) + idx[1 ] * this ->desc_ .Stride (1 ) + idx[2 ] * (this ->desc_ .Stride (2 ) * EPT_int);
1249- } else {
1250- return idx[0 ] * this ->desc_ .Stride (0 ) + idx[1 ] * this ->desc_ .Stride (1 ) + idx[2 ] * this ->desc_ .Stride (2 );
1251- }
1259+ return DimStride<0 , R, CapType>(idx[0 ])
1260+ + DimStride<1 , R, CapType>(idx[1 ])
1261+ + DimStride<2 , R, CapType>(idx[2 ]);
12521262 }
12531263 else if constexpr (rank == 4 ) {
1254- if constexpr (EPT != detail::ElementsPerThread::ONE) {
1255- return idx[0 ] * this ->desc_ .Stride (0 ) + idx[1 ] * this ->desc_ .Stride (1 ) + idx[2 ] * this ->desc_ .Stride (2 ) + idx[3 ] * (this ->desc_ .Stride (3 ) * EPT_int);
1256- } else {
1257- return idx[0 ] * this ->desc_ .Stride (0 ) + idx[1 ] * this ->desc_ .Stride (1 ) + idx[2 ] * this ->desc_ .Stride (2 ) + idx[3 ] * this ->desc_ .Stride (3 );
1258- }
1264+ return DimStride<0 , R, CapType>(idx[0 ])
1265+ + DimStride<1 , R, CapType>(idx[1 ])
1266+ + DimStride<2 , R, CapType>(idx[2 ])
1267+ + DimStride<3 , R, CapType>(idx[3 ]);
12591268 }
12601269 else {
12611270 // For rank > 4, fall back to the recursive implementation
1262- return GetValC<EPT , 0 , Is...>(cuda::std::make_tuple (indices...));
1271+ return GetValC<CapType , 0 , Is...>(cuda::std::make_tuple (indices...));
12631272 }
1264- MATX_IGNORE_WARNING_POP_GCC
1273+ MATX_IGNORE_WARNING_POP_GCC
12651274 }
12661275
1267- template <detail::ElementsPerThread EPT , int I = 0 , typename ...Is>
1276+ template <typename CapType , int I = 0 , typename ...Is>
12681277 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type GetValC ([[maybe_unused]] const cuda::std::tuple<Is...> tup) const {
12691278 if constexpr (I < sizeof ...(Is)) {
12701279MATX_IGNORE_WARNING_PUSH_GCC (" -Wmaybe-uninitialized" )
1271- if constexpr (EPT != detail::ElementsPerThread::ONE && I == sizeof ...(Is) - 1 ) {
1272- return GetValC<EPT, I+1 , Is...>(tup) + cuda::std::get<I>(tup)*(this ->desc_ .Stride (I) * static_cast <index_t >(EPT));
1273- }
1274- else {
1275- return GetValC<EPT, I+1 , Is...>(tup) + cuda::std::get<I>(tup)*(this ->desc_ .Stride (I));
1276- }
1280+ return GetValC<CapType, I+1 , Is...>(tup) + DimStride<I, static_cast <int >(sizeof ...(Is)), CapType>(cuda::std::get<I>(tup));
12771281MATX_IGNORE_WARNING_POP_GCC
12781282 }
12791283 else {
@@ -1299,7 +1303,7 @@ MATX_IGNORE_WARNING_POP_GCC
12991303 assert (data_.ldata_ != nullptr );
13001304#endif
13011305 constexpr int EPT_int = static_cast <int >(CapType::ept);
1302- const index_t offset = GetOffsetOptimized<CapType::ept >(indices...);
1306+ const index_t offset = GetOffsetOptimized<CapType>(indices...);
13031307
13041308 if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
13051309 return data_.ldata_ [offset];
@@ -1329,7 +1333,7 @@ MATX_IGNORE_WARNING_POP_GCC
13291333 {
13301334 static_assert (sizeof ...(Is) == M, " Number of indices of data_ptr must match rank of tensor" );
13311335 if constexpr (!is_sparse_data_v<TensorData>) {
1332- const index_t offset = GetOffsetOptimized<CapType::ept >(indices...);
1336+ const index_t offset = GetOffsetOptimized<CapType>(indices...);
13331337 return data_.ldata_ + offset;
13341338 }
13351339 else {
@@ -1363,7 +1367,7 @@ MATX_IGNORE_WARNING_POP_GCC
13631367 assert (data_.ldata_ != nullptr );
13641368#endif
13651369 constexpr int EPT_int = static_cast <int >(CapType::ept);
1366- const index_t offset = GetOffsetOptimized<CapType::ept >(indices...);
1370+ const index_t offset = GetOffsetOptimized<CapType>(indices...);
13671371
13681372 if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
13691373 return data_.ldata_ [offset];
@@ -1563,6 +1567,13 @@ MATX_IGNORE_WARNING_POP_GCC
15631567 return overlaps;
15641568 }
15651569 }
1570+ else if constexpr (Cap == OperatorCapability::UNIT_STRIDE_LAST) {
1571+ if constexpr (Rank () == 0 ) {
1572+ return true ;
1573+ } else {
1574+ return (Stride (Rank () - 1 ) == 1 );
1575+ }
1576+ }
15661577 else {
15671578 return detail::capability_attributes<Cap>::default_value;
15681579 }
0 commit comments