@@ -3096,6 +3096,198 @@ class Tensor {
30963096 const integer ldb =
30973097 (gemm_helper.right_op () == TiledArray::math::blas::NoTranspose ? N : K);
30983098
3099+ // GEMM-based ToT scale path: for the scale contraction
3100+ // "m,k;a" * "k,n" -> "m,n;a" (left ToT, right plain scalar), recast each
3101+ // row m as one strided GEMM result_m(A_m x N) += left_m(A_m x K) *
3102+ // right(K x N), directly on the arena slab -- amortizing the per-cell AXPY
3103+ // setup over a single BLAS call. Applies for NoTranspose, matching scalar
3104+ // type, and "clean" rows (all cells present, uniform inner size A_m, laid
3105+ // out as one contiguous single-page stride-A_m block); other rows fall back
3106+ // to the per-cell AXPY loop.
3107+ if constexpr (detail::is_numeric_v<V> && is_tensor_view_v<U> &&
3108+ is_tensor_view_v<value_type>) {
3109+ using Real = std::remove_cv_t <typename value_type::value_type>;
3110+ if constexpr (std::is_same_v<std::remove_cv_t <V>, Real>) {
3111+ if (gemm_helper.left_op () == TiledArray::math::blas::NoTranspose &&
3112+ gemm_helper.right_op () == TiledArray::math::blas::NoTranspose) {
3113+ for (integer b = 0 ; b != nbatch (); ++b) {
3114+ auto this_data = this ->batch_data (b);
3115+ auto left_data = left.batch_data (b);
3116+ auto right_data = right.batch_data (b); // K x N row-major scalars
3117+ for (integer m = 0 ; m != M; ++m) {
3118+ auto * lc0 = left_data + (m * K); // left cells (m,0..K-1)
3119+ auto * rc0 = this_data + (m * N); // result cells (m,0..N-1)
3120+ // A "clean" row has all cells present, uniform inner size A, and
3121+ // laid out as one contiguous stride-A block (so the GEMM can run
3122+ // zero-copy directly on the slab). Else fall back to per-cell
3123+ // AXPY.
3124+ long A = -1 ;
3125+ bool clean = true ;
3126+ for (integer k = 0 ; k != K && clean; ++k) {
3127+ const auto & c = lc0[k];
3128+ if (c.empty ()) {
3129+ clean = false ;
3130+ break ;
3131+ }
3132+ long s = static_cast <long >(c.size ());
3133+ if (A < 0 )
3134+ A = s;
3135+ else if (A != s)
3136+ clean = false ;
3137+ }
3138+ for (integer n = 0 ; n != N && clean; ++n) {
3139+ const auto & c = rc0[n];
3140+ if (c.empty ()) {
3141+ clean = false ;
3142+ break ;
3143+ }
3144+ long s = static_cast <long >(c.size ());
3145+ if (A < 0 )
3146+ A = s;
3147+ else if (A != s)
3148+ clean = false ;
3149+ }
3150+ // Arena cells are SIMD-padded, so the per-row inter-cell stride
3151+ // is the padded inner size (>= A). The strided GEMM requires the
3152+ // row's cells to be ONE contiguous run at constant stride -- only
3153+ // true for a single-page arena. An incrementally-built (un-
3154+ // compacted) ToT tile may span multiple pages, where the stride
3155+ // jumps at a page boundary; verify constant stride across ALL
3156+ // cells (so multi-page tiles fall back to the AXPY loop).
3157+ integer ldb = static_cast <integer>(A);
3158+ integer ldc = static_cast <integer>(A);
3159+ if (clean && A > 0 ) {
3160+ if (K > 1 )
3161+ ldb = static_cast <integer>(lc0[1 ].data () - lc0[0 ].data ());
3162+ if (N > 1 )
3163+ ldc = static_cast <integer>(rc0[1 ].data () - rc0[0 ].data ());
3164+ if (ldb < A || ldc < A) clean = false ; // sanity
3165+ const std::ptrdiff_t sb = ldb, sc = ldc;
3166+ for (integer k = 0 ; clean && k != K; ++k)
3167+ if (lc0[k].data () != lc0[0 ].data () + k * sb) clean = false ;
3168+ for (integer n = 0 ; clean && n != N; ++n)
3169+ if (rc0[n].data () != rc0[0 ].data () + n * sc) clean = false ;
3170+ }
3171+ if (A <= 0 ) continue ; // empty row -> nothing to do
3172+ if (clean) {
3173+ // result[m,n][a] += sum_k left[m,k][a] * right[k,n].
3174+ // Row-major gemm: C2(N x A) += right^T(N x K) * L2(K x A),
3175+ // where L2 = left row-m slab (K x A, ld=ldb), C2 = result row-m
3176+ // slab (N x A, ld=ldc), right is K x N (ld=N). ldb/ldc carry
3177+ // padding.
3178+ const integer Ai = static_cast <integer>(A);
3179+ TiledArray::math::blas::gemm (
3180+ TiledArray::math::blas::Transpose,
3181+ TiledArray::math::blas::NoTranspose,
3182+ /* M=*/ N, /* N=*/ Ai, /* K=*/ K, Real (1 ),
3183+ /* A=*/ right_data, /* lda=*/ N,
3184+ /* B=*/ lc0[0 ].data (), /* ldb=*/ ldb, Real (1 ),
3185+ /* C=*/ rc0[0 ].data (), /* ldc=*/ ldc);
3186+ } else { // per-cell AXPY fallback for this row
3187+ for (integer n = 0 ; n != N; ++n) {
3188+ auto c_offset = m * N + n;
3189+ for (integer k = 0 ; k != K; ++k)
3190+ elem_muladd_op (*(this_data + c_offset),
3191+ *(left_data + (m * K + k)),
3192+ *(right_data + (k * N + n)));
3193+ }
3194+ }
3195+ }
3196+ }
3197+ return *this ;
3198+ }
3199+ }
3200+ }
3201+
3202+ // GEMM-based scale path, mirror for T * ToT ("m,k" * "k,n;a" -> "m,n;a",
3203+ // left plain scalar, right ToT). Per column n: one GEMM
3204+ // result_n(M x A_n) += left(M x K) * right_n(K x A_n). The right/result
3205+ // column-n cells are strided over the slab (constant k-/m-stride within a
3206+ // single arena page); verify that, else fall back to per-cell AXPY.
3207+ if constexpr (detail::is_numeric_v<U> && is_tensor_view_v<V> &&
3208+ is_tensor_view_v<value_type>) {
3209+ using Real = std::remove_cv_t <typename value_type::value_type>;
3210+ if constexpr (std::is_same_v<std::remove_cv_t <U>, Real>) {
3211+ if (gemm_helper.left_op () == TiledArray::math::blas::NoTranspose &&
3212+ gemm_helper.right_op () == TiledArray::math::blas::NoTranspose) {
3213+ for (integer b = 0 ; b != nbatch (); ++b) {
3214+ auto this_data = this ->batch_data (b);
3215+ auto left_data = left.batch_data (b); // M x K row-major scalars
3216+ auto right_data = right.batch_data (b); // K x N ToT
3217+ for (integer n = 0 ; n != N; ++n) {
3218+ long A = -1 ;
3219+ bool clean = true ;
3220+ for (integer k = 0 ; k != K && clean; ++k) {
3221+ const auto & c = right_data[k * N + n];
3222+ if (c.empty ()) {
3223+ clean = false ;
3224+ break ;
3225+ }
3226+ long s = static_cast <long >(c.size ());
3227+ if (A < 0 )
3228+ A = s;
3229+ else if (A != s)
3230+ clean = false ;
3231+ }
3232+ for (integer m = 0 ; m != M && clean; ++m) {
3233+ const auto & c = this_data[m * N + n];
3234+ if (c.empty ()) {
3235+ clean = false ;
3236+ break ;
3237+ }
3238+ long s = static_cast <long >(c.size ());
3239+ if (A < 0 )
3240+ A = s;
3241+ else if (A != s)
3242+ clean = false ;
3243+ }
3244+ integer ldb = static_cast <integer>(A); // k-stride, right col n
3245+ integer ldc = static_cast <integer>(A); // m-stride, result col n
3246+ if (clean && A > 0 ) {
3247+ if (K > 1 )
3248+ ldb = static_cast <integer>(right_data[N + n].data () -
3249+ right_data[n].data ());
3250+ if (M > 1 )
3251+ ldc = static_cast <integer>(this_data[N + n].data () -
3252+ this_data[n].data ());
3253+ if (ldb < A || ldc < A) clean = false ;
3254+ const std::ptrdiff_t sb = ldb, sc = ldc;
3255+ for (integer k = 0 ; clean && k != K; ++k)
3256+ if (right_data[k * N + n].data () !=
3257+ right_data[n].data () + k * sb)
3258+ clean = false ;
3259+ for (integer m = 0 ; clean && m != M; ++m)
3260+ if (this_data[m * N + n].data () !=
3261+ this_data[n].data () + m * sc)
3262+ clean = false ;
3263+ }
3264+ if (A <= 0 ) continue ;
3265+ if (clean) {
3266+ // C_n(M x A) += left(M x K) * B_n(K x A). Row-major gemm.
3267+ const integer Ai = static_cast <integer>(A);
3268+ TiledArray::math::blas::gemm (
3269+ TiledArray::math::blas::NoTranspose,
3270+ TiledArray::math::blas::NoTranspose,
3271+ /* M=*/ M, /* N=*/ Ai, /* K=*/ K, Real (1 ),
3272+ /* A=*/ left_data, /* lda=*/ K,
3273+ /* B=*/ right_data[n].data (), /* ldb=*/ ldb, Real (1 ),
3274+ /* C=*/ this_data[n].data (), /* ldc=*/ ldc);
3275+ } else { // per-cell AXPY fallback for this column
3276+ for (integer m = 0 ; m != M; ++m) {
3277+ auto c_offset = m * N + n;
3278+ for (integer k = 0 ; k != K; ++k)
3279+ elem_muladd_op (*(this_data + c_offset),
3280+ *(left_data + (m * K + k)),
3281+ *(right_data + (k * N + n)));
3282+ }
3283+ }
3284+ }
3285+ }
3286+ return *this ;
3287+ }
3288+ }
3289+ }
3290+
30993291 for (integer b = 0 ; b != nbatch (); ++b) {
31003292 auto this_data = this ->batch_data (b);
31013293 auto left_data = left.batch_data (b);
0 commit comments