Skip to content

Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114

Draft
vthumbe1503 wants to merge 59 commits into
NVIDIA:mainfrom
vthumbe1503:current_scaling_group_quant
Draft

Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114
vthumbe1503 wants to merge 59 commits into
NVIDIA:mainfrom
vthumbe1503:current_scaling_group_quant

Conversation

@vthumbe1503

Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Orchestra and others added 30 commits May 14, 2026 18:42
Route grouped Float8CurrentScalingQuantizer through the existing grouped quantize entry point, prepare per-group current-scaling metadata with existing amax/scale helpers, and add focused tests plus a GB200 bandwidth benchmark.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_5507e814ee50f9ff304a4ce708d19768
Orchestra-Run: run_516e1e26891f4ce7d4cde07147c10862
Use wider vectorized grouped FP8 cast-transpose tiles and vectorized masked stores for rowwise and columnwise outputs. Capture all benchmark modes in a single post-warmup profiler range.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_3d6e33eab11e293d72eb4394bad76a81
Orchestra-Run: run_a6e2c31d5fdf850594f71438e53148da
Route non-MXFP8 grouped-linear bias backward through group_quantize plus grouped dbias while keeping MXFP8 bgrad_group_quantize fusion intact. Add focused zero-row grouped FP8 coverage and a current-scaling GroupedLinear bias-backward regression.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_ab566800d87047635cd27f9e64661abe
Orchestra-Run: run_5f9bfef17ccd854232c54d56268ef9e8
Use packed FP8 conversion and reduce columnwise transpose staging register and synchronization overhead in group_cast_fp8_kernel.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_7a830e018ceac8de0018280bd0740a54
Orchestra-Run: run_d2f1df4ffc2265d9cfa5ed01028ee476
Match the grouped FP8 conversion helper's element-count template parameter to Vec's uint32_t parameter so rowwise, columnwise, and activation instantiations can build.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_30c4b6ddb896e5ea3ca5b54731d2c819
Orchestra-Run: run_e95cdbb445943304622b95736f0eca49
Use cached grouped offsets to avoid launching FP8 quantization over unused overallocated rows, permit larger grouped backing buffers when split metadata is present, and tighten full-tile vector paths in the grouped FP8 cast kernel.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_c5db93823dc101838cb1323e283cd6e9
Orchestra-Run: run_063e2e4c724e132612aa5597d6765c9b
Use the FP8 grouped output logical shape when computing the tensor-scaling launch grid so overallocated buffers with active metadata avoid empty tail-row launches while preserving the allocated-shape fallback.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_b4abb47c990404d73142342a19996a3f
Orchestra-Run: run_8f09e7b9d7af9754ef505f2e2ce3cf90
Use larger grouped FP8 tiles with 8-warp CTAs and 16-row columnwise store fragments. Treat uniform overallocated FP8 grouped outputs as same-shape wrappers during output reuse so the timed path avoids varying-shape metadata overlaunch. Add overallocated current-scaling coverage for all grouped FP8 direction modes.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_3f98ac9c5b82192ec289d8d2a9816c7f
Orchestra-Run: run_83f3b99cc950024cf06ee836337fbf72
Stage columnwise transpose fragments through shared-memory vectors with smaller columnwise row tiles to reduce register pressure and barrier overhead while preserving the larger rowwise-only store path.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_495cc57eef84749103aded403a508d99
Orchestra-Run: run_53e038e90f83186bc6c12cb722c986b5
Add fast grouped FP8 rowwise and full-tile columnwise paths for uniform active groups while preserving the general fallback for varying grouped metadata.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_4c33e88776c8a7148e9da5cc2bae84ea
Orchestra-Run: run_2caaff219394eb5d59b7be38ab2bf346
Add a same-shape bidirectional full-tile kernel with wider input vectors and rowwise stores while preserving the existing rowwise-only, columnwise-only, and fallback grouped paths.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_87cec01d94f053b53e3c79377ad379ab
Orchestra-Run: run_ed48db00a730a4bf56530d551ecd350e
Route same-shape rowwise+columnwise grouped FP8 tensor-scaling quantization through the compact full-tile transpose schedule instead of the wide dynamic-shared-memory variant, preserving the existing single-direction and fallback paths.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_fdddd228a620039c024b4ecf43f3ab42
Orchestra-Run: run_30a2753eea9c893cb0fadb8233da8ce6
Hint the rowwise stores in the full-tile rowwise+columnwise grouped FP8 path as streaming global stores to reduce cache/writeback pressure without changing single-direction launch geometry.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_bf82020032e68276f4e47c65f62d97ae
Orchestra-Run: run_754ea4c864f329c6f2003b413b723c43
Add graph-safe grouped FP8 tensor-scaling metadata, support varying last dimensions, preserve same-shape fast paths, adjust grouped FP8 columnwise allocation by architecture, and expand benchmark/test coverage for the reviewed shape cases.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_d104e74844fbc3d3b1a98a8d96d76037
Orchestra-Run: run_1314e997c61ffb92ff7120b0b26f0318
Map varying-last columnwise tiles per group to avoid tile-alignment device errors, expand nonaligned boundary coverage, and restore same-shape benchmark baseline criteria.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_14e0e7973300d26f69550bc0aee21acc
Orchestra-Run: run_2f42b8ba138ed8b2b4d9dc90b92caf85
Add grouped FP8 benchmark support for baseline-ref same-session reports and update the benchmark request to enforce same-shape baseline regression checks alongside the per-mode throughput thresholds.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_d0cada957a4aafdce9d52be86520e182
Orchestra-Run: run_4da74e9bdb4f4a4c72304a385692b6c9
Update the grouped FP8 benchmark driver so same-session baseline checks out and builds the baseline ref into an isolated PyTorch install, verifies the baseline subprocess loads those shared objects, and preserves the required same-shape baseline comparisons.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_4fd88b172872f547f2f2d0053dce73d1
Orchestra-Run: run_6a44ee0467ffff47d4b278de6127354d
Preserve grouped delayed-FP8 amax metadata and keep unsupported FP8 tensor-scaling quantizers out of the grouped GEMM path.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_2aa8e6bf11ae356f4b34d4540b508031
Orchestra-Run: run_302681098d7f4e05b0ad96450f2d9826
Set NVTE_GROUPED_LINEAR_SINGLE_PARAM inside the targeted state-dict tests so they exercise the gated single grouped parameter path without relying on external environment setup.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_261900f987bdc9397965019983a77c41
Orchestra-Run: run_c6624e34717cbe121b3e0edcf490e3d3
Add a segmented flat rowwise kernel for varying-first grouped FP8 tensor-scaling outputs while preserving the existing same-shape fast path.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_c1b7020b27290318848ef6ac9048dd5f
Orchestra-Run: run_5c257b8a5d2e7e4aa95e67aa16436166
Omit the last_dims keyword when absent so the same-session baseline can run against the base extension, and refresh the benchmark request to include direct varying-last current-scaling coverage.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_c20a3c94fdc798e741a469bd7bb9c4df
Orchestra-Run: run_457448e6cba80fc63ac72b3db71c5fd0
Dispatch varying-first tensor-scaling work per group to reduce inactive-tail CTAs and offset lookup overhead while preserving same-shape fast paths and graph-safe device metadata handling.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_d84e1fefef8641e558df064452f4689b
Orchestra-Run: run_a361ca2f93fcec53ddd60dd99f4639e5
Add a no-tail rowwise flat kernel for aligned varying-first grouped FP8 tensor-scaling quantization and keep same-shape and varying-last dispatch isolated. Tighten benchmark profiler timing so post-warmup measured ranges exclude profiler start overhead.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_2e478be1fb38195f36d25c51320dc01f
Orchestra-Run: run_9a133a75fa3d98dc3b1a63b0ff4d84af
Write grouped FP8 benchmark reports to a sidecar path by default and label script reports as benchmark_raw_report/v1 so regular 100-iteration measurements are fetched instead of the wrapper command report.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_27770b2e1d490b1a3053244d4b4ce248
Orchestra-Run: run_214052d0c1316e231443d645183a2675
Write the grouped FP8 benchmark JSON once and mirror the completed sidecar to ORCHESTRA_BENCHMARK_RAW_REPORT when running under Orchestra so the benchmark fetch path can parse the emitted measurements.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_b2e2747371204088c8e3f7cf10263164
Orchestra-Run: run_1d4ea38266807c8acb59143ee74ba241
Allow the grouped FP8 benchmark to use ORCHESTRA_BENCHMARK_RAW_REPORT as its primary output so the benchmark wrapper can fetch canonical measurements directly.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_10fdcfef6b70de4676b7843e4bbfac31
Orchestra-Run: run_4ce57df9e86d6d03a26f7aa95ac252cc
Write canonical grouped FP8 benchmark measurements to ORCHESTRA_BENCHMARK_RAW_REPORT in a small schema-shaped payload so the benchmark wrapper can materialize per-mode threshold evidence.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_3e862eebd585c74f2a58497fedea3511
Orchestra-Run: run_3770ab3dbbf51329d0839b3d10a91b5c
Write candidate_results and nonempty measurements into the Orchestra raw report path, and fail fast if the benchmark cannot produce threshold-ready evidence.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_aa587a7b0d35aa9c2b715ec1b7c8bec3
Orchestra-Run: run_b42870e5d5e142a6cbf53bb5a3cafc2e
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 and others added 17 commits May 28, 2026 22:42
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…nel for varying all dims

