Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114
Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114vthumbe1503 wants to merge 59 commits into
Conversation
Route grouped Float8CurrentScalingQuantizer through the existing grouped quantize entry point, prepare per-group current-scaling metadata with existing amax/scale helpers, and add focused tests plus a GB200 bandwidth benchmark. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_5507e814ee50f9ff304a4ce708d19768 Orchestra-Run: run_516e1e26891f4ce7d4cde07147c10862
Use wider vectorized grouped FP8 cast-transpose tiles and vectorized masked stores for rowwise and columnwise outputs. Capture all benchmark modes in a single post-warmup profiler range. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3d6e33eab11e293d72eb4394bad76a81 Orchestra-Run: run_a6e2c31d5fdf850594f71438e53148da
Route non-MXFP8 grouped-linear bias backward through group_quantize plus grouped dbias while keeping MXFP8 bgrad_group_quantize fusion intact. Add focused zero-row grouped FP8 coverage and a current-scaling GroupedLinear bias-backward regression. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_ab566800d87047635cd27f9e64661abe Orchestra-Run: run_5f9bfef17ccd854232c54d56268ef9e8
Use packed FP8 conversion and reduce columnwise transpose staging register and synchronization overhead in group_cast_fp8_kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_7a830e018ceac8de0018280bd0740a54 Orchestra-Run: run_d2f1df4ffc2265d9cfa5ed01028ee476
Match the grouped FP8 conversion helper's element-count template parameter to Vec's uint32_t parameter so rowwise, columnwise, and activation instantiations can build. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_30c4b6ddb896e5ea3ca5b54731d2c819 Orchestra-Run: run_e95cdbb445943304622b95736f0eca49
Use cached grouped offsets to avoid launching FP8 quantization over unused overallocated rows, permit larger grouped backing buffers when split metadata is present, and tighten full-tile vector paths in the grouped FP8 cast kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_c5db93823dc101838cb1323e283cd6e9 Orchestra-Run: run_063e2e4c724e132612aa5597d6765c9b
Use the FP8 grouped output logical shape when computing the tensor-scaling launch grid so overallocated buffers with active metadata avoid empty tail-row launches while preserving the allocated-shape fallback. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_b4abb47c990404d73142342a19996a3f Orchestra-Run: run_8f09e7b9d7af9754ef505f2e2ce3cf90
Use larger grouped FP8 tiles with 8-warp CTAs and 16-row columnwise store fragments. Treat uniform overallocated FP8 grouped outputs as same-shape wrappers during output reuse so the timed path avoids varying-shape metadata overlaunch. Add overallocated current-scaling coverage for all grouped FP8 direction modes. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3f98ac9c5b82192ec289d8d2a9816c7f Orchestra-Run: run_83f3b99cc950024cf06ee836337fbf72
Stage columnwise transpose fragments through shared-memory vectors with smaller columnwise row tiles to reduce register pressure and barrier overhead while preserving the larger rowwise-only store path. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_495cc57eef84749103aded403a508d99 Orchestra-Run: run_53e038e90f83186bc6c12cb722c986b5
Add fast grouped FP8 rowwise and full-tile columnwise paths for uniform active groups while preserving the general fallback for varying grouped metadata. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_4c33e88776c8a7148e9da5cc2bae84ea Orchestra-Run: run_2caaff219394eb5d59b7be38ab2bf346
Add a same-shape bidirectional full-tile kernel with wider input vectors and rowwise stores while preserving the existing rowwise-only, columnwise-only, and fallback grouped paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_87cec01d94f053b53e3c79377ad379ab Orchestra-Run: run_ed48db00a730a4bf56530d551ecd350e
Route same-shape rowwise+columnwise grouped FP8 tensor-scaling quantization through the compact full-tile transpose schedule instead of the wide dynamic-shared-memory variant, preserving the existing single-direction and fallback paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_fdddd228a620039c024b4ecf43f3ab42 Orchestra-Run: run_30a2753eea9c893cb0fadb8233da8ce6
Hint the rowwise stores in the full-tile rowwise+columnwise grouped FP8 path as streaming global stores to reduce cache/writeback pressure without changing single-direction launch geometry. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_bf82020032e68276f4e47c65f62d97ae Orchestra-Run: run_754ea4c864f329c6f2003b413b723c43
Add graph-safe grouped FP8 tensor-scaling metadata, support varying last dimensions, preserve same-shape fast paths, adjust grouped FP8 columnwise allocation by architecture, and expand benchmark/test coverage for the reviewed shape cases. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d104e74844fbc3d3b1a98a8d96d76037 Orchestra-Run: run_1314e997c61ffb92ff7120b0b26f0318
Map varying-last columnwise tiles per group to avoid tile-alignment device errors, expand nonaligned boundary coverage, and restore same-shape benchmark baseline criteria. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_14e0e7973300d26f69550bc0aee21acc Orchestra-Run: run_2f42b8ba138ed8b2b4d9dc90b92caf85
Add grouped FP8 benchmark support for baseline-ref same-session reports and update the benchmark request to enforce same-shape baseline regression checks alongside the per-mode throughput thresholds. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d0cada957a4aafdce9d52be86520e182 Orchestra-Run: run_4da74e9bdb4f4a4c72304a385692b6c9
Update the grouped FP8 benchmark driver so same-session baseline checks out and builds the baseline ref into an isolated PyTorch install, verifies the baseline subprocess loads those shared objects, and preserves the required same-shape baseline comparisons. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_4fd88b172872f547f2f2d0053dce73d1 Orchestra-Run: run_6a44ee0467ffff47d4b278de6127354d
Preserve grouped delayed-FP8 amax metadata and keep unsupported FP8 tensor-scaling quantizers out of the grouped GEMM path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2aa8e6bf11ae356f4b34d4540b508031 Orchestra-Run: run_302681098d7f4e05b0ad96450f2d9826
Set NVTE_GROUPED_LINEAR_SINGLE_PARAM inside the targeted state-dict tests so they exercise the gated single grouped parameter path without relying on external environment setup. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_261900f987bdc9397965019983a77c41 Orchestra-Run: run_c6624e34717cbe121b3e0edcf490e3d3
Add a segmented flat rowwise kernel for varying-first grouped FP8 tensor-scaling outputs while preserving the existing same-shape fast path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c1b7020b27290318848ef6ac9048dd5f Orchestra-Run: run_5c257b8a5d2e7e4aa95e67aa16436166
Omit the last_dims keyword when absent so the same-session baseline can run against the base extension, and refresh the benchmark request to include direct varying-last current-scaling coverage. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c20a3c94fdc798e741a469bd7bb9c4df Orchestra-Run: run_457448e6cba80fc63ac72b3db71c5fd0
Dispatch varying-first tensor-scaling work per group to reduce inactive-tail CTAs and offset lookup overhead while preserving same-shape fast paths and graph-safe device metadata handling. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d84e1fefef8641e558df064452f4689b Orchestra-Run: run_a361ca2f93fcec53ddd60dd99f4639e5
Add a no-tail rowwise flat kernel for aligned varying-first grouped FP8 tensor-scaling quantization and keep same-shape and varying-last dispatch isolated. Tighten benchmark profiler timing so post-warmup measured ranges exclude profiler start overhead. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2e478be1fb38195f36d25c51320dc01f Orchestra-Run: run_9a133a75fa3d98dc3b1a63b0ff4d84af
Write grouped FP8 benchmark reports to a sidecar path by default and label script reports as benchmark_raw_report/v1 so regular 100-iteration measurements are fetched instead of the wrapper command report. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_27770b2e1d490b1a3053244d4b4ce248 Orchestra-Run: run_214052d0c1316e231443d645183a2675
Write the grouped FP8 benchmark JSON once and mirror the completed sidecar to ORCHESTRA_BENCHMARK_RAW_REPORT when running under Orchestra so the benchmark fetch path can parse the emitted measurements. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_b2e2747371204088c8e3f7cf10263164 Orchestra-Run: run_1d4ea38266807c8acb59143ee74ba241
Allow the grouped FP8 benchmark to use ORCHESTRA_BENCHMARK_RAW_REPORT as its primary output so the benchmark wrapper can fetch canonical measurements directly. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_10fdcfef6b70de4676b7843e4bbfac31 Orchestra-Run: run_4ce57df9e86d6d03a26f7aa95ac252cc
Write canonical grouped FP8 benchmark measurements to ORCHESTRA_BENCHMARK_RAW_REPORT in a small schema-shaped payload so the benchmark wrapper can materialize per-mode threshold evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_3e862eebd585c74f2a58497fedea3511 Orchestra-Run: run_3770ab3dbbf51329d0839b3d10a91b5c
Write candidate_results and nonempty measurements into the Orchestra raw report path, and fail fast if the benchmark cannot produce threshold-ready evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_aa587a7b0d35aa9c2b715ec1b7c8bec3 Orchestra-Run: run_b42870e5d5e142a6cbf53bb5a3cafc2e
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…nel for varying all dims - Generalized the splits_to_offsets CUDA kernel to accept last_dims pointers along with strides, templatizing correctly to avoid performance impact on non-varying cases. - Updated nvte_splits_to_offsets and nvte_splits_to_offsets_multi APIs to pass last_dims to the kernel. - Integrated last_dims into create_grouped_tensor across None, Float8, Float8CurrentScaling, Float8Blockwise, MXFP8, and NVFP4 Quantizers. - Updated group_quantize, bgrad_group_quantize, and nvfp4_group_quantize_with_amax to correctly accept and pass last_dims along with precomputed tensor_offsets.
- Reverted all changes to splits_to_offsets_multi and nvte_splits_to_offsets_multi. - Adapted the existing `kernel` in splits_to_offsets.cu using template parameters instead of introducing a separate kernel, preserving original kernel name and layout. - Kept the optimized nvte_splits_to_offsets behavior with varying first and last dims.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…TransformerEngine into current_scaling_group_quant
| constexpr size_t GROUPED_AMAX_MIN_ELTS_PER_BLOCK = 8 * 1024; // ~16KB of bf16 | ||
|
|
||
| // Zero per-tensor amax buffer so the main kernel can use atomicMax updates. | ||
| __launch_bounds__(GROUPED_AMAX_KERNEL_THREADS) __global__ |
There was a problem hiding this comment.
Alternatively, we can clear the contiguous buffer with cudaMemsetAsync() on the same stream before launching the main kernel. This preserves ordering and may avoid the overhead of a separate zeroing kernel.
| // 4-way vectorized load and reduce | ||
| for (; v + 3 * blockDim.x < total_vecs; v += 4 * blockDim.x) { | ||
| IVecT vec0, vec1, vec2, vec3; | ||
| vec0.load_from(base + v * NVEC); |
There was a problem hiding this comment.
Let’s benchmark the LDG path here. If it turns out to be a bottleneck, we may need to switch to TMA.
| s_block_amax[warp_id] = thread_amax; | ||
| } | ||
| __syncthreads(); | ||
| // Reduce amax within the warp |
There was a problem hiding this comment.
We can reuse the warp-level reducer from transformer_engine/common/utils.cuh here to avoid duplicating the reduction logic.
| #pragma unroll | ||
| for (int s = THREADS_PER_WARP / 2; s > 0; s >>= 1) { | ||
| thread_amax = fmaxf(thread_amax, __shfl_xor_sync(0xFFFFFFFFu, thread_amax, s)); | ||
| } |
There was a problem hiding this comment.
We can also reuse the warp-level reducer here.
| vec3.load_from(base + (v + 3 * blockDim.x) * NVEC); | ||
| #pragma unroll | ||
| for (int i = 0; i < NVEC; ++i) { | ||
| acc0 = max_val(acc0, abs_val(vec0.data.elt[i])); |
There was a problem hiding this comment.
If this becomes performance-critical, we could consider reusing helper functions from transformer_engine/common/util/ptx.cuh, such as abs_max_2x(), or adding a custom inline PTX helper if the existing ones do not fit.
| }; | ||
|
|
||
| template <typename IType, typename OType> | ||
| __device__ __forceinline__ void fast_scaled_fp8_cvt_4(const IType *input, OType *output, |
There was a problem hiding this comment.
Could we rewrite this function to reduce code duplication and hardcoding? ptx::mul_cvt_4x is overloaded, so we should be able to use it directly and let the types select the correct overload. If input and output are properly aligned for the corresponding ptx::FPx4<> types, we can reinterpret them as 4-element packed values and avoid duplicating the conversion logic, e.g.
using IType4 = ptx::FPx4<IType>;
using OType4 = ptx::FPx4<OType>;
const IType4& in4x = *reinterpret_cast<const IType4*>(input);
OType4& out4x = *reinterpret_cast<OType4*>(output);
ptx::mul_cvt_4x(out4x, in4x, scale_2x);
| // Varying-first columnwise intentionally falls through to the | ||
| // generic group_cast_fp8_kernel below. The dedicated | ||
| // group_cast_fp8_varying_first_tile_kernel used a grid-stride | ||
| // row loop that spilled to 118 registers/thread (2 blocks/SM, |
There was a problem hiding this comment.
Could we revise this comment to avoid relying on machine-specific benchmark data? The current numbers are tied to a specific machine/build configuration and may not generalize.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: