Skip to content

Commit c81f03b

Browse files
Merge branch 'develop' into users/jlichtne/ALMIOPEN-1994-bump-rock-hash
2 parents 944ee8b + 308af93 commit c81f03b

91 files changed

Lines changed: 3679 additions & 527 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

projects/composablekernel/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,10 @@ SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
741741
if(BUILD_DEV)
742742
add_compile_options(-Werror)
743743
add_compile_options(-Weverything)
744+
add_compile_options(-Wno-lifetime-safety-intra-tu-suggestions)
745+
add_compile_options(-Wno-lifetime-safety-cross-tu-suggestions)
746+
add_compile_options(-Wno-lifetime-safety-lifetimebound-violation)
747+
add_compile_options(-Wno-unknown-warning-option)
744748
endif()
745749
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
746750

projects/composablekernel/cmake/EnableCompilerWarnings.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ else()
5050
-Wsign-compare
5151
-Wno-extra-semi-stmt
5252
-Wno-unused-template
53+
-Wno-lifetime-safety-intra-tu-suggestions
54+
-Wno-lifetime-safety-cross-tu-suggestions
55+
-Wno-lifetime-safety-lifetimebound-violation
56+
-Wno-unknown-warning-option
5357
)
5458
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
5559
list(APPEND CMAKE_COMPILER_WARNINGS
@@ -76,6 +80,10 @@ else()
7680
-Wno-unsafe-buffer-usage
7781
-Wno-unused-lambda-capture
7882
-Wno-nvcc-compat
83+
-Wno-lifetime-safety-intra-tu-suggestions
84+
-Wno-lifetime-safety-cross-tu-suggestions
85+
-Wno-lifetime-safety-lifetimebound-violation
86+
-Wno-unknown-warning-option
7987
)
8088
if(CK_CXX_STANDARD GREATER_EQUAL 20)
8189
list(APPEND CMAKE_COMPILER_WARNINGS -Wno-c++20-compat)

