@@ -99,9 +99,52 @@ class tensor_impl_t {
9999 using data_type = TensorData;
100100 using shape_type = typename Desc::shape_type;
101101 using stride_type = typename Desc::stride_type;
102+ using shape_container = typename Desc::shape_container;
103+ using stride_container = typename Desc::stride_container;
102104 using matxoplvalue = bool ;
103105 using self_type = tensor_impl_t <T, RANK, Desc, TensorData>;
104106
107+ // Planar complex wrappers store real/imag in separate contiguous planes:
108+ // [real_0..real_n-1][imag_0..imag_n-1]. Since there is no contiguous T object
109+ // at element i, operator() cannot return a true T&. This proxy provides
110+ // reference-like read/write semantics for expression assignment paths.
111+ struct PlanarComplexProxy {
112+ self_type *self;
113+ index_t offset;
114+
115+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator T () const
116+ {
117+ return self->LoadPlanarComplex (offset);
118+ }
119+
120+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator =(const T &rhs)
121+ {
122+ self->StorePlanarComplex (offset, rhs);
123+ return *this ;
124+ }
125+
126+ template <typename U>
127+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator =(const U &rhs)
128+ requires requires (const U &u) { u.real (); u.imag (); }
129+ {
130+ T tmp{};
131+ tmp.real (rhs.real ());
132+ tmp.imag (rhs.imag ());
133+ self->StorePlanarComplex (offset, tmp);
134+ return *this ;
135+ }
136+
137+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto real () const
138+ {
139+ return self->LoadPlanarComplex (offset).real ();
140+ }
141+
142+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto imag () const
143+ {
144+ return self->LoadPlanarComplex (offset).imag ();
145+ }
146+ };
147+
105148 // Type specifier for signaling this is a matx operation
106149 using matxop = bool ;
107150
@@ -1031,7 +1074,8 @@ MATX_IGNORE_WARNING_POP_GCC
10311074 s[i] = this ->Stride (d);
10321075 }
10331076
1034- return Desc{std::move (n), std::move (s)};
1077+ auto new_desc = Desc{std::move (n), std::move (s)};
1078+ return new_desc;
10351079 }
10361080
10371081 __MATX_INLINE__ auto Permute (const cuda::std::array<int32_t , RANK> &dims) const
@@ -1306,7 +1350,12 @@ MATX_IGNORE_WARNING_POP_GCC
13061350 const index_t offset = GetOffsetOptimized<CapType>(indices...);
13071351
13081352 if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
1309- return data_.ldata_ [offset];
1353+ if constexpr (is_planar_complex_v<T>) {
1354+ return LoadPlanarComplex (offset);
1355+ }
1356+ else {
1357+ return data_.ldata_ [offset];
1358+ }
13101359 } else if constexpr (EPT_int * sizeof (T) <= MAX_VEC_WIDTH_BYTES ) {
13111360 return *reinterpret_cast <detail::Vector<T, EPT_int>*>(data_.ldata_ + offset);
13121361 } else {
@@ -1370,7 +1419,12 @@ MATX_IGNORE_WARNING_POP_GCC
13701419 const index_t offset = GetOffsetOptimized<CapType>(indices...);
13711420
13721421 if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
1373- return data_.ldata_ [offset];
1422+ if constexpr (is_planar_complex_v<T>) {
1423+ return PlanarComplexProxy{this , offset};
1424+ }
1425+ else {
1426+ return data_.ldata_ [offset];
1427+ }
13741428 } else {
13751429 return *reinterpret_cast <detail::Vector<T, EPT_int>*>(data_.ldata_ + offset);
13761430 }
@@ -1390,7 +1444,7 @@ MATX_IGNORE_WARNING_POP_GCC
13901444 template <typename CapType>
13911445 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype (auto ) operator()(const cuda::std::array<index_t , RANK> &idx) const noexcept
13921446 {
1393- return cuda::std::apply ([&](auto &&...args ) -> T {
1447+ return cuda::std::apply ([&](auto &&...args ) -> decltype ( auto ) {
13941448 return this ->operator ()<CapType>(args...);
13951449 }, idx);
13961450 }
@@ -1404,7 +1458,7 @@ MATX_IGNORE_WARNING_POP_GCC
14041458 template <typename CapType>
14051459 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype (auto ) operator()(const cuda::std::array<index_t , RANK> &idx) noexcept
14061460 {
1407- return cuda::std::apply ([&](auto &&...args ) -> T& {
1461+ return cuda::std::apply ([&](auto &&...args ) -> decltype ( auto ) {
14081462 return this ->operator ()<CapType>(args...);
14091463 }, idx);
14101464 }
@@ -1417,7 +1471,7 @@ MATX_IGNORE_WARNING_POP_GCC
14171471 */
14181472 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype (auto ) operator()(const cuda::std::array<index_t , RANK> &idx) const noexcept
14191473 {
1420- return cuda::std::apply ([&](auto &&...args ) -> T {
1474+ return cuda::std::apply ([&](auto &&...args ) -> decltype ( auto ) {
14211475 return this ->operator ()<DefaultCapabilities>(args...);
14221476 }, idx);
14231477 }
@@ -1430,7 +1484,7 @@ MATX_IGNORE_WARNING_POP_GCC
14301484 */
14311485 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype (auto ) operator()(const cuda::std::array<index_t , RANK> &idx) noexcept
14321486 {
1433- return cuda::std::apply ([&](auto &&...args ) -> T& {
1487+ return cuda::std::apply ([&](auto &&...args ) -> decltype ( auto ) {
14341488 return this ->operator ()<DefaultCapabilities>(args...);
14351489 }, idx);
14361490 }
@@ -1441,6 +1495,10 @@ MATX_IGNORE_WARNING_POP_GCC
14411495 // Since tensors are a "leaf" operator type, we will never have an operator passed to a tensor as the
14421496 // type, but only POD types.
14431497 if constexpr (Cap == detail::OperatorCapability::ELEMENTS_PER_THREAD) {
1498+ if constexpr (is_planar_complex_v<T>) {
1499+ return cuda::std::array<detail::ElementsPerThread, 2 >{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE};
1500+ }
1501+
14441502 if constexpr (Rank () == 0 ) {
14451503 return cuda::std::array<detail::ElementsPerThread, 2 >{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE};
14461504 }
@@ -1713,6 +1771,27 @@ MATX_IGNORE_WARNING_POP_GCC
17131771 protected:
17141772 TensorData data_;
17151773 Desc desc_;
1774+
1775+ private:
1776+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T LoadPlanarComplex (index_t offset) const
1777+ {
1778+ using Scalar = typename T::value_type;
1779+ const auto *base = reinterpret_cast <const Scalar *>(data_.ldata_ );
1780+ const index_t total = this ->TotalSize ();
1781+ T out{};
1782+ out.real (base[offset]);
1783+ out.imag (base[offset + total]);
1784+ return out;
1785+ }
1786+
1787+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void StorePlanarComplex (index_t offset, const T &v)
1788+ {
1789+ using Scalar = typename T::value_type;
1790+ auto *base = reinterpret_cast <Scalar *>(data_.ldata_ );
1791+ const index_t total = this ->TotalSize ();
1792+ base[offset] = v.real ();
1793+ base[offset + total] = v.imag ();
1794+ }
17161795};
17171796
17181797}
0 commit comments