Skip to content

Commit c699120

Browse files
committed
refine code
1 parent f364ace commit c699120

148 files changed

Lines changed: 9403 additions & 4887 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
2828
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2929
set(MPS_FILES csrc/mps_ops.mm)
3030
set(METAL_FILES csrc/mps_kernels.metal)
31-
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass.cpp csrc/xpu_cutlass-cute.cpp csrc/xpu_cutlass_fusion.cpp)
31+
#set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass.cpp csrc/xpu_cutlass-cute.cpp csrc/xpu_cutlass_fusion.cpp)
32+
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass_fusion.cpp)
3233
# C++ sources are always included
3334
list(APPEND SRC_FILES ${CPP_FILES})
3435

@@ -312,7 +313,14 @@ if(BUILD_MPS)
312313
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
313314
endif()
314315
if(BUILD_XPU)
315-
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=intel_gpu_pvc;-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier;-Xs; -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
316+
set(SYCL_FLAGS
317+
-fsycl
318+
--offload-compress
319+
-fsycl-targets=intel_gpu_pvc
320+
-Xspirv-translator -spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate
321+
-Xs
322+
-options "-cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required"
323+
)
316324
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=intel_gpu_pvc;-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier;")
317325

318326
set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)

csrc/xpu_cutlass_fusion.cpp

Lines changed: 153 additions & 363 deletions
Large diffs are not rendered by default.

include/cute/algorithm/cooperative_gemm.hpp

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,23 @@ epilogue_predication(ThrMMA<Args...> const& thr_mma,
9898
}
9999
}
100100

101-
template<class Alpha, class TRC, class RCLayout,
101+
template<class ... Args,
102+
class Alpha, class TRC, class RCLayout,
102103
class Beta, class TSC, class SCLayout,
103104
class CLoadTransformOp, class CStoreTransformOp,
104-
class SmemCopyOpC>
105+
class SmemCopyLdOpC, class SmemCopyStOpC>
105106
CUTE_HOST_DEVICE
106107
void
107-
epilogue_no_predication(Alpha const& alpha,
108+
epilogue_no_predication(uint32_t thread_idx,
109+
ThrMMA<Args...> const& thr_mma,
110+
Alpha const& alpha,
108111
Tensor<TRC, RCLayout> & tCrC,
109112
Beta const& beta,
110-
Tensor<TSC, SCLayout> & tCsC,
113+
Tensor<TSC, SCLayout> & sC,
111114
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
112115
CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C
113-
SmemCopyOpC const& sC_copy_op)
116+
SmemCopyLdOpC const& sC_copy_ld_op,
117+
SmemCopyStOpC const& sC_copy_st_op)
114118
{
115119
using InputTypeC = typename TSC::value_type;
116120
using ComputeTypeC = typename TRC::value_type;
@@ -125,18 +129,33 @@ epilogue_no_predication(Alpha const& alpha,
125129
CUTE_GCC_UNREACHABLE;
126130
} ();
127131

128-
Tensor tCrDi = make_fragment_like(tCsC);
129132
Tensor tCrD = make_fragment_like(tCrC);
133+
Tensor tCrDi = make_fragment_like<InputTypeC>(tCrD);
134+
130135
if(!isBetaZero) {
131-
copy(sC_copy_op, tCsC, tCrDi);
136+
auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom<SmemCopyLdOpC, InputTypeC>{}, thr_mma);
137+
auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx);
138+
Tensor tCsC = smem_thr_copy_C.partition_S(sC);
139+
Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D(tCrDi);
140+
CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M
141+
CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N
142+
copy(smem_tiled_copy_C, tCsC, tCrDi_copy_view);
143+
132144
// Transform C on/after load
133145
cute::transform(tCrDi, tCrD, sC_load_op);
134146
}
135147
// C = alpha * (A * B) + beta * C
136148
axpby(alpha, tCrC, beta, tCrD);
137149
// Transform C before/on store
138150
cute::transform(tCrD, tCrDi, sC_store_op);
139-
copy(sC_copy_op, tCrDi, tCsC);
151+
152+
auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom<SmemCopyStOpC, InputTypeC>{}, thr_mma);
153+
auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx);
154+
Tensor tCsC = smem_thr_copy_C.partition_D(sC);
155+
Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S(tCrDi);
156+
CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M
157+
CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N
158+
copy(smem_tiled_copy_C, tCrDi_copy_view, tCsC);
140159
}
141160

142161
// Predicated Cooperative GEMM
@@ -283,23 +302,23 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
283302

284303
// Create register tensors for the MMA to operate on
285304
Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
305+
Tensor tCrAi = make_fragment_like<InputTypeA>(tCrA);
286306
Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K)
307+
Tensor tCrBi = make_fragment_like<InputTypeB>(tCrB);
287308

288309
using CopyOpAType = SmemCopyOpA;
289310
using CopyOpBType = SmemCopyOpB;
290311

291312
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, InputTypeA>{}, thr_mma);
292313
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
293314
Tensor tCsA = smem_thr_copy_A.partition_S(sA);
294-
Tensor tCrAi = make_fragment_like(tCsA);
295315
Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi);
296316
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M
297317
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K
298318

299319
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, InputTypeB>{}, thr_mma);
300320
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
301321
Tensor tCsB = smem_thr_copy_B.partition_S(sB);
302-
Tensor tCrBi = make_fragment_like(tCsB);
303322
Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi);
304323
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N
305324
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K
@@ -346,7 +365,7 @@ template <class... Args,
346365
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
347366
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
348367
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
349-
class SmemCopyOpC = DefaultCopy>
368+
class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
350369
CUTE_HOST_DEVICE
351370
void
352371
cooperative_gemm(uint32_t thread_idx,
@@ -356,13 +375,14 @@ cooperative_gemm(uint32_t thread_idx,
356375
Tensor<TB, BLayout> const& sB,
357376
Beta const& beta,
358377
Tensor<TC, CLayout> & sC,
359-
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
360-
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
361-
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
362-
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
363-
SmemCopyOpA const& sA_copy_op = {},
364-
SmemCopyOpB const& sB_copy_op = {},
365-
SmemCopyOpC const& sC_copy_op = {})
378+
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
379+
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
380+
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
381+
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
382+
SmemCopyOpA const& sA_copy_op = {},
383+
SmemCopyOpB const& sB_copy_op = {},
384+
SmemCopyLdOpC const& sC_copy_ld_op = {},
385+
SmemCopyStOpC const& sC_copy_st_op = {})
366386
{
367387
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
368388
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
@@ -394,7 +414,7 @@ cooperative_gemm(uint32_t thread_idx,
394414
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
395415
);
396416
detail::epilogue_no_predication(
397-
alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op
417+
thread_idx, thr_mma,alpha, tCrC, beta, sC, sC_load_op, sC_store_op, sC_copy_ld_op, sC_copy_st_op
398418
);
399419
} else {
400420
detail::cooperative_gemm_predication(
@@ -466,7 +486,7 @@ template <class... Args,
466486
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
467487
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
468488
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
469-
class SmemCopyOpC = DefaultCopy>
489+
class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
470490
CUTE_HOST_DEVICE
471491
void
472492
cooperative_gemm(uint32_t thread_idx,
@@ -476,17 +496,18 @@ cooperative_gemm(uint32_t thread_idx,
476496
Tensor<TB, BLayout> const& sB,
477497
Beta const& beta,
478498
Tensor<TC, CLayout> && sC,
479-
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
480-
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
481-
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
482-
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
483-
SmemCopyOpA const& sA_copy_op = {},
484-
SmemCopyOpB const& sB_copy_op = {},
485-
SmemCopyOpC const& sC_copy_op = {})
499+
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
500+
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
501+
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
502+
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
503+
SmemCopyOpA const& sA_copy_op = {},
504+
SmemCopyOpB const& sB_copy_op = {},
505+
SmemCopyLdOpC const& sC_copy_ld_op = {},
506+
SmemCopyStOpC const& sC_copy_st_op = {})
486507
{
487508
cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
488509
sA_load_op, sB_load_op, sC_load_op, sC_store_op,
489-
sA_copy_op, sB_copy_op, sC_copy_op);
510+
sA_copy_op, sB_copy_op, sC_copy_ld_op, sC_copy_st_op);
490511
}
491512

492513
// Legacy overload of cute::gemm for backwards-compatibility

include/cute/algorithm/tuple_algorithms.hpp

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <cute/config.hpp>
3434

3535
#include <cute/util/type_traits.hpp>
36+
#include <cute/container/type_list.hpp>
3637
#include <cute/container/tuple.hpp>
3738
#include <cute/algorithm/functional.hpp>
3839
#include <cute/numeric/integer_sequence.hpp>
@@ -283,34 +284,13 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f)
283284
// find and find_if
284285
//
285286

