Skip to content

Commit 9e5a847

Browse files
timmoon10claudepre-commit-ci[bot]
authored
Optimize function that loads pointers on GPU (#3001)
* Remove unnecessary heap allocations Avoid constructing temporary std::vector when converting NVTEBasicTensor to SimpleTensor. Avoid string operations in multi-tensor swizzle. Avoid temporary std::vector when checking scale tensors. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Avoid heap allocation in Tensor::flat_first_dim/flat_last_dim Tensor::shape() returns a std::vector<size_t> by value, allocating on the heap. flat_first_dim and flat_last_dim only need to walk the dims, so the allocation was pure overhead in hot paths. Introduce Tensor::compute_shape() returning an NVTEShape (fixed inline buffer, no heap) as the single source of truth for the format-dependent shape logic. shape() is now a thin std::vector wrapper around it for callers that want a vector; flat_first_dim and flat_last_dim call compute_shape() directly. Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add Tensor::flat_2d_dims() to compute both matrix dims in one pass flat_first_dim() and flat_last_dim() each called compute_shape() independently. flat_2d_dims() computes both in a single pass; the scalar helpers now delegate to it. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use flat_2d_dims() throughout common lib Replace all paired flat_first_dim() + flat_last_dim() calls on the same tensor with a single flat_2d_dims() call. Saves one compute_shape() per tensor in CheckScaleTensorShape, the multi-tensor swizzle loop, and various cast/GEMM dispatch paths. Also adds reserve() to the local vectors in nvte_multi_tensor_swizzle_scaling_factors to avoid reallocation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Generalize API for CUDA-Graph-safe copy to GPU. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Dedup swizzle logic in get_device_pointer_for_data_and_scales Replace the inline swizzle implementation with a call to multi_tensor_swizzle_scales_for_gemm, which has identical logic (16B-aligned contiguous output buffer, TensorWrapper construction, nvte_multi_tensor_swizzle_scaling_factors kernel). Swizzled pointers are read back from the updated TensorWrappers after the call. Add reserve() to vectors in multi_tensor_swizzle_scales_for_gemm_impl now that this function is on the hot path for get_device_pointer_for_data_and_scales. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Make separate functions for load data_ptrs and swizzle + load data_ptrs. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Change function name to nvte_load_value_on_device Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix code review issues before opening PR - Use size_t in kernel tail loop (was int64_t) - Zero-initialize Payload before memcpy (Payload{}) - Rename Payload members to kMaxBytes/kVectorSize/kMaxVectors (linter) - Consistent at::empty shape pattern: {static_cast<int64_t>(N)} - Drop intermediate swizzled_scales_bytes variable - Add comment explaining uniform-stride assumption in transform_and_load_data_ptrs_on_device - Rename sfb_buffer -> _sfb_buffer (keepalive, not directly used) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Formatter and review suggestions from @greptile-apps Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add Shape class wrapping NVTEShape Provides a std::vector<size_t>-like interface around NVTEShape without heap allocation, used as the return type of Tensor::shape() in place of the previous std::vector. Disambiguate cute::Shape from transformer_engine::Shape in the hadamard_transform kernels. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Make SimpleTensor stack-allocatable Store shape in Shape class rather than std::vector. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make Shape conversion constructors explicit Signed-off-by: Tim Moon <tmoon@nvidia.com> * Make conversion from Shape to std::vector explicit Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add batched NVTETensor create/destroy Expose nvte_create_tensors and nvte_destroy_tensors so multi-tensor callers can amortize the TensorAllocator mutex across N tensors instead of locking once per call. nvte_destroy_tensors was already defined internally but not declared in the public header. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use batched NVTETensor allocator in transform_and_load_data_ptrs_on_device The uniform swizzle path constructed 2N TensorWrappers and then extracted their raw NVTETensors into separate vectors. Replace with a single 2N nvte_create_tensors call into one contiguous buffer (inputs in the first half, outputs in the second), an RAII guard for nvte_destroy_tensors, and a local set_param lambda for the setters. Drops the separate pack pass and reduces the allocator mutex acquisitions from 4N to 2 per call. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Expand usage of batched NVTETensor allocator Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use string_view in tensor checking functions Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Tweak function names Signed-off-by: Tim Moon <tmoon@nvidia.com> * Pass std::string_view by value in Check*Tensor helpers string_view is already a (ptr, len) reference — passing by const-ref adds an indirection without benefit. Matches the C++ Core Guidelines F.16 recommendation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Review suggestions from @ptrendx Expand internal usage of Shape class. Zero-initialize in Shape::resize. Make sure dynamic smem querying is per-device. Reuse logic for batched and single tensor alloc/dealloc. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add MultiTensorWrapper for batched NVTETensor allocation Thin RAII wrapper around a batched nvte_create_tensors / nvte_destroy_tensors pair, with operator[], data(), iteration, and implicit conversion to NVTETensor* for multi-tensor C APIs. Replaces the ad-hoc DestroyGuard struct used at each call site in recipe.cpp, swizzle.cpp, and utils.cpp. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ace2a96 commit 9e5a847

36 files changed

Lines changed: 724 additions & 496 deletions

transformer_engine/common/cast/dispatch/quantize.cuh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
9898
CheckOutputTensor(*output_tensor, "output", false);
9999

100100
// Choose kernel
101-
int32_t rows = input_tensor->flat_first_dim();
102-
int32_t cols = input_tensor->flat_last_dim();
101+
const auto [rows, cols] = input_tensor->flat_2d_dims();
103102
auto dtype = input_tensor->dtype();
104103
const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4;
105104
const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled;
@@ -260,8 +259,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
260259
CheckOutputTensor(*output_tensor, "output", false);
261260

262261
// Choose kernel
263-
int32_t rows = grad_tensor->flat_first_dim();
264-
int32_t cols = grad_tensor->flat_last_dim();
262+
const auto [rows, cols] = grad_tensor->flat_2d_dims();
265263
auto dtype = grad_tensor->dtype();
266264
const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled;
267265
NVTE_CHECK(nvfp4_use_4over6 || output_tensor->nvfp4_e4m3_max == 448,
@@ -396,8 +394,7 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou
396394
// output list here is allowed to have empty tensor
397395

398396
// Choose kernel
399-
int32_t rows = input_tensor->flat_first_dim();
400-
int32_t cols = input_tensor->flat_last_dim();
397+
const auto [rows, cols] = input_tensor->flat_2d_dims();
401398
auto dtype = input_tensor->dtype();
402399

403400
const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled;

transformer_engine/common/cast/fp8/quantize_fp8.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,7 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
391391
using namespace quantize_2D_kernel;
392392
checkCuDriverContext(stream);
393393

394-
const size_t rows = input.flat_first_dim();
395-
const size_t cols = input.flat_last_dim();
394+
const auto [rows, cols] = input.flat_2d_dims();
396395
const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y);
397396
const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X);
398397
const size_t blocks_Y = chunks_Y;
@@ -406,7 +405,7 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
406405

