Skip to content

Commit a54b3c6

Browse files
committed
compiler: fix complex cxx arithmetic
1 parent 1520fa2 commit a54b3c6

2 files changed

Lines changed: 54 additions & 28 deletions

File tree

devito/passes/iet/languages/CXX.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def std_arith(prefix=''):
1919
prefix = prefix if prefix.endswith(' ') else f'{prefix} '
2020
return f"""
2121
#include <complex>
22+
#include <type_traits>
23+
24+
// ---- scalar <op> complex<T> (scalar promoted to T) --------------------
2225
2326
template<typename _Tp, typename _Ti>
2427
{prefix}std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){{
@@ -32,7 +35,7 @@ def std_arith(prefix=''):
3235
3336
template<typename _Tp, typename _Ti>
3437
{prefix}std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){{
35-
_Tp denom = b.real() * b.real () + b.imag() * b.imag();
38+
_Tp denom = b.real() * b.real() + b.imag() * b.imag();
3639
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
3740
}}
3841
@@ -53,14 +56,37 @@ def std_arith(prefix=''):
5356
5457
template<typename _Tp, typename _Ti>
5558
{prefix}std::complex<_Tp> operator - (const _Ti & a, const std::complex<_Tp> & b){{
56-
return std::complex<_Tp>(a = b.real(), b.imag());
59+
return std::complex<_Tp>(a - b.real(), -b.imag());
5760
}}
5861
5962
template<typename _Tp, typename _Ti>
6063
{prefix}std::complex<_Tp> operator - (const std::complex<_Tp> & b, const _Ti & a){{
6164
return std::complex<_Tp>(b.real() - a, b.imag());
6265
}}
6366
67+
// ---- mixed-precision complex<T1> <op> complex<T2> ----------------------
68+
// Promote both sides to std::complex<common_type_t<T1,T2>> and delegate to
69+
// the standard library's same-type operator. The enable_if disables the
70+
// overload when T1 == T2 so we don't collide with std::complex's own ops.
71+
72+
#define _MIXED_COMPLEX_OP(OP) \\
73+
template<typename _Tp1, typename _Tp2, \\
74+
typename _Tr = std::common_type_t<_Tp1, _Tp2>, \\
75+
typename = std::enable_if_t<!std::is_same<_Tp1, _Tp2>::value>> \\
76+
{prefix}std::complex<_Tr> \\
77+
operator OP (const std::complex<_Tp1> & a, \\
78+
const std::complex<_Tp2> & b) {{ \\
79+
return std::complex<_Tr>(a.real(), a.imag()) \\
80+
OP std::complex<_Tr>(b.real(), b.imag()); \\
81+
}}
82+
83+
_MIXED_COMPLEX_OP(*)
84+
_MIXED_COMPLEX_OP(/)
85+
_MIXED_COMPLEX_OP(+)
86+
_MIXED_COMPLEX_OP(-)
87+
88+
#undef _MIXED_COMPLEX_OP
89+
6490
"""
6591

6692

0 commit comments

Comments
 (0)