Skip to content

Commit ba5efa2

Browse files
committed
device: small UMTensor cleanups
- shift_to: call Tensor::shift_to instead of const_cast'ing the range. TA::Tensor exposes a public shift_to member (unlike btas::Tensor), so the const_cast inherited from btas_um_tensor.h is unnecessary here. - apply_scale_factor: flatten 3-level nested if constexpr into one else-if-constexpr cascade.
1 parent 3fabd11 commit ba5efa2

1 file changed

Lines changed: 13 additions & 20 deletions

File tree

src/TiledArray/device/tensor.h

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,22 +145,18 @@ inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor,
145145
if constexpr (TiledArray::detail::is_blas_numeric_v<Scalar> ||
146146
std::is_arithmetic_v<Scalar>) {
147147
::blas::scal(n, factor, data, 1, queue);
148-
} else {
149-
if constexpr (TiledArray::detail::is_complex_v<T>) {
150-
TA_EXCEPTION(
151-
"UMTensor scale with ComplexConjugate factor on complex T is not "
152-
"implemented (requires a fused conjugation kernel)");
153-
} else {
154-
if constexpr (std::is_same_v<
155-
Scalar, TiledArray::detail::ComplexConjugate<void>>) {
156-
// conjugation on a real tensor is a no-op
157-
} else if constexpr (std::is_same_v<
158-
Scalar,
159-
TiledArray::detail::ComplexConjugate<
160-
TiledArray::detail::ComplexNegTag>>) {
161-
::blas::scal(n, static_cast<T>(-1), data, 1, queue);
162-
}
163-
}
148+
} else if constexpr (TiledArray::detail::is_complex_v<T>) {
149+
TA_EXCEPTION(
150+
"UMTensor scale with ComplexConjugate factor on complex T is not "
151+
"implemented (requires a fused conjugation kernel)");
152+
} else if constexpr (std::is_same_v<
153+
Scalar,
154+
TiledArray::detail::ComplexConjugate<void>>) {
155+
// conjugation on a real tensor is a no-op
156+
} else if constexpr (std::is_same_v<
157+
Scalar, TiledArray::detail::ComplexConjugate<
158+
TiledArray::detail::ComplexNegTag>>) {
159+
::blas::scal(n, static_cast<T>(-1), data, 1, queue);
164160
}
165161
}
166162

@@ -546,10 +542,7 @@ inline UMTensor<T> shift(const UMTensor<T>& arg, const Index& bound_shift) {
546542
template <typename T, typename Index>
547543
requires TiledArray::detail::is_numeric_v<T>
548544
inline UMTensor<T>& shift_to(UMTensor<T>& arg, const Index& bound_shift) {
549-
// `range()` only exposes a const accessor; cast is safe because we are the
550-
// tile's owner here and only the range bounds change, not the data layout.
551-
const_cast<TiledArray::Range&>(arg.range()).inplace_shift(bound_shift);
552-
return arg;
545+
return arg.shift_to(bound_shift);
553546
}
554547

555548
template <typename T, typename Index>

0 commit comments

Comments
 (0)