Skip to content

Commit fcbd964

Browse files
fanshiqingjiemingz
andcommitted
Generalized Tensor Parallelism (GTP) init commit
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com> Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
1 parent 76c2a9e commit fcbd964

13 files changed

Lines changed: 1144 additions & 179 deletions

File tree

transformer_engine/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ list(APPEND transformer_engine_cuda_sources
207207
recipe/current_scaling.cu
208208
recipe/delayed_scaling.cu
209209
recipe/fp8_block_scaling.cu
210+
recipe/multi_amax.cu
210211
comm_gemm_overlap/userbuffers/userbuffers.cu)
211212

212213
list(APPEND transformer_engine_cuda_arch_specific_sources

transformer_engine/common/include/transformer_engine/recipe.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,26 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
9999
void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
100100
const NVTEQuantizationConfig config, cudaStream_t stream);
101101

102+
/*! \brief Compute amax for a list of independent tensors in a single kernel launch.
103+
*
104+
* Unlike nvte_group_amax (which requires a single contiguous input split along dim 0),
105+
* this API accepts arrays of independent input tensors, each with its own allocation.
106+
* Designed for the GTP grouped-experts case where per-expert weights live in separate
107+
* buffers. For each i in [0, num_tensors), computes amax(inputs[i]) and writes it to
108+
* outputs[i]'s amax buffer. outputs[i] must be an FP8 per-tensor scaling or NVFP4 1D
109+
* scaling tensor. All inputs must share the same dtype. If the list exceeds the
110+
* per-launch batch capacity, it is internally chunked.
111+
*
112+
* \param[in] inputs Array of input tensors (unquantized). Size num_tensors.
113+
* \param[in,out] outputs Array of output tensors. Only the amax is updated.
114+
* Size num_tensors.
115+
* \param[in] num_tensors Number of tensors.
116+
* \param[in] config Quantization configuration (for noop_tensor). May be NULL.
117+
* \param[in] stream CUDA stream used for the operation.
118+
*/
119+
void nvte_multi_compute_amax(const NVTETensor* inputs, NVTETensor* outputs, size_t num_tensors,
120+
const NVTEQuantizationConfig config, cudaStream_t stream);
121+
102122
/*! \brief Update an FP8 tensor's scale based on its amax.
103123
*
104124
* This is only supported for FP8 tensors with per-tensor scaling.
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include <transformer_engine/recipe.h>
8+
9+
#include <algorithm>
10+
#include <vector>
11+
12+
#include "../common.h"
13+
#include "../util/logging.h"
14+
#include "../util/vectorized_pointwise.h"
15+
#include "recipe_common.cuh"
16+
17+
namespace transformer_engine {
18+
namespace {
19+
20+
constexpr int multi_amax_kernel_threads = 512;
21+
// Per-launch capacity. kMaxTensorsPerBatch * ~40 bytes per slot keeps the args
22+
// struct within the 4KB kernel parameter limit with comfortable headroom.
23+
constexpr int kMaxTensorsPerBatch = 64;
24+
25+
struct MultiAmaxArgs {
26+
const void *input_list[kMaxTensorsPerBatch];
27+
void *output_rowwise_amax_list[kMaxTensorsPerBatch];
28+
void *output_columnwise_amax_list[kMaxTensorsPerBatch];
29+
size_t input_numel[kMaxTensorsPerBatch];
30+
size_t num_aligned_elements[kMaxTensorsPerBatch];
31+
int num_tensors;
32+
};
33+
34+
// Zero out every output amax slot (rowwise + columnwise, deduped) in a single launch.
35+
// Respects the noop_ptr contract shared with the single-tensor amax path.
36+
__launch_bounds__(multi_amax_kernel_threads) __global__
37+
void MultiZeroAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) {
38+
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
39+
return;
40+
}
41+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
42+
int stride = blockDim.x * gridDim.x;
43+
for (; tid < args.num_tensors; tid += stride) {
44+
float *rw = static_cast<float *>(args.output_rowwise_amax_list[tid]);
45+
float *cw = static_cast<float *>(args.output_columnwise_amax_list[tid]);
46+
if (rw != nullptr) {
47+
*rw = 0.0f;
48+
}
49+
if (cw != nullptr && cw != rw) {
50+
*cw = 0.0f;
51+
}
52+
}
53+
}
54+
55+
// Per-tensor amax with one block-strip per tensor. blockIdx.y selects the
56+
// tensor; blockIdx.x is the work chunk within that tensor. Each block
57+
// vector-loads the tensor, reduces across threads, and atomicMaxFloats the
58+
// result into BOTH output amax slots (rowwise + columnwise, deduped). This
59+
// subsumes the per-expert D2D copy that the single-tensor path does after the
60+
// amax kernel.
61+
template <int nvec, bool aligned, typename InputType>
62+
__launch_bounds__(multi_amax_kernel_threads) __global__
63+
void MultiAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) {
64+
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
65+
return;
66+
}
67+
68+
const int t_idx = blockIdx.y;
69+
if (t_idx >= args.num_tensors) {
70+
return;
71+
}
72+
73+
const InputType *input = static_cast<const InputType *>(args.input_list[t_idx]);
74+
const size_t N = args.input_numel[t_idx];
75+
if (N == 0) {
76+
return;
77+
}
78+
const size_t M = args.num_aligned_elements[t_idx];
79+
80+
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
81+
InputType max = InputType{0.f};
82+
const int warp_id = threadIdx.x / THREADS_PER_WARP;
83+
84+
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
85+
loader.load(tid, N);
86+
#pragma unroll
87+
for (int i = 0; i < nvec; ++i) {
88+
const InputType val = static_cast<InputType>(loader.separate()[i]);
89+
__builtin_assume(max >= InputType{0.f});
90+
if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
91+
#if __CUDA_ARCH__ >= 800
92+
max = __hmax(__habs(val), max);
93+
#else
94+
max = static_cast<__nv_bfloat16>(
95+
fmaxf(fabsf(static_cast<float>(val)), static_cast<float>(max)));
96+
#endif
97+
} else if constexpr (std::is_same_v<InputType, __half>) {
98+
max = __hmax(__habs(val), max);
99+
} else {
100+
max = fmaxf(fabsf(val), max);
101+
}
102+
}
103+
}
104+
105+
// Reduce amax over block.
106+
max = reduce_max<multi_amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
107+
if (threadIdx.x == 0) {
108+
float *rw = static_cast<float *>(args.output_rowwise_amax_list[t_idx]);
109+
float *cw = static_cast<float *>(args.output_columnwise_amax_list[t_idx]);
110+
if (rw != nullptr) {
111+
atomicMaxFloat(rw, static_cast<float>(max));
112+
}
113+
if (cw != nullptr && cw != rw) {
114+
atomicMaxFloat(cw, static_cast<float>(max));
115+
}
116+
}
117+
}
118+
119+
template <typename InputType>
120+
void launch_multi_amax_batch(const MultiAmaxArgs &args, size_t max_numel, Alignment align,
121+
const float *noop_ptr, cudaStream_t stream) {
122+
// Zero all amax outputs in one launch.
123+
{
124+
constexpr int threads = multi_amax_kernel_threads;
125+
const int num_blocks = std::max(1, DIVUP(args.num_tensors, threads));
126+
MultiZeroAmaxKernel<<<num_blocks, threads, 0, stream>>>(args, noop_ptr);
127+
NVTE_CHECK_CUDA(cudaGetLastError());
128+
}
129+
130+
if (max_numel == 0) {
131+
return;
132+
}
133+
134+
// Grid: y = tensor index, x = work chunks within the largest tensor. Blocks
135+
// that exceed a shorter tensor's aligned element count bail out via the
136+
// bounds check inside the kernel.
137+
constexpr int nvec = 32 / sizeof(InputType);
138+
constexpr size_t threads = multi_amax_kernel_threads;
139+
const size_t max_aligned = (max_numel + nvec - 1) / nvec;
140+
size_t num_blocks_x = DIVUP(max_aligned, threads);
141+
constexpr size_t max_blocks = 65535;
142+
num_blocks_x = std::min(num_blocks_x, max_blocks);
143+
num_blocks_x = std::max<size_t>(num_blocks_x, 1);
144+
dim3 grid(num_blocks_x, static_cast<unsigned int>(args.num_tensors), 1);
145+
146+
switch (align) {
147+
case Alignment::SAME_ALIGNED:
148+
MultiAmaxKernel<nvec, true, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr);
149+
break;
150+
case Alignment::SAME_UNALIGNED:
151+
MultiAmaxKernel<nvec, false, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr);
152+
break;
153+
case Alignment::DIFFERENT:
154+
// Heterogeneous alignment across tensors — fall back to nvec=1, aligned=true path
155+
// which is safe for any pointer alignment.
156+
MultiAmaxKernel<1, true, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr);
157+
break;
158+
}
159+
NVTE_CHECK_CUDA(cudaGetLastError());
160+
}
161+
162+
// Fill one MultiAmaxArgs batch from a slice of the full input/output list.
163+
// Returns (max_numel in this batch, worst-case alignment across the batch).
164+
template <typename InputType>
165+
std::pair<size_t, Alignment> build_batch_args(const std::vector<Tensor *> &inputs,
166+
const std::vector<Tensor *> &outputs, size_t start,
167+
size_t count, MultiAmaxArgs &args) {
168+
constexpr int nvec = 32 / sizeof(InputType);
169+
size_t max_numel = 0;
170+
// SAME_ALIGNED is the most optimistic; degrade to SAME_UNALIGNED if any
171+
// tensor is merely same-layout but unaligned, to DIFFERENT if alignment
172+
// varies across tensors.
173+
Alignment batch_align = Alignment::SAME_ALIGNED;
174+
for (size_t i = 0; i < count; ++i) {
175+
const Tensor &inp = *inputs[start + i];
176+
Tensor &out = *outputs[start + i];
177+
const size_t N = inp.data.numel();
178+
void *rw_ptr = out.amax.dptr;
179+
void *cw_ptr = out.columnwise_amax.dptr;
180+
181+
args.input_list[i] = inp.data.dptr;
182+
args.output_rowwise_amax_list[i] = rw_ptr;
183+
args.output_columnwise_amax_list[i] = cw_ptr;
184+
args.input_numel[i] = N;
185+
args.num_aligned_elements[i] =
186+
get_num_aligned_elements(inp.data.dptr, N, nvec, sizeof(InputType));
187+
max_numel = std::max(max_numel, N);
188+
189+
// Fold this tensor's alignment into the batch decision. CheckAlignment on a
190+
// single pointer yields SAME_ALIGNED or SAME_UNALIGNED; mixing the two across
191+
// tensors means heterogeneous — switch to the DIFFERENT fall-back.
192+
if (N > 0) {
193+
Alignment a = CheckAlignment(N, nvec, static_cast<const InputType *>(inp.data.dptr));
194+
if (batch_align == Alignment::SAME_ALIGNED && a == Alignment::SAME_UNALIGNED) {
195+
batch_align = Alignment::SAME_UNALIGNED;
196+
} else if (batch_align == Alignment::SAME_UNALIGNED && a == Alignment::SAME_ALIGNED) {
197+
batch_align = Alignment::SAME_UNALIGNED;
198+
} else if (a == Alignment::DIFFERENT) {
199+
batch_align = Alignment::DIFFERENT;
200+
}
201+
}
202+
}
203+
args.num_tensors = static_cast<int>(count);
204+
return {max_numel, batch_align};
205+
}
206+
207+
void multi_compute_amax_impl(const NVTETensor *inputs_, NVTETensor *outputs_, size_t num_tensors,
208+
const NVTEQuantizationConfig config_, cudaStream_t stream) {
209+
if (num_tensors == 0) {
210+
return;
211+
}
212+
NVTE_CHECK(inputs_ != nullptr, "nvte_multi_compute_amax: inputs is NULL");
213+
NVTE_CHECK(outputs_ != nullptr, "nvte_multi_compute_amax: outputs is NULL");
214+
215+
// Convert, validate, collect into plain vectors.
216+
std::vector<Tensor *> inputs(num_tensors);
217+
std::vector<Tensor *> outputs(num_tensors);
218+
DType input_dtype;
219+
for (size_t i = 0; i < num_tensors; ++i) {
220+
inputs[i] = convertNVTETensorCheck(inputs_[i]);
221+
outputs[i] = convertNVTETensorCheck(outputs_[i]);
222+
const auto &inp = *inputs[i];
223+
auto &out = *outputs[i];
224+
NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "nvte_multi_compute_amax: input[",
225+
i, "] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode));
226+
NVTE_CHECK(!is_fp8_dtype(inp.data.dtype), "nvte_multi_compute_amax: input[", i,
227+
"] must be unquantized, got dtype=", to_string(inp.data.dtype));
228+
if (i == 0) {
229+
input_dtype = inp.data.dtype;
230+
} else {
231+
NVTE_CHECK(inp.data.dtype == input_dtype,
232+
"nvte_multi_compute_amax: all inputs must share dtype; input[0]=",
233+
to_string(input_dtype), ", input[", i, "]=", to_string(inp.data.dtype));
234+
}
235+
NVTE_CHECK(out.scaling_mode == NVTE_DELAYED_TENSOR_SCALING ||
236+
out.scaling_mode == NVTE_NVFP4_1D_SCALING,
237+
"nvte_multi_compute_amax: output[", i, "] must be FP8 per-tensor or NVFP4 1D");
238+
NVTE_CHECK(out.amax.dptr != nullptr || out.columnwise_amax.dptr != nullptr,
239+
"nvte_multi_compute_amax: output[", i, "] has no amax buffer");
240+
}
241+
242+
const float *noop_ptr = nullptr;
243+
if (config_ != nullptr) {
244+
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
245+
const NVTETensor noop = config_cpp->noop_tensor;
246+
noop_ptr = reinterpret_cast<float *>(
247+
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
248+
}
249+
250+
// Chunk across kMaxTensorsPerBatch launches (single launch in the common 8-expert case).
251+
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input_dtype, IType, {
252+
for (size_t start = 0; start < num_tensors; start += kMaxTensorsPerBatch) {
253+
const size_t count = std::min<size_t>(kMaxTensorsPerBatch, num_tensors - start);
254+
MultiAmaxArgs args = {};
255+
auto [max_numel, batch_align] = build_batch_args<IType>(inputs, outputs, start, count, args);
256+
launch_multi_amax_batch<IType>(args, max_numel, batch_align, noop_ptr, stream);
257+
}
258+
}); // NOLINT(*)
259+
}
260+
261+
} // anonymous namespace
262+
} // namespace transformer_engine
263+
264+
void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors,
265+
const NVTEQuantizationConfig config, cudaStream_t stream) {
266+
NVTE_API_CALL(nvte_multi_compute_amax);
267+
transformer_engine::multi_compute_amax_impl(inputs, outputs, num_tensors, config, stream);
268+
}

transformer_engine/pytorch/csrc/common.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,30 @@ class NVFP4Quantizer : public Quantizer {
369369
*/
370370
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);
371371

