Skip to content

Commit db0bff5

Browse files
authored
Merge pull request #557 from ValeevGroup/evaleev/feature/tot-scale-strided-gemm
arena: recast ToT outer-contraction scale as strided BLAS GEMM
2 parents c70fa07 + 266f0a4 commit db0bff5

1 file changed

Lines changed: 192 additions & 0 deletions

File tree

src/TiledArray/tensor/tensor.h

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3096,6 +3096,198 @@ class Tensor {
30963096
const integer ldb =
30973097
(gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? N : K);
30983098

3099+
// GEMM-based ToT scale path: for the scale contraction
3100+
// "m,k;a" * "k,n" -> "m,n;a" (left ToT, right plain scalar), recast each
3101+
// row m as one strided GEMM result_m(A_m x N) += left_m(A_m x K) *
3102+
// right(K x N), directly on the arena slab -- amortizing the per-cell AXPY
3103+
// setup over a single BLAS call. Applies for NoTranspose, matching scalar
3104+
// type, and "clean" rows (all cells present, uniform inner size A_m, laid
3105+
// out as one contiguous single-page stride-A_m block); other rows fall back
3106+
// to the per-cell AXPY loop.
3107+
if constexpr (detail::is_numeric_v<V> && is_tensor_view_v<U> &&
3108+
is_tensor_view_v<value_type>) {
3109+
using Real = std::remove_cv_t<typename value_type::value_type>;
3110+
if constexpr (std::is_same_v<std::remove_cv_t<V>, Real>) {
3111+
if (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose &&
3112+
gemm_helper.right_op() == TiledArray::math::blas::NoTranspose) {
3113+
for (integer b = 0; b != nbatch(); ++b) {
3114+
auto this_data = this->batch_data(b);
3115+
auto left_data = left.batch_data(b);
3116+
auto right_data = right.batch_data(b); // K x N row-major scalars
3117+
for (integer m = 0; m != M; ++m) {
3118+
auto* lc0 = left_data + (m * K); // left cells (m,0..K-1)
3119+
auto* rc0 = this_data + (m * N); // result cells (m,0..N-1)
3120+
// A "clean" row has all cells present, uniform inner size A, and
3121+
// laid out as one contiguous stride-A block (so the GEMM can run
3122+
// zero-copy directly on the slab). Else fall back to per-cell
3123+
// AXPY.
3124+
long A = -1;
3125+
bool clean = true;
3126+
for (integer k = 0; k != K && clean; ++k) {
3127+
const auto& c = lc0[k];
3128+
if (c.empty()) {
3129+
clean = false;
3130+
break;
3131+
}
3132+
long s = static_cast<long>(c.size());
3133+
if (A < 0)
3134+
A = s;
3135+
else if (A != s)
3136+
clean = false;
3137+
}
3138+
for (integer n = 0; n != N && clean; ++n) {
3139+
const auto& c = rc0[n];
3140+
if (c.empty()) {
3141+
clean = false;
3142+
break;
3143+
}
3144+
long s = static_cast<long>(c.size());
3145+
if (A < 0)
3146+
A = s;
3147+
else if (A != s)
3148+
clean = false;
3149+
}
3150+
// Arena cells are SIMD-padded, so the per-row inter-cell stride
3151+
// is the padded inner size (>= A). The strided GEMM requires the
3152+
// row's cells to be ONE contiguous run at constant stride -- only
3153+
// true for a single-page arena. An incrementally-built (un-
3154+
// compacted) ToT tile may span multiple pages, where the stride
3155+
// jumps at a page boundary; verify constant stride across ALL
3156+
// cells (so multi-page tiles fall back to the AXPY loop).
3157+
integer ldb = static_cast<integer>(A);
3158+
integer ldc = static_cast<integer>(A);
3159+
if (clean && A > 0) {
3160+
if (K > 1)
3161+
ldb = static_cast<integer>(lc0[1].data() - lc0[0].data());
3162+
if (N > 1)
3163+
ldc = static_cast<integer>(rc0[1].data() - rc0[0].data());
3164+
if (ldb < A || ldc < A) clean = false; // sanity
3165+
const std::ptrdiff_t sb = ldb, sc = ldc;
3166+
for (integer k = 0; clean && k != K; ++k)
3167+
if (lc0[k].data() != lc0[0].data() + k * sb) clean = false;
3168+
for (integer n = 0; clean && n != N; ++n)
3169+
if (rc0[n].data() != rc0[0].data() + n * sc) clean = false;
3170+
}
3171+
if (A <= 0) continue; // empty row -> nothing to do
3172+
if (clean) {
3173+
// result[m,n][a] += sum_k left[m,k][a] * right[k,n].
3174+
// Row-major gemm: C2(N x A) += right^T(N x K) * L2(K x A),
3175+
// where L2 = left row-m slab (K x A, ld=ldb), C2 = result row-m
3176+
// slab (N x A, ld=ldc), right is K x N (ld=N). ldb/ldc carry
3177+
// padding.
3178+
const integer Ai = static_cast<integer>(A);
3179+
TiledArray::math::blas::gemm(
3180+
TiledArray::math::blas::Transpose,
3181+
TiledArray::math::blas::NoTranspose,
3182+
/*M=*/N, /*N=*/Ai, /*K=*/K, Real(1),
3183+
/*A=*/right_data, /*lda=*/N,
3184+
/*B=*/lc0[0].data(), /*ldb=*/ldb, Real(1),
3185+
/*C=*/rc0[0].data(), /*ldc=*/ldc);
3186+
} else { // per-cell AXPY fallback for this row
3187+
for (integer n = 0; n != N; ++n) {
3188+
auto c_offset = m * N + n;
3189+
for (integer k = 0; k != K; ++k)
3190+
elem_muladd_op(*(this_data + c_offset),
3191+
*(left_data + (m * K + k)),
3192+
*(right_data + (k * N + n)));
3193+
}
3194+
}
3195+
}
3196+
}
3197+
return *this;
3198+
}
3199+
}
3200+
}
3201+
3202+
// GEMM-based scale path, mirror for T * ToT ("m,k" * "k,n;a" -> "m,n;a",
3203+
// left plain scalar, right ToT). Per column n: one GEMM
3204+
// result_n(M x A_n) += left(M x K) * right_n(K x A_n). The right/result
3205+
// column-n cells are strided over the slab (constant k-/m-stride within a
3206+
// single arena page); verify that, else fall back to per-cell AXPY.
3207+
if constexpr (detail::is_numeric_v<U> && is_tensor_view_v<V> &&
3208+
is_tensor_view_v<value_type>) {
3209+
using Real = std::remove_cv_t<typename value_type::value_type>;
3210+
if constexpr (std::is_same_v<std::remove_cv_t<U>, Real>) {
3211+
if (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose &&
3212+
gemm_helper.right_op() == TiledArray::math::blas::NoTranspose) {
3213+
for (integer b = 0; b != nbatch(); ++b) {
3214+
auto this_data = this->batch_data(b);
3215+
auto left_data = left.batch_data(b); // M x K row-major scalars
3216+
auto right_data = right.batch_data(b); // K x N ToT
3217+
for (integer n = 0; n != N; ++n) {
3218+
long A = -1;
3219+
bool clean = true;
3220+
for (integer k = 0; k != K && clean; ++k) {
3221+
const auto& c = right_data[k * N + n];
3222+
if (c.empty()) {
3223+
clean = false;
3224+
break;
3225+
}
3226+
long s = static_cast<long>(c.size());
3227+
if (A < 0)
3228+
A = s;
3229+
else if (A != s)
3230+
clean = false;
3231+
}
3232+
for (integer m = 0; m != M && clean; ++m) {
3233+
const auto& c = this_data[m * N + n];
3234+
if (c.empty()) {
3235+
clean = false;
3236+
break;
3237+
}
3238+
long s = static_cast<long>(c.size());
3239+
if (A < 0)
3240+
A = s;
3241+
else if (A != s)
3242+
clean = false;
3243+
}
3244+
integer ldb = static_cast<integer>(A); // k-stride, right col n
3245+
integer ldc = static_cast<integer>(A); // m-stride, result col n
3246+
if (clean && A > 0) {
3247+
if (K > 1)
3248+
ldb = static_cast<integer>(right_data[N + n].data() -
3249+
right_data[n].data());
3250+
if (M > 1)
3251+
ldc = static_cast<integer>(this_data[N + n].data() -
3252+
this_data[n].data());
3253+
if (ldb < A || ldc < A) clean = false;
3254+
const std::ptrdiff_t sb = ldb, sc = ldc;
3255+
for (integer k = 0; clean && k != K; ++k)
3256+
if (right_data[k * N + n].data() !=
3257+
right_data[n].data() + k * sb)
3258+
clean = false;
3259+
for (integer m = 0; clean && m != M; ++m)
3260+
if (this_data[m * N + n].data() !=
3261+
this_data[n].data() + m * sc)
3262+
clean = false;
3263+
}
3264+
if (A <= 0) continue;
3265+
if (clean) {
3266+
// C_n(M x A) += left(M x K) * B_n(K x A). Row-major gemm.
3267+
const integer Ai = static_cast<integer>(A);
3268+
TiledArray::math::blas::gemm(
3269+
TiledArray::math::blas::NoTranspose,
3270+
TiledArray::math::blas::NoTranspose,
3271+
/*M=*/M, /*N=*/Ai, /*K=*/K, Real(1),
3272+
/*A=*/left_data, /*lda=*/K,
3273+
/*B=*/right_data[n].data(), /*ldb=*/ldb, Real(1),
3274+
/*C=*/this_data[n].data(), /*ldc=*/ldc);
3275+
} else { // per-cell AXPY fallback for this column
3276+
for (integer m = 0; m != M; ++m) {
3277+
auto c_offset = m * N + n;
3278+
for (integer k = 0; k != K; ++k)
3279+
elem_muladd_op(*(this_data + c_offset),
3280+
*(left_data + (m * K + k)),
3281+
*(right_data + (k * N + n)));
3282+
}
3283+
}
3284+
}
3285+
}
3286+
return *this;
3287+
}
3288+
}
3289+
}
3290+
30993291
for (integer b = 0; b != nbatch(); ++b) {
31003292
auto this_data = this->batch_data(b);
31013293
auto left_data = left.batch_data(b);

0 commit comments

Comments
 (0)