- Generalized the splits_to_offsets CUDA kernel to accept last_dims pointers along with strides, templatizing correctly to avoid performance impact on non-varying cases.
- Updated nvte_splits_to_offsets and nvte_splits_to_offsets_multi APIs to pass last_dims to the kernel.
- Integrated last_dims into create_grouped_tensor across None, Float8, Float8CurrentScaling, Float8Blockwise, MXFP8, and NVFP4 Quantizers.
- Updated group_quantize, bgrad_group_quantize, and nvfp4_group_quantize_with_amax to correctly accept and pass last_dims along with precomputed tensor_offsets.
- Reverted all changes to splits_to_offsets_multi and nvte_splits_to_offsets_multi.
- Adapted the existing `kernel` in splits_to_offsets.cu using template parameters instead of introducing a separate kernel, preserving original kernel name and layout.
- Kept the optimized nvte_splits_to_offsets behavior with varying first and last dims.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title Current Scaling Group Quantization Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize Jun 10, 2026
vthumbe1503 and others added 4 commits June 10, 2026 23:27
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…TransformerEngine into current_scaling_group_quant
constexpr size_t GROUPED_AMAX_MIN_ELTS_PER_BLOCK = 8 * 1024; // ~16KB of bf16

// Zero per-tensor amax buffer so the main kernel can use atomicMax updates.
__launch_bounds__(GROUPED_AMAX_KERNEL_THREADS) __global__

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we can clear the contiguous buffer with cudaMemsetAsync() on the same stream before launching the main kernel. This preserves ordering and may avoid the overhead of a separate zeroing kernel.

