Skip to content

[Bug][TMA] Skip OOB gate for 1D TMA bulk-copy eligibility (#2180)#2235

Draft
mygitljf wants to merge 1 commit into
tile-ai:mainfrom
mygitljf:fix/issue-2180-tma-1d-no-split
Draft

[Bug][TMA] Skip OOB gate for 1D TMA bulk-copy eligibility (#2180)#2235
mygitljf wants to merge 1 commit into
tile-ai:mainfrom
mygitljf:fix/issue-2180-tma-1d-no-split

Conversation

@mygitljf
Copy link
Copy Markdown

@mygitljf mygitljf commented May 20, 2026

Summary

Fixes #2180.

T.tma_copy(A[var, 0:K], a_shared, barrier=mbar) on a [M, K] tensor with M = T.dynamic('M'), K = T.const('K') lowered to 4 split tl::tma_load calls (offsets 0/256/512/768) using a 2D CUtensorMap descriptor, instead of a single 1D cp.async.bulk issuing one tl::tma_load for the full row.

Root cause

In AnalyzeCopyFacts (src/backend/cuda/op/copy_analysis.cc), the 1D and 2D TMA eligibility checks shared a single gate that included !ctx.buffer_oob:

facts.layout_dependent_tma_available =
    facts.has_layout_map && !is_cutedsl && !ctx.buffer_oob;
if (facts.layout_dependent_tma_available) {
  facts.can_bulk_load_1d  = CheckBulkLoad1D(...);
  facts.can_bulk_store_1d = CheckBulkStore1D(...);
}

When M is dynamic and var = T.alloc_var(init=0), the analyzer cannot prove var + 1 <= M, so layout_inference.cc sets buffer_oob = true. That correctly disqualifies the descriptor-based 2D path, but it also (incorrectly) suppresses 1D eligibility, so:

  1. SelectInst falls back to kBulkLoad.
  2. InferBulkLayout installs a swizzle-shaped (FloorDiv/FloorMod 256) shared layout.
  3. Copy::Lower re-runs SelectInst with buffer_oob = false, but the shared layout is no longer linear, so CheckBulkCopy1D returns false again.
  4. LowerBulk runs, hits inner_box_dim = 1024 > 256, and emits For(i, 0, 4) tma_load(...) — the four split loads in the issue.

The 1D bulk-copy path emits cp.async.bulk (not cp.async.bulk.tensor) and has the same OOB semantics as a plain T.copy(). It does not need the descriptor-only OOB gate.

Fix

Drop !ctx.buffer_oob from the 1D eligibility check. CheckBulkCopy1D (contiguous innermost slice, full-extent trailing dim, element count match) is the only contract the 1D path needs. The 2D descriptor path is unaffected: it depends on facts.can_bulk_load / facts.can_bulk_store, which are computed independently below and are not gated on buffer_oob.

-  facts.layout_dependent_tma_available =
-      facts.has_layout_map && !is_cutedsl && !ctx.buffer_oob;
+  // Issue #2180: only the descriptor-based 2D TMA path needs the OOB gate.
+  // The 1D bulk-copy path emits `cp.async.bulk`, which has the same OOB
+  // semantics as plain T.copy(); gating it on `buffer_oob` causes
+  // InferLayout to fall through to the 2D path for dynamic-outer-shape
+  // tensors and install a swizzle-shaped shared layout, which then forces
+  // Lower() into LowerBulk and triggers the 256-element splitting.
+  facts.layout_dependent_tma_available = facts.has_layout_map && !is_cutedsl;

Verification

Pre-patch generated source for the issue's repro (extracted from the codegen string):

extern "C" __global__ void __launch_bounds__(256, 1)
gemm_kernel(__grid_constant__ const CUtensorMap A_desc, int M) {
  ...
  if (tl::tl_shuffle_elect<256>()) {
    mbar[0].expect_transaction(4096);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[  0])),   0, var);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[256])), 256, var);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[512])), 512, var);
    tl::tma_load(A_desc, mbar[0], (&(a_shared[768])), 768, var);
  }
}

Post-patch:

extern "C" __global__ void __launch_bounds__(256, 1)
gemm_kernel(float* __restrict__ A, int M) {
  ...
  mbar[0].expect_transaction(4096);
  tl::tma_load((&(a_shared[0])), (&(A[(((int64_t)var) * (int64_t)1024)])), mbar[0], 4096);
}

A single 1D cp.async.bulk issue, no CUtensorMap descriptor — matches the expected behavior in the issue.

Tests

3 new regression cases in testing/python/language/test_tilelang_language_tma_1d.py, asserting on kernel_source produced by tilelang.lower(target={"kind": "cuda", "arch": "sm_90a"}):

  • test_issue_2180_full_row_fp32_k1024
  • test_issue_2180_full_row_fp32_k512
  • test_issue_2180_full_row_fp16_k1024

Each asserts exactly one tl::tma_load and no CUtensorMap substring in the generated source. All three pass locally.

Existing tests checked locally (A100, sm_80):

  • test_tilelang_language_tma_1d.py — 3 passed
  • test_tilelang_language_tma_copy.py — 0 failed (5 skipped, require sm_90)
  • test_tilelang_language_tma_store.py — 0 failed (4 skipped, require sm_90)

Notes

  • Verified at the codegen-string level on an A100 (sm_80) host. End-to-end launch + numerical correctness require sm_90 and are left to CI's H100 runner.
  • The patch is intentionally minimal (one boolean expression). Other 1D-TMA-related cleanups (e.g. legalize_pairwise_extents rank handling, threading layout_map through Classify*) were considered but are not needed to fix this specific issue and are out of scope here.

Summary by CodeRabbit

  • Bug Fixes

    • Improved Tensor Memory Accelerator (TMA) path selection logic for operations with dynamic dimensions, enabling more efficient code generation when layout information is available.
  • Tests

    • Added regression tests for TMA GEMM-like kernels with dynamic dimensions to ensure correct kernel generation.

Review Change Stack

When the global tensor has a dynamic outer shape and the index cannot
be statically proven within bounds, `buffer_oob` is set to true. The
existing logic in `AnalyzeCopyFacts` then masked out 1D bulk-copy
eligibility along with the descriptor-based 2D path, causing
`InferLayout` to install a swizzle-shaped shared layout via
`InferBulkLayout` and forcing `Copy::Lower` to fall through to the
2D `LowerBulk`. There the `inner_box_dim > 256` branch issues
`(K / 256)` separate `tl::tma_load` calls instead of a single 1D
bulk copy.

The 1D bulk-copy path emits `cp.async.bulk` (not `cp.async.bulk.tensor`)
and has the same OOB semantics as a plain `T.copy()` - it does not
need the descriptor-only OOB gate. Drop `!ctx.buffer_oob` from the
1D eligibility check so dynamic-outer-shape 1D copies that already
satisfy `CheckBulkCopy1D` (contiguous innermost slice, full-extent
trailing dim, element count match) keep the single-issue path.

Repro: see issue tile-ai#2180 - `T.tma_copy(A[var, 0:K], a_shared, barrier=mbar)`
with `M=T.dynamic('M')`, `K=T.const('K')` on `A[M, K]` was lowering to
4 split `tl::tma_load` calls; with this patch it lowers to a single
`tl::tma_load(smem, gmem, mbar, total_bytes)` and no `CUtensorMap`
descriptor. Verified locally on A100 at the codegen-string level
(end-to-end launch requires sm_90+ and is left to CI).

Tests: 3 new regression cases in
`testing/python/language/test_tilelang_language_tma_1d.py` covering
fp32/K=1024, fp32/K=512, fp16/K=1024 with `T.dynamic` outer shape.
All existing TMA tests pass without regression.
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 20, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a334b809-3667-4295-85bf-14b36bf4949b

📥 Commits

Reviewing files that changed from the base of the PR and between 5a22d62 and 73589fc.

📒 Files selected for processing (2)
  • src/backend/cuda/op/copy_analysis.cc
  • testing/python/language/test_tilelang_language_tma_1d.py

📝 Walkthrough

Walkthrough

Fixed TMA availability gating in copy_analysis.cc to enable the 1D bulk-copy code path by removing the ctx.buffer_oob requirement, preventing unnecessary descriptor-based TMA splitting. Added four regression tests validating single tl::tma_load emission and absence of CUtensorMap for Issue #2180.

Changes

Issue #2180 TMA 1D load path fix

Layer / File(s) Summary
TMA availability gating fix in copy_analysis
src/backend/cuda/op/copy_analysis.cc
Changed layout_dependent_tma_available logic to enable 1D bulk-copy path when layout map exists and target is not CuTeDSL, removing prior !ctx.buffer_oob gating. Added comments explaining OOB handling for descriptor-based 2D TMA versus 1D bulk-copy paths.
Regression test helpers and test suite
testing/python/language/test_tilelang_language_tma_1d.py
Added _lower_issue_2180_kernel to JIT/lower GEMM kernels with dynamic M and tma_copy targeting sm_90a, and _check_single_1d_tma to validate generated code contains exactly one tl::tma_load( and omits CUtensorMap. Implemented four test functions for K={1024,512} across fp32 and fp16, and updated main() to invoke these tests.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • tile-ai/tilelang#2137: Fixes 1D TMA store layout inference in copy.cc and adds matching 1D tma_store regression tests, addressing the same descriptor-vs-bulk-copy lowering issue from a different angle.
  • tile-ai/tilelang#1989: Targets 1D TMA codegen correctness and adds regression tests asserting compiled kernels emit 1D TMA (cp.async.bulk.tensor) and run correctly.
  • tile-ai/tilelang#766: Demonstrates 1D TMA/bulk-copy path usage in an example with the Scale tensor, providing a practical reference for the same TMA code path.

Poem

🐰 A TMA that split into four became one,
No more CUtensorMap, the 1D path is won!
With layout maps checked, the gating grows wise,
Tests guard the fix from future demise. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: fixing the OOB gate condition for 1D TMA bulk-copy eligibility, directly addressing issue #2180.
Linked Issues check ✅ Passed The PR directly addresses issue #2180 by removing the !ctx.buffer_oob gate from layout-dependent TMA logic and adds regression test coverage for the fix.
Out of Scope Changes check ✅ Passed All changes are scoped to fixing the TMA 1D bulk-copy issue: core fix in copy_analysis.cc and regression tests in test_tilelang_language_tma_1d.py.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999 LeiWang1999 self-requested a review May 21, 2026 01:41
@mygitljf mygitljf marked this pull request as draft May 21, 2026 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] TMA load unnecessary spliting

1 participant