407406
if constexpr (IS_DBIAS) {
408407
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
409-
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
408+
NVTE_CHECK(dbias->data.shape == Shape{cols}, "Wrong shape of DBias.");
410409
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
411410

412411
if (workspace->data.dptr == nullptr) {

transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
261261
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
262262
const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1;
263263

264-
const size_t rows = input.flat_first_dim();
265-
const size_t cols = input.flat_last_dim();
264+
const auto [rows, cols] = input.flat_2d_dims();
266265
const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y);
267266
const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X);
268267

transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
867867
NVTE_CHECK(dbias->data.dtype == input->dtype(),
868868
"DBias must have the same type as input_tensor.");
869869

870-
std::vector<size_t> expected_shape_dbias_tensor = {num_tensors, last_logical_dim};
870+
Shape expected_shape_dbias_tensor = {num_tensors, last_logical_dim};
871871
NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias.");
872872

873873
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");

transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
578578
constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT);
579579

580580
// Tensor dimensions
581-
const size_t rows = input.flat_first_dim();
582-
const size_t cols = input.flat_last_dim();
581+
const auto [rows, cols] = input.flat_2d_dims();
583582

584583
// Tensor chunk handled by each CUDA block
585584
constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64;
@@ -622,7 +621,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
622621

623622
if constexpr (IS_DBIAS) {
624623
NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input.");
625-
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
624+
NVTE_CHECK(dbias->data.shape == Shape{cols}, "Wrong shape of DBias.");
626625
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
627626

628627
if (workspace->data.dptr == nullptr) {

transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
9595
const int e4m3_max = input.nvfp4_e4m3_max;
9696

9797
constexpr int FP4_BLOCK_SIZE = 16;
98-
const size_t N = input.flat_first_dim();
99-
const size_t M = input.flat_last_dim();
98+
const auto [N, M] = input.flat_2d_dims();
10099

101100
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
102101
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");

transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,8 +783,7 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop,
783783

784784
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
785785

786-
const size_t rows = input.flat_first_dim();
787-
const size_t cols = input.flat_last_dim();
786+
const auto [rows, cols] = input.flat_2d_dims();
788787

789788
NVTE_CHECK(rows % 32 == 0,
790789
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
@@ -835,7 +834,7 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop,
835834
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
836835
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
837836
"RNG state should contain 2 64-bit values.");
838-
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
837+
NVTE_CHECK(rng_state_te_tensor.data.shape == Shape{2},
839838
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
840839
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
841840
}

transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ inline void compute_rowwise_amax(const Tensor &input, const Tensor *noop, Tensor
121121
#if FP4_TYPE_SUPPORTED
122122
using namespace rowwise_amax_kernel;
123123

124-
const size_t rows = input.flat_first_dim();
125-
const size_t cols = input.flat_last_dim();
124+
const auto [rows, cols] = input.flat_2d_dims();
126125
NVTE_CHECK(cols % ROWWISE_AMAX_SF_VEC_SIZE == 0,
127126
"Row-scaled NVFP4 quantization requires last dim divisible by ",
128127
ROWWISE_AMAX_SF_VEC_SIZE, ".");
@@ -1359,8 +1358,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
13591358
"Transposed scaling tensor must be allocated");
13601359
}
13611360

1362-
const size_t rows = input.flat_first_dim();
1363-
const size_t cols = input.flat_last_dim();
1361+
const auto [rows, cols] = input.flat_2d_dims();
13641362

13651363
NVTE_CHECK(rows % 32 == 0,
13661364
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
@@ -1391,7 +1389,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
13911389
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
13921390
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
13931391
"RNG state should contain 2 64-bit values.");
1394-
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
1392+
NVTE_CHECK(rng_state_te_tensor.data.shape == Shape{2},
13951393
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
13961394
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
13971395
}

transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop,
718718
"Transposed scaling tensor must be allocated");
719719
}
720720

721-
const size_t rows = input.flat_first_dim();
722-
const size_t cols = input.flat_last_dim();
721+
const auto [rows, cols] = input.flat_2d_dims();
723722

724723
NVTE_CHECK(rows % 32 == 0,
725724
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
@@ -750,7 +749,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop,
750749
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
751750
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
752751
"RNG state should contain 2 64-bit values.");
753-
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
752+
NVTE_CHECK(rng_state_te_tensor.data.shape == Shape{2},
754753
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
755754
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
756755
}

transformer_engine/common/comm_gemm/comm_gemm.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,9 @@ int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) {
130130
void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
131131
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
132132
bool transb) {
133-
const auto a0 = a->flat_first_dim();
134-
const auto a1 = a->flat_last_dim();
135-
const auto b0 = b->flat_first_dim();
136-
const auto b1 = b->flat_last_dim();
137-
const auto d0 = d->flat_first_dim();
138-
const auto d1 = d->flat_last_dim();
133+
const auto [a0, a1] = a->flat_2d_dims();
134+
const auto [b0, b1] = b->flat_2d_dims();
135+
const auto [d0, d1] = d->flat_2d_dims();
139136

140137
if (transa) {
141138
NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1);
@@ -169,12 +166,9 @@ void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
169166
void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
170167
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
171168
bool transb) {
172-
const auto a0 = a->flat_first_dim();
173-
const auto a1 = a->flat_last_dim();
174-
const auto b0 = b->flat_first_dim();
175-
const auto b1 = b->flat_last_dim();
176-
const auto d0 = d->flat_first_dim();
177-
const auto d1 = d->flat_last_dim();
169+
const auto [a0, a1] = a->flat_2d_dims();
170+
const auto [b0, b1] = b->flat_2d_dims();
171+
const auto [d0, d1] = d->flat_2d_dims();
178172

179173
if (transa) {
180174
NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0);
@@ -213,12 +207,9 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
213207
void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
214208
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
215209
bool transb) {
216-
const auto a0 = a->flat_first_dim();
217-
const auto a1 = a->flat_last_dim();
218-
const auto b0 = b->flat_first_dim();
219-
const auto b1 = b->flat_last_dim();
220-
const auto d0 = d->flat_first_dim();
221-
const auto d1 = d->flat_last_dim();
210+
const auto [a0, a1] = a->flat_2d_dims();
211+
const auto [b0, b1] = b->flat_2d_dims();
212+
const auto [d0, d1] = d->flat_2d_dims();
222213

223214
if (transa) {
224215
NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0);

0 commit comments

Comments
 (0)