372+
/*! @brief Compute (and D2D fill) local amax only — no cast, no allreduce.
373+
*
374+
* Writes the local amax into out's rowwise and/or columnwise amax
375+
* buffers. Callers are expected to perform a coalesced allreduce
376+
* across the amax reduction group afterwards, then invoke
377+
* quantize_cast_only to finish the cast with the reduced amax.
378+
*/
379+
void compute_amax_only(const TensorWrapper& input, TensorWrapper& out);
380+
381+
/*! @brief Cast to NVFP4 assuming amax already reduced externally.
382+
*
383+
* Skips both local amax compute and the internal amax allreduce.
384+
* Callers must guarantee out's amax buffers already hold the reduced
385+
* amax (e.g. via compute_amax_only + allreduce_coalesced).
386+
*/
387+
void quantize_cast_only(const TensorWrapper& input, TensorWrapper& out,
388+
const std::optional<TensorWrapper>& noop_flag = std::nullopt);
389+
372390
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
373391

374392
private:
375393
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
376-
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
394+
const std::optional<TensorWrapper>& noop_flag, bool compute_amax,
395+
bool skip_amax_reduction = false);
377396
void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out,
378397
TensorWrapper& rht_output_t_cpp,
379398
QuantizationConfigWrapper& quant_config,

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,20 @@ py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector
329329
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
330330
std::optional<at::Tensor> noop_flag);
331331

332+
// NVFP4-only split-phase quantize: compute amax, coalesce allreduce externally, then cast.
333+
py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer,
334+
const py::object &output);
335+
py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer,
336+
const py::object &output, std::optional<at::Tensor> noop_flag);
337+
338+
// NVFP4-only multi-tensor amax: fuses N per-expert (zero_amax + amax + D2D replicate)
339+
// chains into a single pair of kernel launches (one multi-zero + one multi-amax) that
340+
// writes amax into every output's rowwise AND columnwise buffers. Outputs must be
341+
// pre-allocated; amax is written in place, no return.
342+
void compute_multi_amax_nvfp4(const std::vector<at::Tensor> &tensor_list,
343+
std::vector<py::handle> quantizer_list,
344+
const std::vector<py::object> &output_list);
345+
332346
py::object dequantize(const py::handle &input, DType otype);
333347

334348
py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors,

0 commit comments

Comments
 (0)