// 4-way vectorized load and reduce
for (; v + 3 * blockDim.x < total_vecs; v += 4 * blockDim.x) {
IVecT vec0, vec1, vec2, vec3;
vec0.load_from(base + v * NVEC);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s benchmark the LDG path here. If it turns out to be a bottleneck, we may need to switch to TMA.

s_block_amax[warp_id] = thread_amax;
}
__syncthreads();
// Reduce amax within the warp

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can reuse the warp-level reducer from transformer_engine/common/utils.cuh here to avoid duplicating the reduction logic.

#pragma unroll
for (int s = THREADS_PER_WARP / 2; s > 0; s >>= 1) {
thread_amax = fmaxf(thread_amax, __shfl_xor_sync(0xFFFFFFFFu, thread_amax, s));
}

@Oleg-Goncharov Oleg-Goncharov Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also reuse the warp-level reducer here.

vec3.load_from(base + (v + 3 * blockDim.x) * NVEC);
#pragma unroll
for (int i = 0; i < NVEC; ++i) {
acc0 = max_val(acc0, abs_val(vec0.data.elt[i]));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this becomes performance-critical, we could consider reusing helper functions from transformer_engine/common/util/ptx.cuh, such as abs_max_2x(), or adding a custom inline PTX helper if the existing ones do not fit.

};

template <typename IType, typename OType>
__device__ __forceinline__ void fast_scaled_fp8_cvt_4(const IType *input, OType *output,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rewrite this function to reduce code duplication and hardcoding? ptx::mul_cvt_4x is overloaded, so we should be able to use it directly and let the types select the correct overload. If input and output are properly aligned for the corresponding ptx::FPx4<> types, we can reinterpret them as 4-element packed values and avoid duplicating the conversion logic, e.g.

using IType4 = ptx::FPx4<IType>;
using OType4 = ptx::FPx4<OType>;

const IType4& in4x = *reinterpret_cast<const IType4*>(input);
OType4& out4x = *reinterpret_cast<OType4*>(output);

ptx::mul_cvt_4x(out4x, in4x, scale_2x);

// Varying-first columnwise intentionally falls through to the
// generic group_cast_fp8_kernel below. The dedicated
// group_cast_fp8_varying_first_tile_kernel used a grid-stride
// row loop that spilled to 118 registers/thread (2 blocks/SM,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we revise this comment to avoid relying on machine-specific benchmark data? The current numbers are tied to a specific machine/build configuration and may not generalize.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants