From 433fd67e6faaee487fb4baf4772a2ca8e1557bfb Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 21 Apr 2026 13:21:33 -0700 Subject: [PATCH 1/2] Fix CuTe composition stride-divisibility check (#3177) composition_impl() used a strict weakening of the divisibility condition: it accepted any rhs stride smaller than the current lhs mode shape, regardless of whether the shape was actually divisible by the stride. For A=(4,6,8):(2,3,5), B=6:3, this lets composition(A,B) compile and return (_2,_3):(_6,_3), but C(2)=3 != A(B(2))=7. Replace the weak check with the stronger condition used by pycute (layout.py:211). Fixes #3177 --- include/cute/layout.hpp | 9 ++++++--- test/unit/cute/core/composition.cpp | 31 +++++++++++++++-------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 6b8b102d91..65713a9d15 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1070,14 +1070,17 @@ composition_impl(LShape const& lhs_shape, [[maybe_unused]] LStride const& lhs_st [[maybe_unused]] auto curr_stride = get(lhs_stride); // Strong divisibility condition -- requires composition to be statically verifiable. - //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); + // Composition C = A o B is well defined for the current mode iff B's stride is either + // (a) a multiple of the mode's shape ((rest_stride % curr_shape) == 0), so the mode is skipped entirely, or + // (b) a divisor of the mode's shape ((curr_shape % rest_stride) == 0), so the mode is partially traversed by an integral number of strides. + //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or ((curr_shape % rest_stride) == Int<0>{}), "Stride Divisibility Condition"); // Weak divisibility condition -- verify the divisibility condition whenever possible if constexpr (is_static::value and is_static::value) { - CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); + CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or ((curr_shape % rest_stride) == Int<0>{}), "Stride Divisibility Condition"); } else { // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure - //assert((((rest_stride % curr_shape) == 0) or (rest_stride < curr_shape)) && "Stride Divisibility Condition"); + //assert((((rest_stride % curr_shape) == 0) or ((curr_shape % rest_stride) == 0)) && "Stride Divisibility Condition"); } // next_shape: ceil(exclusive_prefix_product(lhs_shape) / rhs_stride) diff --git a/test/unit/cute/core/composition.cpp b/test/unit/cute/core/composition.cpp index 7040acf22c..d37338f331 100644 --- a/test/unit/cute/core/composition.cpp +++ b/test/unit/cute/core/composition.cpp @@ -456,21 +456,22 @@ TEST(CuTe_core, Composition) test_composition(a, b); } - { - auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); - auto b = make_layout(_2{}, _3{}); - - test_composition(a, b); - } - - { - auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); - auto b = make_layout(_3{}, _3{}); - - test_composition(a, b); - } - - // Should fail to a static divisibility condition + // Should fail the strong stride-divisibility condition: rhs stride 3 neither + // divides nor is divided by lhs mode-0 shape 8. The previous "weak" check + // (rest_stride < curr_shape) accepted these cases; for these particular + // rhs sizes they happen to produce the right answer because the rhs + // coordinates never reach the mode boundary, but the same algorithm produces + // wrong answers for compositions that do cross it (see NVIDIA/cutlass#3177). + // { + // auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); + // auto b = make_layout(_2{}, _3{}); + // test_composition(a, b); + // } + // { + // auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); + // auto b = make_layout(_3{}, _3{}); + // test_composition(a, b); + // } // { // auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); // auto b = make_layout(_4{}, _3{}); From 2e9ee1ff0078ea51d7c0bed3480a962c5c132d95 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Thu, 30 Apr 2026 21:35:17 -0700 Subject: [PATCH 2/2] =?UTF-8?q?Restore=20=C2=A73.3.3=20truncation=20cases?= =?UTF-8?q?=20in=20the=20divisibility=20check?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The strong divisibility check from the previous commit fixes the wrong-answer composition from #3177, but rejects the paper's §3.3.3 "apparent violation" cases that produce well-defined results, e.g. A = (4,2,8):(3,12,97), B = 3:3 -> 3:9 After the public composition() coalesces A to (8,8):(3,97), the strong check sees `8 % 3 != 0` and refuses to compile, even though A(0)=0, A(3)=9, A(6)=18 is well-defined. Add a third disjunct that accepts the safe-truncation pattern: when B's entire image fits inside the current LHS mode, higher modes are unreachable and cannot perturb the result. This is the §3.3.3 distinction between "apparent" and "real" divisibility violations. Predicate now accepts iff at least one of: (a) (rest_stride % curr_shape) == 0 -- skip mode entirely (b) (curr_shape % rest_stride) == 0 -- partial traversal (c) (rest_shape - 1) * rest_stride < curr_shape -- safe truncation: B's image stays within the current mode Verification matrix: Case Pre-coalesce LHS Decision ---------------------------------- --------------------- -------- paper §3.3.3 ok (returns 3:9) (8,8):(3,97) o 3:3 accept paper §3.3.3 fail-left (8,8):(3,97) o 4:3 reject paper §3.3.3 fail-right (4,2,8):(3,15,97) o 3:3 reject wrong-answer bug #3177 (4,6,8):(2,3,5) o 6:3 reject CuTe test (8,8):(8,1) o 2:3 (8,8):(8,1) o 2:3 accept CuTe test (8,8):(8,1) o 3:3 (8,8):(8,1) o 3:3 accept CuTe test (8,8):(8,1) o 4:3 (8,8):(8,1) o 4:3 reject Reference: arXiv:2603.02298 §3.3.3. --- include/cute/layout.hpp | 22 ++++++++++--- test/unit/cute/core/composition.cpp | 51 +++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 65713a9d15..841eac0a78 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1075,12 +1075,26 @@ composition_impl(LShape const& lhs_shape, [[maybe_unused]] LStride const& lhs_st // (b) a divisor of the mode's shape ((curr_shape % rest_stride) == 0), so the mode is partially traversed by an integral number of strides. //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or ((curr_shape % rest_stride) == Int<0>{}), "Stride Divisibility Condition"); - // Weak divisibility condition -- verify the divisibility condition whenever possible - if constexpr (is_static::value and is_static::value) { - CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or ((curr_shape % rest_stride) == Int<0>{}), "Stride Divisibility Condition"); + // Composition C = A o B is well defined for the current mode iff at least one of: + // (a) rest_stride is a multiple of curr_shape ((rest_stride % curr_shape) == 0): + // the mode is skipped entirely. + // (b) curr_shape is a multiple of rest_stride ((curr_shape % rest_stride) == 0): + // the mode is partially traversed by an integral number of strides. + // (c) every point in B's image stays inside the current mode: + // (rest_shape - 1) * rest_stride < curr_shape. + // Higher modes are unreachable and cannot perturb the result; + // this is the §3.3.3 "apparent violation" pattern resolved by truncation. + if constexpr (is_static::value and + is_static::value and + is_static::value) { + CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or + ((curr_shape % rest_stride) == Int<0>{}) or + (((rest_shape - Int<1>{}) * rest_stride) < curr_shape), + "Stride Divisibility Condition"); } else { // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure - //assert((((rest_stride % curr_shape) == 0) or ((curr_shape % rest_stride) == 0)) && "Stride Divisibility Condition"); + //assert((((rest_stride % curr_shape) == 0) or ((curr_shape % rest_stride) == 0) or + // (((rest_shape - 1) * rest_stride) < curr_shape)) and "Stride Divisibility Condition"); } // next_shape: ceil(exclusive_prefix_product(lhs_shape) / rhs_stride) diff --git a/test/unit/cute/core/composition.cpp b/test/unit/cute/core/composition.cpp index d37338f331..399dcaf116 100644 --- a/test/unit/cute/core/composition.cpp +++ b/test/unit/cute/core/composition.cpp @@ -456,25 +456,56 @@ TEST(CuTe_core, Composition) test_composition(a, b); } - // Should fail the strong stride-divisibility condition: rhs stride 3 neither - // divides nor is divided by lhs mode-0 shape 8. The previous "weak" check - // (rest_stride < curr_shape) accepted these cases; for these particular - // rhs sizes they happen to produce the right answer because the rhs - // coordinates never reach the mode boundary, but the same algorithm produces - // wrong answers for compositions that do cross it (see NVIDIA/cutlass#3177). + // Safe truncation: rhs stride 3 neither divides nor is divided by lhs + // mode-0 shape 8, but the rhs image stays inside that mode so higher modes + // are unreachable. Accepted via the safe-truncation disjunct + // ((rest_shape - 1) * rest_stride < curr_shape). + { + auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); + auto b = make_layout(_2{}, _3{}); + test_composition(a, b); + } + { + auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); + auto b = make_layout(_3{}, _3{}); + test_composition(a, b); + } + + // Paper §3.3.3 "apparent violation": (4,2,8):(3,12,97) coalesces to + // (8,8):(3,97), then o 3:3 truncates safely to 3:9. + // arXiv:2603.02298 §3.3.3. + { + auto a = make_layout(Shape<_4,_2,_8>{}, Stride<_3,_12,Int<97>>{}); + auto b = make_layout(_3{}, _3{}); + test_composition(a, b); + } + + // Should fail the divisibility condition: rhs image crosses the mode + // boundary (every disjunct of the predicate fails). // { // auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); - // auto b = make_layout(_2{}, _3{}); + // auto b = make_layout(_4{}, _3{}); // test_composition(a, b); // } + // Paper §3.3.3 fail-left: same A as above but B=4:3 doesn't truncate + // (after coalesce, (4-1)*3 = 9 >= 8). // { - // auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); + // auto a = make_layout(Shape<_4,_2,_8>{}, Stride<_3,_12,Int<97>>{}); + // auto b = make_layout(_4{}, _3{}); + // test_composition(a, b); + // } + // Paper §3.3.3 fail-right: A's middle stride 15 prevents the coalesce + // that would otherwise expose a safe truncation. + // { + // auto a = make_layout(Shape<_4,_2,_8>{}, Stride<_3,Int<15>,Int<97>>{}); // auto b = make_layout(_3{}, _3{}); // test_composition(a, b); // } + // Wrong-answer case from #3177: the previous weak check accepted this and + // returned (_2,_3):(_6,_3), but C(2) = 3 != A(B(2)) = 7. // { - // auto a = make_layout(Shape<_8,_8>{}, Stride<_8,_1>{}); - // auto b = make_layout(_4{}, _3{}); + // auto a = make_layout(Shape<_4,_6,_8>{}, Stride<_2,_3,_5>{}); + // auto b = make_layout(_6{}, _3{}); // test_composition(a, b); // }