|
| 1 | +/************************************************************************* |
| 2 | + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | + * |
| 4 | + * See LICENSE for license information. |
| 5 | + ************************************************************************/ |
| 6 | + |
| 7 | +#include <transformer_engine/recipe.h> |
| 8 | + |
| 9 | +#include <algorithm> |
| 10 | +#include <vector> |
| 11 | + |
| 12 | +#include "../common.h" |
| 13 | +#include "../util/logging.h" |
| 14 | +#include "../util/vectorized_pointwise.h" |
| 15 | +#include "recipe_common.cuh" |
| 16 | + |
| 17 | +namespace transformer_engine { |
| 18 | +namespace { |
| 19 | + |
| 20 | +constexpr int multi_amax_kernel_threads = 512; |
| 21 | +// Per-launch capacity. kMaxTensorsPerBatch * ~40 bytes per slot keeps the args |
| 22 | +// struct within the 4KB kernel parameter limit with comfortable headroom. |
| 23 | +constexpr int kMaxTensorsPerBatch = 64; |
| 24 | + |
| 25 | +struct MultiAmaxArgs { |
| 26 | + const void *input_list[kMaxTensorsPerBatch]; |
| 27 | + void *output_rowwise_amax_list[kMaxTensorsPerBatch]; |
| 28 | + void *output_columnwise_amax_list[kMaxTensorsPerBatch]; |
| 29 | + size_t input_numel[kMaxTensorsPerBatch]; |
| 30 | + size_t num_aligned_elements[kMaxTensorsPerBatch]; |
| 31 | + int num_tensors; |
| 32 | +}; |
| 33 | + |
| 34 | +// Zero out every output amax slot (rowwise + columnwise, deduped) in a single launch. |
| 35 | +// Respects the noop_ptr contract shared with the single-tensor amax path. |
| 36 | +__launch_bounds__(multi_amax_kernel_threads) __global__ |
| 37 | + void MultiZeroAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) { |
| 38 | + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { |
| 39 | + return; |
| 40 | + } |
| 41 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 42 | + int stride = blockDim.x * gridDim.x; |
| 43 | + for (; tid < args.num_tensors; tid += stride) { |
| 44 | + float *rw = static_cast<float *>(args.output_rowwise_amax_list[tid]); |
| 45 | + float *cw = static_cast<float *>(args.output_columnwise_amax_list[tid]); |
| 46 | + if (rw != nullptr) { |
| 47 | + *rw = 0.0f; |
| 48 | + } |
| 49 | + if (cw != nullptr && cw != rw) { |
| 50 | + *cw = 0.0f; |
| 51 | + } |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +// Per-tensor amax with one block-strip per tensor. blockIdx.y selects the |
| 56 | +// tensor; blockIdx.x is the work chunk within that tensor. Each block |
| 57 | +// vector-loads the tensor, reduces across threads, and atomicMaxFloats the |
| 58 | +// result into BOTH output amax slots (rowwise + columnwise, deduped). This |
| 59 | +// subsumes the per-expert D2D copy that the single-tensor path does after the |
| 60 | +// amax kernel. |
| 61 | +template <int nvec, bool aligned, typename InputType> |
| 62 | +__launch_bounds__(multi_amax_kernel_threads) __global__ |
| 63 | + void MultiAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) { |
| 64 | + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { |
| 65 | + return; |
| 66 | + } |
| 67 | + |
| 68 | + const int t_idx = blockIdx.y; |
| 69 | + if (t_idx >= args.num_tensors) { |
| 70 | + return; |
| 71 | + } |
| 72 | + |
| 73 | + const InputType *input = static_cast<const InputType *>(args.input_list[t_idx]); |
| 74 | + const size_t N = args.input_numel[t_idx]; |
| 75 | + if (N == 0) { |
| 76 | + return; |
| 77 | + } |
| 78 | + const size_t M = args.num_aligned_elements[t_idx]; |
| 79 | + |
| 80 | + VectorizedLoader<InputType, nvec, aligned> loader(input, N); |
| 81 | + InputType max = InputType{0.f}; |
| 82 | + const int warp_id = threadIdx.x / THREADS_PER_WARP; |
| 83 | + |
| 84 | + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { |
| 85 | + loader.load(tid, N); |
| 86 | +#pragma unroll |
| 87 | + for (int i = 0; i < nvec; ++i) { |
| 88 | + const InputType val = static_cast<InputType>(loader.separate()[i]); |
| 89 | + __builtin_assume(max >= InputType{0.f}); |
| 90 | + if constexpr (std::is_same_v<InputType, __nv_bfloat16>) { |
| 91 | +#if __CUDA_ARCH__ >= 800 |
| 92 | + max = __hmax(__habs(val), max); |
| 93 | +#else |
| 94 | + max = static_cast<__nv_bfloat16>( |
| 95 | + fmaxf(fabsf(static_cast<float>(val)), static_cast<float>(max))); |
| 96 | +#endif |
| 97 | + } else if constexpr (std::is_same_v<InputType, __half>) { |
| 98 | + max = __hmax(__habs(val), max); |
| 99 | + } else { |
| 100 | + max = fmaxf(fabsf(val), max); |
| 101 | + } |
| 102 | + } |
| 103 | + } |
| 104 | + |
| 105 | + // Reduce amax over block. |
| 106 | + max = reduce_max<multi_amax_kernel_threads / THREADS_PER_WARP>(max, warp_id); |
| 107 | + if (threadIdx.x == 0) { |
| 108 | + float *rw = static_cast<float *>(args.output_rowwise_amax_list[t_idx]); |
| 109 | + float *cw = static_cast<float *>(args.output_columnwise_amax_list[t_idx]); |
| 110 | + if (rw != nullptr) { |
| 111 | + atomicMaxFloat(rw, static_cast<float>(max)); |
| 112 | + } |
| 113 | + if (cw != nullptr && cw != rw) { |
| 114 | + atomicMaxFloat(cw, static_cast<float>(max)); |
| 115 | + } |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +template <typename InputType> |
| 120 | +void launch_multi_amax_batch(const MultiAmaxArgs &args, size_t max_numel, Alignment align, |
| 121 | + const float *noop_ptr, cudaStream_t stream) { |
| 122 | + // Zero all amax outputs in one launch. |
| 123 | + { |
| 124 | + constexpr int threads = multi_amax_kernel_threads; |
| 125 | + const int num_blocks = std::max(1, DIVUP(args.num_tensors, threads)); |
| 126 | + MultiZeroAmaxKernel<<<num_blocks, threads, 0, stream>>>(args, noop_ptr); |
| 127 | + NVTE_CHECK_CUDA(cudaGetLastError()); |
| 128 | + } |
| 129 | + |
| 130 | + if (max_numel == 0) { |
| 131 | + return; |
| 132 | + } |
| 133 | + |
| 134 | + // Grid: y = tensor index, x = work chunks within the largest tensor. Blocks |
| 135 | + // that exceed a shorter tensor's aligned element count bail out via the |
| 136 | + // bounds check inside the kernel. |
| 137 | + constexpr int nvec = 32 / sizeof(InputType); |
| 138 | + constexpr size_t threads = multi_amax_kernel_threads; |
| 139 | + const size_t max_aligned = (max_numel + nvec - 1) / nvec; |
| 140 | + size_t num_blocks_x = DIVUP(max_aligned, threads); |
| 141 | + constexpr size_t max_blocks = 65535; |
| 142 | + num_blocks_x = std::min(num_blocks_x, max_blocks); |
| 143 | + num_blocks_x = std::max<size_t>(num_blocks_x, 1); |
| 144 | + dim3 grid(num_blocks_x, static_cast<unsigned int>(args.num_tensors), 1); |
| 145 | + |
| 146 | + switch (align) { |
| 147 | + case Alignment::SAME_ALIGNED: |
| 148 | + MultiAmaxKernel<nvec, true, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr); |
| 149 | + break; |
| 150 | + case Alignment::SAME_UNALIGNED: |
| 151 | + MultiAmaxKernel<nvec, false, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr); |
| 152 | + break; |
| 153 | + case Alignment::DIFFERENT: |
| 154 | + // Heterogeneous alignment across tensors — fall back to nvec=1, aligned=true path |
| 155 | + // which is safe for any pointer alignment. |
| 156 | + MultiAmaxKernel<1, true, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr); |
| 157 | + break; |
| 158 | + } |
| 159 | + NVTE_CHECK_CUDA(cudaGetLastError()); |
| 160 | +} |
| 161 | + |
| 162 | +// Fill one MultiAmaxArgs batch from a slice of the full input/output list. |
| 163 | +// Returns (max_numel in this batch, worst-case alignment across the batch). |
| 164 | +template <typename InputType> |
| 165 | +std::pair<size_t, Alignment> build_batch_args(const std::vector<Tensor *> &inputs, |
| 166 | + const std::vector<Tensor *> &outputs, size_t start, |
| 167 | + size_t count, MultiAmaxArgs &args) { |
| 168 | + constexpr int nvec = 32 / sizeof(InputType); |
| 169 | + size_t max_numel = 0; |
| 170 | + // SAME_ALIGNED is the most optimistic; degrade to SAME_UNALIGNED if any |
| 171 | + // tensor is merely same-layout but unaligned, to DIFFERENT if alignment |
| 172 | + // varies across tensors. |
| 173 | + Alignment batch_align = Alignment::SAME_ALIGNED; |
| 174 | + for (size_t i = 0; i < count; ++i) { |
| 175 | + const Tensor &inp = *inputs[start + i]; |
| 176 | + Tensor &out = *outputs[start + i]; |
| 177 | + const size_t N = inp.data.numel(); |
| 178 | + void *rw_ptr = out.amax.dptr; |
| 179 | + void *cw_ptr = out.columnwise_amax.dptr; |
| 180 | + |
| 181 | + args.input_list[i] = inp.data.dptr; |
| 182 | + args.output_rowwise_amax_list[i] = rw_ptr; |
| 183 | + args.output_columnwise_amax_list[i] = cw_ptr; |
| 184 | + args.input_numel[i] = N; |
| 185 | + args.num_aligned_elements[i] = |
| 186 | + get_num_aligned_elements(inp.data.dptr, N, nvec, sizeof(InputType)); |
| 187 | + max_numel = std::max(max_numel, N); |
| 188 | + |
| 189 | + // Fold this tensor's alignment into the batch decision. CheckAlignment on a |
| 190 | + // single pointer yields SAME_ALIGNED or SAME_UNALIGNED; mixing the two across |
| 191 | + // tensors means heterogeneous — switch to the DIFFERENT fall-back. |
| 192 | + if (N > 0) { |
| 193 | + Alignment a = CheckAlignment(N, nvec, static_cast<const InputType *>(inp.data.dptr)); |
| 194 | + if (batch_align == Alignment::SAME_ALIGNED && a == Alignment::SAME_UNALIGNED) { |
| 195 | + batch_align = Alignment::SAME_UNALIGNED; |
| 196 | + } else if (batch_align == Alignment::SAME_UNALIGNED && a == Alignment::SAME_ALIGNED) { |
| 197 | + batch_align = Alignment::SAME_UNALIGNED; |
| 198 | + } else if (a == Alignment::DIFFERENT) { |
| 199 | + batch_align = Alignment::DIFFERENT; |
| 200 | + } |
| 201 | + } |
| 202 | + } |
| 203 | + args.num_tensors = static_cast<int>(count); |
| 204 | + return {max_numel, batch_align}; |
| 205 | +} |
| 206 | + |
| 207 | +void multi_compute_amax_impl(const NVTETensor *inputs_, NVTETensor *outputs_, size_t num_tensors, |
| 208 | + const NVTEQuantizationConfig config_, cudaStream_t stream) { |
| 209 | + if (num_tensors == 0) { |
| 210 | + return; |
| 211 | + } |
| 212 | + NVTE_CHECK(inputs_ != nullptr, "nvte_multi_compute_amax: inputs is NULL"); |
| 213 | + NVTE_CHECK(outputs_ != nullptr, "nvte_multi_compute_amax: outputs is NULL"); |
| 214 | + |
| 215 | + // Convert, validate, collect into plain vectors. |
| 216 | + std::vector<Tensor *> inputs(num_tensors); |
| 217 | + std::vector<Tensor *> outputs(num_tensors); |
| 218 | + DType input_dtype; |
| 219 | + for (size_t i = 0; i < num_tensors; ++i) { |
| 220 | + inputs[i] = convertNVTETensorCheck(inputs_[i]); |
| 221 | + outputs[i] = convertNVTETensorCheck(outputs_[i]); |
| 222 | + const auto &inp = *inputs[i]; |
| 223 | + auto &out = *outputs[i]; |
| 224 | + NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "nvte_multi_compute_amax: input[", |
| 225 | + i, "] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode)); |
| 226 | + NVTE_CHECK(!is_fp8_dtype(inp.data.dtype), "nvte_multi_compute_amax: input[", i, |
| 227 | + "] must be unquantized, got dtype=", to_string(inp.data.dtype)); |
| 228 | + if (i == 0) { |
| 229 | + input_dtype = inp.data.dtype; |
| 230 | + } else { |
| 231 | + NVTE_CHECK(inp.data.dtype == input_dtype, |
| 232 | + "nvte_multi_compute_amax: all inputs must share dtype; input[0]=", |
| 233 | + to_string(input_dtype), ", input[", i, "]=", to_string(inp.data.dtype)); |
| 234 | + } |
| 235 | + NVTE_CHECK(out.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || |
| 236 | + out.scaling_mode == NVTE_NVFP4_1D_SCALING, |
| 237 | + "nvte_multi_compute_amax: output[", i, "] must be FP8 per-tensor or NVFP4 1D"); |
| 238 | + NVTE_CHECK(out.amax.dptr != nullptr || out.columnwise_amax.dptr != nullptr, |
| 239 | + "nvte_multi_compute_amax: output[", i, "] has no amax buffer"); |
| 240 | + } |
| 241 | + |
| 242 | + const float *noop_ptr = nullptr; |
| 243 | + if (config_ != nullptr) { |
| 244 | + const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_); |
| 245 | + const NVTETensor noop = config_cpp->noop_tensor; |
| 246 | + noop_ptr = reinterpret_cast<float *>( |
| 247 | + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); |
| 248 | + } |
| 249 | + |
| 250 | + // Chunk across kMaxTensorsPerBatch launches (single launch in the common 8-expert case). |
| 251 | + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input_dtype, IType, { |
| 252 | + for (size_t start = 0; start < num_tensors; start += kMaxTensorsPerBatch) { |
| 253 | + const size_t count = std::min<size_t>(kMaxTensorsPerBatch, num_tensors - start); |
| 254 | + MultiAmaxArgs args = {}; |
| 255 | + auto [max_numel, batch_align] = build_batch_args<IType>(inputs, outputs, start, count, args); |
| 256 | + launch_multi_amax_batch<IType>(args, max_numel, batch_align, noop_ptr, stream); |
| 257 | + } |
| 258 | + }); // NOLINT(*) |
| 259 | +} |
| 260 | + |
| 261 | +} // anonymous namespace |
| 262 | +} // namespace transformer_engine |
| 263 | + |
| 264 | +void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, |
| 265 | + const NVTEQuantizationConfig config, cudaStream_t stream) { |
| 266 | + NVTE_API_CALL(nvte_multi_compute_amax); |
| 267 | + transformer_engine::multi_compute_amax_impl(inputs, outputs, num_tensors, config, stream); |
| 268 | +} |
0 commit comments