projects/composablekernel/dispatcher/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ endif()
5959
# Compiler warnings
6060
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
6161
target_compile_options(ck_tile_dispatcher PRIVATE
62-
-Wall -Wextra -Wpedantic
62+
-Wall -Wextra -Wpedantic -Wno-lifetime-safety-intra-tu-suggestions -Wno-lifetime-safety-cross-tu-suggestions -Wno-lifetime-safety-lifetimebound-violation -Wno-unknown-warning-option
6363
)
6464
elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
6565
target_compile_options(ck_tile_dispatcher PRIVATE

projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# SPDX-License-Identifier: MIT
33

44
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
5-
# Currently only gfx9 and gfx12 archs are supported by FMHA
6-
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
5+
# Currently only gfx9, gfx11, and gfx12 archs are supported by FMHA
6+
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx1[12]")
77
if(NOT INST_TARGETS)
88
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx11, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
99
return()

projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
2929
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
3030
// auto generated by generate.py
31+
#if defined(__HIP_DEVICE_COMPILE__) && \\
32+
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \\
33+
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \\
34+
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__))
35+
#undef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
36+
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
37+
#endif
3138
#include "fmha_bwd.hpp"
3239
3340
"""

projects/composablekernel/include/ck_tile/core/config.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,15 @@
173173
#endif
174174

175175
// buffer atomic add: floating point
176+
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
176177
#ifndef __HIP_DEVICE_COMPILE__ // for host code
177178
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
178179
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
179180
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
180181
#else // for GPU code
181182
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
182183
#endif
184+
#endif
183185

184186
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
185187
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1

projects/composablekernel/include/ck_tile/core/numeric/mxfp_scale.hpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,101 @@ struct Packed4Scale
103103
}
104104
};
105105

106+
template <typename ScaleType>
107+
struct Packed8Scale
108+
{
109+
using scale_type = ScaleType;
110+
using raw_type = uint64_t;
111+
using raw_scale_type = typename ScaleType::raw_type;
112+
113+
static constexpr int num_pack = 8;
114+
union
115+
{
116+
raw_type data_;
117+
raw_scale_type scales_[num_pack]; // Direct byte/element access
118+
};
119+
120+
// Constructors
121+
CK_TILE_HOST_DEVICE constexpr Packed8Scale() = default;
122+
CK_TILE_HOST_DEVICE constexpr Packed8Scale(raw_type val) : data_(val) {}
123+
CK_TILE_HOST_DEVICE constexpr Packed8Scale(
124+
float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7)
125+
{
126+
set_scales_from_float(s0, s1, s2, s3, s4, s5, s6, s7);
127+
}
128+
129+
CK_TILE_HOST_DEVICE constexpr Packed8Scale(ScaleType s0,
130+
ScaleType s1,
131+
ScaleType s2,
132+
ScaleType s3,
133+
ScaleType s4,
134+
ScaleType s5,
135+
ScaleType s6,
136+
ScaleType s7)
137+
{
138+
set_scales(s0, s1, s2, s3, s4, s5, s6, s7);
139+
}
140+
141+
CK_TILE_HOST_DEVICE constexpr void set_scales_from_float(
142+
float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7)
143+
{
144+
set_scales(ScaleType(s0),
145+
ScaleType(s1),
146+
ScaleType(s2),
147+
ScaleType(s3),
148+
ScaleType(s4),
149+
ScaleType(s5),
150+
ScaleType(s6),
151+
ScaleType(s7));
152+
}
153+
154+
CK_TILE_HOST_DEVICE constexpr void set_scales(ScaleType s0,
155+
ScaleType s1,
156+
ScaleType s2,
157+
ScaleType s3,
158+
ScaleType s4,
159+
ScaleType s5,
160+
ScaleType s6,
161+
ScaleType s7)
162+
{
163+
data_ = 0;
164+
pack_scale(s0, 7);
165+
pack_scale(s1, 6);
166+
pack_scale(s2, 5);
167+
pack_scale(s3, 4);
168+
pack_scale(s4, 3);
169+
pack_scale(s5, 2);
170+
pack_scale(s6, 1);
171+
pack_scale(s7, 0);
172+
}
173+
174+
CK_TILE_HOST_DEVICE constexpr operator raw_type() const { return data_; }
175+
CK_TILE_HOST_DEVICE constexpr raw_type& data() [[clang::lifetimebound]] { return data_; }
176+
CK_TILE_HOST_DEVICE constexpr raw_type data() const { return data_; }
177+
178+
CK_TILE_HOST_DEVICE constexpr float unpack_to_float(int i) const
179+
{
180+
return static_cast<float>(unpack_scale(i));
181+
}
182+
183+
CK_TILE_HOST_DEVICE constexpr ScaleType unpack_scale(int i) const
184+
{
185+
return ScaleType(scales_[i]);
186+
}
187+
188+
CK_TILE_HOST_DEVICE constexpr void pack_from_float(float scale, int i)
189+
{
190+
pack_scale(ScaleType(scale), i);
191+
}
192+
193+
CK_TILE_HOST_DEVICE constexpr void pack_scale(ScaleType scale, int i)
194+
{
195+
scales_[i] = scale.get();
196+
}
197+
};
198+
106199
// Type alias for e8m0_t scales
107200
using Packed4Scale_E8M0 = Packed4Scale<e8m0_t>;
201+
using Packed8Scale_E8M0 = Packed8Scale<e8m0_t>;
108202

109203
} // namespace ck_tile

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ struct StreamKKernel
363363

364364
// Determine the total size along the K dimension the workgroup is using in this
365365
// iteration (used to construct tensor views).
366-
index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
366+
index_t k_size = amd_wave_read_first_lane(
367+
kargs.tile_partitioner.get_k_size(num_loop_sk, local_iter_end));
367368

368369
// Get the K offsets for the A and B tensors
369370
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,19 @@ struct StreamKTilePartitionerBase
139139
CK_TILE_DEVICE auto
140140
get_output_tile_index(index_t tile_idx) const noexcept -> tuple<index_t, index_t>;
141141

142+
/**
143+
* @brief Calculates the total size along the K dimension the workgroup is using in this
144+
* Stream-K loop iteration
145+
*
146+
* @param num_macro_tiles The number of macro tiles along the K dimension this workgroup is
147+
* assigned.
148+
* @param local_iter_end The workgroup's non-inclusive end iteration that is local to its
149+
* current tile.
150+
* @return index_t The K dimension size for the current Stream-K loop iteration.
151+
*/
152+
CK_TILE_DEVICE index_t get_k_size(index_t num_macro_tiles,
153+
index_t local_iter_end) const noexcept;
154+
142155
/**
143156
* @brief Calculates the total space needed for the partials and flags buffers.
144157
*
@@ -208,6 +221,17 @@ struct StreamKTilePartitionerBase
208221
*/
209222
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
210223

224+
/**
225+
* @brief Returns the k dimension for the GEMM problem.
226+
*/
227+
CK_TILE_HOST_DEVICE index_t get_k() const noexcept;
228+
229+
/**
230+
* @brief Returns the remainder along the k dimension when k is not evenly divisible by
231+
* KPerBlock.
232+
*/
233+
CK_TILE_HOST_DEVICE index_t get_remainder_along_k() const noexcept;
234+
211235
/**
212236
* @brief Returns an estimate of the number of workgroups writing to the same macro tile in C.
213237
*/
@@ -244,6 +268,8 @@ struct StreamKTilePartitionerBase
244268
index_t extra_iters_;
245269
index_t total_dp_iters_;
246270
index_t n_;
271+
index_t k_;
272+
index_t remainder_along_k_;
247273
};
248274

249275
/**

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ namespace ck_tile {
88
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
99
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
1010
index_t m, index_t n, index_t k, index_t max_active_wgs)
11-
: max_active_wgs_{max_active_wgs}, n_{n}
11+
: max_active_wgs_{max_active_wgs}, n_{n}, k_{k}
1212
{
13-
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
14-
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
13+
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
14+
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
15+
remainder_along_k_ = k % KPerBlock;
1516

1617
bool big_enough = num_tiles_ > max_active_wgs_;
1718
index_t remainder_tiles = num_tiles_ % max_active_wgs_;
@@ -250,6 +251,21 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
250251
return n_;
251252
}
252253

254+
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
255+
CK_TILE_HOST_DEVICE index_t
256+
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_k() const noexcept
257+
{
258+
return k_;
259+
}
260+
261+
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
262+
CK_TILE_HOST_DEVICE index_t
263+
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_remainder_along_k()
264+
const noexcept
265+
{
266+
return remainder_along_k_;
267+
}
268+
253269
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
254270
CK_TILE_HOST index_t
255271
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
@@ -334,6 +350,29 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::remap_xcd
334350
return block_1d_id;
335351
}
336352

353+
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
354+
CK_TILE_DEVICE index_t
355+
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_k_size(
356+
index_t num_macro_tiles, index_t local_iter_end) const noexcept
357+
{
358+
// Determine if this workgroup is responsible for the last macro tile in the K dimension
359+
bool last_tile = get_iters_per_tile() == local_iter_end;
360+
index_t k_size;
361+
// If there is no remainder or if the workgroup was not assigned the last macro tile along K,
362+
// then their k_size will be a multiple of KPerBlock.
363+
if(!remainder_along_k_ || !last_tile)
364+
{
365+
k_size = num_macro_tiles * KPerBlock;
366+
}
367+
// Otherwise, there's a remainder. So, k_size is not a multiple of KPerBlock.
368+
else
369+
{
370+
k_size = (num_macro_tiles - 1) * KPerBlock + remainder_along_k_;
371+
}
372+
373+
return k_size;
374+
}
375+
337376
template <typename BlockGemmShapeType,
338377
StreamKReductionStrategy ReductionStrategyType,
339378
bool Persistent>

0 commit comments

Comments
 (0)