Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/backend/cuda/op/copy_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,13 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) {
const LayoutMap &layout_map =
ctx.layout_map != nullptr ? *ctx.layout_map : empty_layout_map;
bool is_cutedsl = TargetIsCuTeDSL(ctx.target);
facts.layout_dependent_tma_available =
facts.has_layout_map && !is_cutedsl && !ctx.buffer_oob;
// Issue #2180: only the descriptor-based 2D TMA path needs the OOB gate.
// The 1D bulk-copy path emits `cp.async.bulk`, which has the same OOB
// semantics as plain T.copy(); gating it on `buffer_oob` causes
// InferLayout to fall through to the 2D path for dynamic-outer-shape
// tensors and install a swizzle-shaped shared layout, which then forces
// Lower() into LowerBulk and triggers the 256-element splitting.
facts.layout_dependent_tma_available = facts.has_layout_map && !is_cutedsl;

if (facts.layout_dependent_tma_available) {
facts.can_bulk_load_1d =
Expand Down
37 changes: 37 additions & 0 deletions testing/python/language/test_tilelang_language_tma_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,47 @@ def run_elementwise_add(M, N):
assert "tma_load" in code and "CUtensorMap" in code


def _lower_issue_2180_kernel(K, dtype):
M = T.dynamic("M")

@T.prim_func
def gemm(A: T.Tensor([M, K], dtype)):
with T.Kernel(M, threads=256):
var = T.alloc_var(T.int32, init=0)
a_shared = T.alloc_shared(K, dtype=dtype)
mbar = T.alloc_barrier(256)
T.tma_copy(A[var, 0:K], a_shared, barrier=mbar)

artifact = tilelang.lower(gemm, target={"kind": "cuda", "arch": "sm_90a"})
return artifact.kernel_source


def _check_single_1d_tma(code):
n_tma_load = code.count("tl::tma_load(")
has_desc = "CUtensorMap" in code
assert n_tma_load == 1, f"Issue #2180: expected exactly 1 tl::tma_load, got {n_tma_load}.\nGenerated source:\n{code}"
assert not has_desc, f"Issue #2180: expected 1D bulk-copy without CUtensorMap descriptor.\nGenerated source:\n{code}"


def test_issue_2180_full_row_fp32_k1024():
_check_single_1d_tma(_lower_issue_2180_kernel(K=1024, dtype=T.float32))


def test_issue_2180_full_row_fp32_k512():
_check_single_1d_tma(_lower_issue_2180_kernel(K=512, dtype=T.float32))


def test_issue_2180_full_row_fp16_k1024():
_check_single_1d_tma(_lower_issue_2180_kernel(K=1024, dtype=T.float16))


def main():
run_elementwise_add(128, 128)
run_elementwise_add(256, 128)
run_elementwise_add(256, 256)
test_issue_2180_full_row_fp32_k1024()
test_issue_2180_full_row_fp32_k512()
test_issue_2180_full_row_fp16_k1024()


if __name__ == "__main__":
Expand Down
Loading