diff --git a/src/backend/cuda/op/copy_analysis.cc b/src/backend/cuda/op/copy_analysis.cc index bbb672b18..52cd723bd 100644 --- a/src/backend/cuda/op/copy_analysis.cc +++ b/src/backend/cuda/op/copy_analysis.cc @@ -493,8 +493,13 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) { const LayoutMap &layout_map = ctx.layout_map != nullptr ? *ctx.layout_map : empty_layout_map; bool is_cutedsl = TargetIsCuTeDSL(ctx.target); - facts.layout_dependent_tma_available = - facts.has_layout_map && !is_cutedsl && !ctx.buffer_oob; + // Issue #2180: only the descriptor-based 2D TMA path needs the OOB gate. + // The 1D bulk-copy path emits `cp.async.bulk`, which has the same OOB + // semantics as plain T.copy(); gating it on `buffer_oob` causes + // InferLayout to fall through to the 2D path for dynamic-outer-shape + // tensors and install a swizzle-shaped shared layout, which then forces + // Lower() into LowerBulk and triggers the 256-element splitting. + facts.layout_dependent_tma_available = facts.has_layout_map && !is_cutedsl; if (facts.layout_dependent_tma_available) { facts.can_bulk_load_1d = diff --git a/testing/python/language/test_tilelang_language_tma_1d.py b/testing/python/language/test_tilelang_language_tma_1d.py index 9cb79c10c..fdb783a2d 100644 --- a/testing/python/language/test_tilelang_language_tma_1d.py +++ b/testing/python/language/test_tilelang_language_tma_1d.py @@ -46,10 +46,47 @@ def run_elementwise_add(M, N): assert "tma_load" in code and "CUtensorMap" in code +def _lower_issue_2180_kernel(K, dtype): + M = T.dynamic("M") + + @T.prim_func + def gemm(A: T.Tensor([M, K], dtype)): + with T.Kernel(M, threads=256): + var = T.alloc_var(T.int32, init=0) + a_shared = T.alloc_shared(K, dtype=dtype) + mbar = T.alloc_barrier(256) + T.tma_copy(A[var, 0:K], a_shared, barrier=mbar) + + artifact = tilelang.lower(gemm, target={"kind": "cuda", "arch": "sm_90a"}) + return artifact.kernel_source + + +def _check_single_1d_tma(code): + n_tma_load = code.count("tl::tma_load(") + has_desc = "CUtensorMap" in code + assert n_tma_load == 1, f"Issue #2180: expected exactly 1 tl::tma_load, got {n_tma_load}.\nGenerated source:\n{code}" + assert not has_desc, f"Issue #2180: expected 1D bulk-copy without CUtensorMap descriptor.\nGenerated source:\n{code}" + + +def test_issue_2180_full_row_fp32_k1024(): + _check_single_1d_tma(_lower_issue_2180_kernel(K=1024, dtype=T.float32)) + + +def test_issue_2180_full_row_fp32_k512(): + _check_single_1d_tma(_lower_issue_2180_kernel(K=512, dtype=T.float32)) + + +def test_issue_2180_full_row_fp16_k1024(): + _check_single_1d_tma(_lower_issue_2180_kernel(K=1024, dtype=T.float16)) + + def main(): run_elementwise_add(128, 128) run_elementwise_add(256, 128) run_elementwise_add(256, 256) + test_issue_2180_full_row_fp32_k1024() + test_issue_2180_full_row_fp32_k512() + test_issue_2180_full_row_fp16_k1024() if __name__ == "__main__":