Skip to content

Commit 242b88b

Browse files
authored
Added planar types to speed up complex half precision GEMMs (#1142)
1 parent d2f550c commit 242b88b

26 files changed

Lines changed: 680 additions & 95 deletions

include/matx/core/allocator.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,36 @@ struct MemTracker {
128128

129129
matxMemoryStats.currentBytesAllocated -= bytes;
130130

131+
// Check if the CUDA context is still valid before attempting to free.
132+
// During static destruction at program exit, the CUDA context may already
133+
// be destroyed, making cudaFree/cudaFreeAsync calls fail with
134+
// CUDA_ERROR_CONTEXT_IS_DESTROYED.
135+
auto is_cuda_free = [&]() {
136+
if (iter->second.kind == MATX_HOST_MALLOC_MEMORY) return true; // not CUDA
137+
int dev;
138+
return cudaGetDevice(&dev) == cudaSuccess;
139+
};
140+
131141
switch (iter->second.kind) {
132142
case MATX_MANAGED_MEMORY:
133143
[[fallthrough]];
134144
case MATX_DEVICE_MEMORY:
135-
cudaFree(ptr);
145+
if (is_cuda_free()) cudaFree(ptr);
136146
break;
137147
case MATX_HOST_MEMORY:
138-
cudaFreeHost(ptr);
148+
if (is_cuda_free()) cudaFreeHost(ptr);
139149
break;
140150
case MATX_HOST_MALLOC_MEMORY:
141151
free(ptr);
142152
break;
143153
case MATX_ASYNC_DEVICE_MEMORY:
144-
if constexpr (std::is_same_v<no_stream_t, StreamType>) {
145-
cudaFreeAsync(ptr, iter->second.stream);
146-
}
147-
else {
148-
cudaFreeAsync(ptr, st.stream);
154+
if (is_cuda_free()) {
155+
if constexpr (std::is_same_v<no_stream_t, StreamType>) {
156+
cudaFreeAsync(ptr, iter->second.stream);
157+
}
158+
else {
159+
cudaFreeAsync(ptr, st.stream);
160+
}
149161
}
150162
break;
151163
default:

include/matx/core/half_complex.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,30 @@ tanh(const matxHalfComplex<T> &x)
10551055
using matxFp16Complex = matxHalfComplex<matxFp16>; ///< Alias for a MatX fp16 complex wrapper
10561056
using matxBf16Complex = matxHalfComplex<matxBf16>; ///< Alias for a MatXbf16 complex wrapper
10571057

1058+
struct matxFp16ComplexPlanar : public matxFp16Complex {
1059+
using matxFp16Complex::matxFp16Complex;
1060+
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar() = default;
1061+
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar(const matxFp16Complex &rhs) : matxFp16Complex(rhs) {}
1062+
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar &operator=(const matxFp16Complex &rhs)
1063+
{
1064+
this->x = rhs.x;
1065+
this->y = rhs.y;
1066+
return *this;
1067+
}
1068+
};
1069+
1070+
struct matxBf16ComplexPlanar : public matxBf16Complex {
1071+
using matxBf16Complex::matxBf16Complex;
1072+
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar() = default;
1073+
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar(const matxBf16Complex &rhs) : matxBf16Complex(rhs) {}
1074+
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar &operator=(const matxBf16Complex &rhs)
1075+
{
1076+
this->x = rhs.x;
1077+
this->y = rhs.y;
1078+
return *this;
1079+
}
1080+
};
1081+
10581082
}; // namespace matx
10591083

10601084
// cuda::std::numeric_limits specializations for matxFp16Complex and matxBf16Complex

include/matx/core/operator_utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ namespace matx {
103103
if(supported) {
104104
return make_tensor<typename Op::value_type>(in.Data(), in.Descriptor());
105105
} else {
106-
return make_tensor<typename Op::value_type>(in.Shape(), space, stream);
106+
// Fresh allocation is row-major contiguous; copying an affine stride
107+
// descriptor onto it would break transforms (e.g. cuFFT batch distance).
108+
return make_tensor<typename Op::value_type>(Shape(in), space, stream);
107109
}
108110
}
109111
}

include/matx/core/tensor.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
179179
detail::tensor_impl_t<T, RANK, Desc>{std::forward<D2>(desc)},
180180
storage_{std::move(s)}
181181
{
182+
ValidatePlanarLayoutOnCreate_();
182183
this->SetLocalData(storage_.data());
183184
}
184185

@@ -194,6 +195,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
194195
detail::tensor_impl_t<T, RANK, D2>{std::forward<D2>(desc)},
195196
storage_{std::move(s)}
196197
{
198+
ValidatePlanarLayoutOnCreate_();
197199
this->SetLocalData(ldata);
198200
}
199201

@@ -210,6 +212,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
210212
detail::tensor_impl_t<T, RANK, D2>{std::forward<D2>(desc)},
211213
storage_{make_owning_storage<T>(this->desc_.TotalSize())}
212214
{
215+
ValidatePlanarLayoutOnCreate_();
213216
this->SetLocalData(storage_.data());
214217
}
215218

@@ -225,6 +228,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
225228
detail::tensor_impl_t<T, RANK, Desc>(cuda::std::array<index_t, 0>{}),
226229
storage_{make_owning_storage<T>(1)}
227230
{
231+
ValidatePlanarLayoutOnCreate_();
228232
this->SetLocalData(storage_.data());
229233
}
230234

@@ -239,6 +243,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
239243
detail::tensor_impl_t<T, RANK, Desc>(shape),
240244
storage_{make_owning_storage<T>(this->desc_.TotalSize())}
241245
{
246+
ValidatePlanarLayoutOnCreate_();
242247
this->SetLocalData(storage_.data());
243248
}
244249

@@ -956,6 +961,7 @@ MATX_LOOP_UNROLL
956961
Reset(T *const data, ShapeType &&shape) noexcept
957962
{
958963
this->desc_.InitFromShape(std::forward<ShapeType>(shape));
964+
ValidatePlanarLayoutOnCreate_();
959965
// For non-owning storage, we need to recreate the storage object
960966
storage_ = make_non_owning_storage<T>(data, this->desc_.TotalSize());
961967
this->SetData(data);
@@ -977,6 +983,7 @@ MATX_LOOP_UNROLL
977983
{
978984
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
979985

986+
ValidatePlanarLayoutOnCreate_();
980987
// For non-owning storage, we need to recreate the storage object
981988
storage_ = make_non_owning_storage<T>(data, this->desc_.TotalSize());
982989
this->SetData(data);
@@ -998,6 +1005,7 @@ MATX_LOOP_UNROLL
9981005
__MATX_HOST__ __MATX_INLINE__ void
9991006
Reset(T *const data, T *const ldata) noexcept
10001007
{
1008+
ValidatePlanarLayoutOnCreate_();
10011009
// For non-owning storage, we need to recreate the storage object
10021010
storage_ = make_non_owning_storage<T>(data, this->desc_.TotalSize());
10031011
this->SetData(ldata);
@@ -1541,6 +1549,18 @@ MATX_LOOP_UNROLL
15411549
}
15421550

15431551
private:
1552+
__MATX_HOST__ __MATX_INLINE__ void ValidatePlanarLayoutOnCreate_() const
1553+
{
1554+
if constexpr (is_planar_complex_v<T>) {
1555+
if constexpr (RANK > 0) {
1556+
MATX_ASSERT_STR(this->Stride(RANK - 1) == 1, matxInvalidDim,
1557+
"Planar complex tensors must have unit innermost stride");
1558+
}
1559+
MATX_ASSERT_STR(this->IsContiguous(), matxInvalidDim,
1560+
"Planar complex tensors must be contiguous (non-unity strides are not supported)");
1561+
}
1562+
}
1563+
15441564
Storage<T> storage_;
15451565
std::string name_ = std::string("tensor_") + std::to_string(RANK) + "_" + detail::to_short_str<T>();
15461566
};

include/matx/core/tensor_desc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <type_traits>
4242
#include <cstdint>
4343
#include "matx/core/error.h"
44+
#include "matx/core/type_utils.h"
4445

4546
namespace matx {
4647

include/matx/core/tensor_impl.h

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,52 @@ class tensor_impl_t {
9999
using data_type = TensorData;
100100
using shape_type = typename Desc::shape_type;
101101
using stride_type = typename Desc::stride_type;
102+
using shape_container = typename Desc::shape_container;
103+
using stride_container = typename Desc::stride_container;
102104
using matxoplvalue = bool;
103105
using self_type = tensor_impl_t<T, RANK, Desc, TensorData>;
104106

107+
// Planar complex wrappers store real/imag in separate contiguous planes:
108+
// [real_0..real_n-1][imag_0..imag_n-1]. Since there is no contiguous T object
109+
// at element i, operator() cannot return a true T&. This proxy provides
110+
// reference-like read/write semantics for expression assignment paths.
111+
struct PlanarComplexProxy {
112+
self_type *self;
113+
index_t offset;
114+
115+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator T() const
116+
{
117+
return self->LoadPlanarComplex(offset);
118+
}
119+
120+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator=(const T &rhs)
121+
{
122+
self->StorePlanarComplex(offset, rhs);
123+
return *this;
124+
}
125+
126+
template <typename U>
127+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator=(const U &rhs)
128+
requires requires(const U &u) { u.real(); u.imag(); }
129+
{
130+
T tmp{};
131+
tmp.real(rhs.real());
132+
tmp.imag(rhs.imag());
133+
self->StorePlanarComplex(offset, tmp);
134+
return *this;
135+
}
136+
137+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto real() const
138+
{
139+
return self->LoadPlanarComplex(offset).real();
140+
}
141+
142+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto imag() const
143+
{
144+
return self->LoadPlanarComplex(offset).imag();
145+
}
146+
};
147+
105148
// Type specifier for signaling this is a matx operation
106149
using matxop = bool;
107150

@@ -1031,7 +1074,8 @@ MATX_IGNORE_WARNING_POP_GCC
10311074
s[i] = this->Stride(d);
10321075
}
10331076

1034-
return Desc{std::move(n), std::move(s)};
1077+
auto new_desc = Desc{std::move(n), std::move(s)};
1078+
return new_desc;
10351079
}
10361080

10371081
__MATX_INLINE__ auto Permute(const cuda::std::array<int32_t, RANK> &dims) const
@@ -1306,7 +1350,12 @@ MATX_IGNORE_WARNING_POP_GCC
13061350
const index_t offset = GetOffsetOptimized<CapType>(indices...);
13071351

13081352
if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
1309-
return data_.ldata_[offset];
1353+
if constexpr (is_planar_complex_v<T>) {
1354+
return LoadPlanarComplex(offset);
1355+
}
1356+
else {
1357+
return data_.ldata_[offset];
1358+
}
13101359
} else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {
13111360
return *reinterpret_cast<detail::Vector<T, EPT_int>*>(data_.ldata_ + offset);
13121361
} else {
@@ -1370,7 +1419,12 @@ MATX_IGNORE_WARNING_POP_GCC
13701419
const index_t offset = GetOffsetOptimized<CapType>(indices...);
13711420

13721421
if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
1373-
return data_.ldata_[offset];
1422+
if constexpr (is_planar_complex_v<T>) {
1423+
return PlanarComplexProxy{this, offset};
1424+
}
1425+
else {
1426+
return data_.ldata_[offset];
1427+
}
13741428
} else {
13751429
return *reinterpret_cast<detail::Vector<T, EPT_int>*>(data_.ldata_ + offset);
13761430
}
@@ -1390,7 +1444,7 @@ MATX_IGNORE_WARNING_POP_GCC
13901444
template <typename CapType>
13911445
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) const noexcept
13921446
{
1393-
return cuda::std::apply([&](auto &&...args) -> T {
1447+
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
13941448
return this->operator()<CapType>(args...);
13951449
}, idx);
13961450
}
@@ -1404,7 +1458,7 @@ MATX_IGNORE_WARNING_POP_GCC
14041458
template <typename CapType>
14051459
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) noexcept
14061460
{
1407-
return cuda::std::apply([&](auto &&...args) -> T& {
1461+
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
14081462
return this->operator()<CapType>(args...);
14091463
}, idx);
14101464
}
@@ -1417,7 +1471,7 @@ MATX_IGNORE_WARNING_POP_GCC
14171471
*/
14181472
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) const noexcept
14191473
{
1420-
return cuda::std::apply([&](auto &&...args) -> T {
1474+
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
14211475
return this->operator()<DefaultCapabilities>(args...);
14221476
}, idx);
14231477
}
@@ -1430,7 +1484,7 @@ MATX_IGNORE_WARNING_POP_GCC
14301484
*/
14311485
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) noexcept
14321486
{
1433-
return cuda::std::apply([&](auto &&...args) -> T& {
1487+
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
14341488
return this->operator()<DefaultCapabilities>(args...);
14351489
}, idx);
14361490
}
@@ -1441,6 +1495,10 @@ MATX_IGNORE_WARNING_POP_GCC
14411495
// Since tensors are a "leaf" operator type, we will never have an operator passed to a tensor as the
14421496
// type, but only POD types.
14431497
if constexpr (Cap == detail::OperatorCapability::ELEMENTS_PER_THREAD) {
1498+
if constexpr (is_planar_complex_v<T>) {
1499+
return cuda::std::array<detail::ElementsPerThread, 2>{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE};
1500+
}
1501+
14441502
if constexpr (Rank() == 0) {
14451503
return cuda::std::array<detail::ElementsPerThread, 2>{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE};
14461504
}
@@ -1713,6 +1771,27 @@ MATX_IGNORE_WARNING_POP_GCC
17131771
protected:
17141772
TensorData data_;
17151773
Desc desc_;
1774+
1775+
private:
1776+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T LoadPlanarComplex(index_t offset) const
1777+
{
1778+
using Scalar = typename T::value_type;
1779+
const auto *base = reinterpret_cast<const Scalar *>(data_.ldata_);
1780+
const index_t total = this->TotalSize();
1781+
T out{};
1782+
out.real(base[offset]);
1783+
out.imag(base[offset + total]);
1784+
return out;
1785+
}
1786+
1787+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void StorePlanarComplex(index_t offset, const T &v)
1788+
{
1789+
using Scalar = typename T::value_type;
1790+
auto *base = reinterpret_cast<Scalar *>(data_.ldata_);
1791+
const index_t total = this->TotalSize();
1792+
base[offset] = v.real();
1793+
base[offset + total] = v.imag();
1794+
}
17161795
};
17171796

17181797
}

include/matx/core/tensor_utils.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,22 @@ namespace matx
129129
return "f32";
130130
if constexpr (std::is_same_v<T, double>)
131131
return "f64";
132-
if constexpr (std::is_same_v<T, matxHalf<__half>>)
132+
if constexpr (std::is_same_v<T, matxFp16>)
133133
return "f16";
134-
if constexpr (std::is_same_v<T, matxHalf<__nv_bfloat16>>)
134+
if constexpr (std::is_same_v<T, matxBf16>)
135135
return "bf16";
136136
else
137137
return "x" + std::to_string(sizeof(T)*8);
138138
}
139139
else {
140+
if constexpr (std::is_same_v<T, matxFp16ComplexPlanar>)
141+
return "f16cp";
142+
if constexpr (std::is_same_v<T, matxBf16ComplexPlanar>)
143+
return "bf16cp";
144+
if constexpr (std::is_same_v<T, matxFp16Complex>)
145+
return "f16c";
146+
if constexpr (std::is_same_v<T, matxBf16Complex>)
147+
return "bf16c";
140148
if constexpr (std::is_same_v<typename T::value_type, int32_t>)
141149
return "i32c";
142150
if constexpr (std::is_same_v<typename T::value_type, uint32_t>)
@@ -149,10 +157,6 @@ namespace matx
149157
return "f32c";
150158
if constexpr (std::is_same_v<typename T::value_type, double>)
151159
return "f64c";
152-
if constexpr (std::is_same_v<typename T::value_type, matxHalf<__half>>)
153-
return "f16";
154-
if constexpr (std::is_same_v<typename T::value_type, matxHalf<__nv_bfloat16>>)
155-
return "bf16";
156160
else
157161
return "x" + std::to_string(sizeof(typename T::value_type)*8) + "c";
158162
}
@@ -199,6 +203,10 @@ namespace matx
199203
return {kDLComplex, 32, 1};
200204
if constexpr (std::is_same_v<T, matxBf16Complex>)
201205
return {kDLComplex, 32, 1}; // Wrong, but no other choice
206+
if constexpr (std::is_same_v<T, matxFp16ComplexPlanar>)
207+
return {kDLComplex, 32, 1};
208+
if constexpr (std::is_same_v<T, matxBf16ComplexPlanar>)
209+
return {kDLComplex, 32, 1}; // Wrong, but no other choice
202210
if constexpr (std::is_same_v<T, float>)
203211
return {kDLFloat, 32, 1};
204212
if constexpr (std::is_same_v<T, double>)

0 commit comments

Comments
 (0)