286-
namespace detail {
287-
288-
template <class T, class F, int I, int... Is>
289-
CUTE_HOST_DEVICE constexpr
290-
auto
291-
find_if(T const& t, F&& f, seq<I,Is...>)
292-
{
293-
if constexpr (decltype(f(get<I>(t)))::value) {
294-
return cute::C<I>{};
295-
} else
296-
if constexpr (sizeof...(Is) == 0) {
297-
return cute::C<I+1>{};
298-
} else {
299-
return find_if(t, f, seq<Is...>{});
300-
}
301-
302-
CUTE_GCC_UNREACHABLE;
303-
}
304-
305-
} // end namespace detail
306-
307287
template <class T, class F>
308288
CUTE_HOST_DEVICE constexpr
309289
auto
310290
find_if(T const& t, F&& f)
311291
{
312292
if constexpr (is_tuple<T>::value) {
313-
return detail::find_if(t, f, tuple_seq<T>{});
293+
return detail::tapply(t, f, [] (auto... a) { return cute::C<find_true_v<decltype(a)::value...>>{}; }, tuple_seq<T>{});
314294
} else {
315295
return cute::C<decltype(f(t))::value ? 0 : 1>{};
316296
}
@@ -332,7 +312,7 @@ auto
332312
any_of(T const& t, F&& f)
333313
{
334314
if constexpr (is_tuple<T>::value) {
335-
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
315+
return detail::tapply(t, f, [] (auto... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
336316
} else {
337317
return f(t);
338318
}
@@ -346,7 +326,7 @@ auto
346326
all_of(T const& t, F&& f)
347327
{
348328
if constexpr (is_tuple<T>::value) {
349-
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
329+
return detail::tapply(t, f, [] (auto... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
350330
} else {
351331
return f(t);
352332
}

include/cute/arch/cluster_sm90.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#pragma once
3232

3333
#include <cute/config.hpp>
34+
#include <cute/numeric/numeric_types.hpp>
3435

3536
// Config
3637
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \

include/cute/arch/config.hpp

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@
7272
# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED
7373
#endif
7474

75+
#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED))
76+
# define CUTE_ARCH_TMA_SM90_ENABLED
77+
# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED
78+
# define CUTE_ARCH_STSM_SM90_ENABLED
79+
# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED
80+
# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED
81+
# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED
82+
# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED
83+
# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED
84+
#endif
85+
86+
#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)
87+
# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED
88+
#endif
89+
90+
#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED))
91+
# define CUTE_ARCH_TMA_SM90_ENABLED
92+
# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED
93+
# define CUTE_ARCH_STSM_SM90_ENABLED
94+
#endif
95+
7596
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED))
7697
# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED
7798
#endif
@@ -91,8 +112,11 @@
91112
#endif
92113

93114
// {add, mul, fma}.f32x2 PTX
94-
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED))
95-
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
115+
#if defined(CUTLASS_ARCH_MMA_SM100_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
116+
// Enable CuTe MMA Atoms
117+
# define CUTE_ARCH_FFMA2_SM100_ENABLED
118+
// Enable f32x2 PTX generation
119+
# define CUTE_ARCH_FLOAT2_MATH_ENABLED
96120
#endif
97121

98122
#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)
@@ -109,3 +133,37 @@
109133
# endif
110134
#endif
111135

136+
#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)
137+
# define CUTE_ARCH_LDSM_SM100A_ENABLED
138+
# define CUTE_ARCH_STSM_SM100A_ENABLED
139+
# define CUTE_ARCH_TCGEN05_TMEM_ENABLED
140+
# define CUTE_ARCH_TMA_SM100_ENABLED
141+
# define CUTE_ARCH_FLOAT2_MATH_ENABLED
142+
#endif
143+
144+
#if defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)
145+
# define CUTE_ARCH_LDSM_SM100A_ENABLED
146+
# define CUTE_ARCH_STSM_SM100A_ENABLED
147+
# define CUTE_ARCH_TCGEN05_TMEM_ENABLED
148+
# define CUTE_ARCH_TMA_SM100_ENABLED
149+
#endif
150+
151+
#if defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)
152+
# define CUTE_ARCH_LDSM_SM100A_ENABLED
153+
# define CUTE_ARCH_STSM_SM100A_ENABLED
154+
#endif
155+
156+
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\
157+
defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\
158+
defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED))
159+
# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9))
160+
# define CUTE_ARCH_LOAD256_SM100A_ENABLED
161+
# define CUTE_ARCH_STORE256_SM100A_ENABLED
162+
# endif
163+
#endif
164+
165+
// {add, mul, fma}.f32x2 PTX
166+
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)
167+
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
168+
#endif
169+

0 commit comments

Comments
 (0)