Skip to content

[BUG] TMA load unnecessary spliting #2180

@bucket-xv

Description

@bucket-xv

Required prerequisites

What version of TileLang are you using?

0.1.9+cuda.gitbeef5cf4

System information

0.1.9+cuda.gitbeef5cf4

Problem description

Read 1D TMA as 2D TMA

Reproducible example code

The Python snippets:

import tilelang
from tilelang import language as T
import torch
@tilelang.jit
def gemm(A):
    M = T.dynamic('M')
    K = T.const('K')
    A: T.Tensor[[M, K], T.float32]
    with T.Kernel(M, threads=256) as pid:
        var = T.alloc_var(T.int32, init=0)
        a_shared = T.alloc_shared(K, dtype=T.float32)
        mbar = T.alloc_barrier(256)
        T.tma_copy(A[var, 0:K], a_shared, barrier=mbar)
    return None
# Get kernel object by calling compile()
A = torch.empty((23,1024), dtype=torch.float32, device='cuda')
kernel = gemm.compile(A)
# Print kernel source
print(kernel.get_kernel_source())

Traceback

++
#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940
#define _tl_orig_alignas alignas
#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64)
#include <cuda.h>
#undef alignas
#define alignas _tl_orig_alignas
#endif
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void gemm_kernel(__grid_constant__ const CUtensorMap A_desc, int M);
extern "C" __global__ void __launch_bounds__(256, 1) gemm_kernel(__grid_constant__ const CUtensorMap A_desc, int M) {
  __shared__ __align__(16) uint64_t mbar_mem[1];
  auto mbar = reinterpret_cast<Barrier*>(mbar_mem);
  int var = 0;
  extern __shared__ __align__(1024) float a_shared[];
  if (tl::tl_shuffle_elect<0>()) {
    tl::prefetch_tma_descriptor(A_desc);
  }
  if (tl::tl_shuffle_elect<0>()) {
    mbar[0].init(256);
  }
  tl::fence_barrier_init();
  __syncthreads();
  var = 0;
  if (tl::tl_shuffle_elect<256>()) {
    mbar[0].expect_transaction(4096);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[0])), 0, var);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[256])), 256, var);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[512])), 512, var);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[768])), 768, var);
  }
}

Expected behavior

Should not split TMA into 4 parts and should use 1D TMA. The expected TMA size is not right as well.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions