Skip to content

Optimize function that loads pointers on GPU#3001

Open
timmoon10 wants to merge 30 commits into
NVIDIA:mainfrom
timmoon10:tmoon/optimize-get_device_pointer_for_data_and_scales
Open

Optimize function that loads pointers on GPU#3001
timmoon10 wants to merge 30 commits into
NVIDIA:mainfrom
timmoon10:tmoon/optimize-get_device_pointer_for_data_and_scales

Conversation

@timmoon10
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 commented May 16, 2026

Description

tex.get_device_pointer_for_data_and_scales has two problems:

  1. It has significant CPU overhead (see [PyTorch] Reduce CPU overhead in grouped MLP block #2897). In a representative benchmark on a GB200, it takes ~70 us per call.
  2. The meaning is extremely unintuitive. The most natural interpretation is that it takes a FP8/MXFP8/NVFP4 tensor and returns pointers as two int s. But actually it takes the buffers from multiple MXFP8/NVFP4 tensors (all assumed to have the same shape), swizzles the scaling factors, and transfers the pointers to a GPU array in a CUDA Graph-friendly way.

This PR makes several optimizations to reduce CPU overhead, mostly by avoiding heap allocations and mutex acquisition. I've also attempted to make the functionality more general and logical:

  • nvte_copy_host_to_device_via_kernel_args: A general function for copying a small amount of data to GPU in a CUDA Graph-friendly way. Unlike nvte_convert_pointers_to_tensor, it makes no assumptions that the data is a list of pointers.
  • tex.copy_data_ptrs_to_device: Takes a list of tensors and puts their data pointers into a GPU buffer.
  • tex.transform_and_copy_data_ptrs_to_device: Performs a user-specified transform on a list of tensors and puts the resulting data pointers into a GPU buffer. Currently it only supports scale swizzles on uniformly shaped tensors, but the transform names help make the contracts explicit.

With these changes, per-call CPU runtime has dropped from 70 us to 31 us on a GB200 node.

This is progress toward #2897.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring
  • Performance optimization

Changes

  • Add transformer_engine::Tensor::flat_2d_dims to compute first and last dims simultaneously
  • Generalize and rename nvte_load_value_on_device
  • Refactor and rename tex.load_data_ptrs_on_device and tex.transform_and_load_data_ptrs_on_device
  • Add internal wrapper class for NVTEShape with similar API as std::vector
  • Remove heap allocations in transformer_engine::SimpleTensor
  • Remove heap allocations in transformer_engine::Tensor shape functions
  • Add batched tensor allocation and deallocation to reduce mutex overhead
  • Avoid heap allocations in tensor checking functions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 8 commits May 15, 2026 01:35
Avoid constructing temporary std::vector when converting NVTEBasicTensor to SimpleTensor. Avoid string operations in multi-tensor swizzle. Avoid temporary std::vector when checking scale tensors.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Tensor::shape() returns a std::vector<size_t> by value, allocating
on the heap. flat_first_dim and flat_last_dim only need to walk
the dims, so the allocation was pure overhead in hot paths.

Introduce Tensor::compute_shape() returning an NVTEShape (fixed
inline buffer, no heap) as the single source of truth for the
format-dependent shape logic. shape() is now a thin std::vector
wrapper around it for callers that want a vector; flat_first_dim
and flat_last_dim call compute_shape() directly.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
flat_first_dim() and flat_last_dim() each called compute_shape()
independently. flat_2d_dims() computes both in a single pass; the
scalar helpers now delegate to it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replace all paired flat_first_dim() + flat_last_dim() calls on the
same tensor with a single flat_2d_dims() call. Saves one compute_shape()
per tensor in CheckScaleTensorShape, the multi-tensor swizzle loop, and
various cast/GEMM dispatch paths.

Also adds reserve() to the local vectors in
nvte_multi_tensor_swizzle_scaling_factors to avoid reallocation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replace the inline swizzle implementation with a call to
multi_tensor_swizzle_scales_for_gemm, which has identical logic
(16B-aligned contiguous output buffer, TensorWrapper construction,
nvte_multi_tensor_swizzle_scaling_factors kernel). Swizzled pointers
are read back from the updated TensorWrappers after the call.

Add reserve() to vectors in multi_tensor_swizzle_scales_for_gemm_impl
now that this function is on the hot path for get_device_pointer_for_data_and_scales.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 16, 2026

Greptile Summary

This PR refactors and optimizes the function that loads tensor pointers onto the GPU, achieving a 2× reduction in per-call CPU latency (~70 µs → ~31 µs). The optimizations eliminate heap allocations and mutex acquisitions by replacing std::vector<size_t> shapes with a stack-allocated Shape class wrapping NVTEShape, adding a batched tensor allocator, and generalizing the host-to-device copy kernel to work with arbitrary byte buffers rather than only pointer arrays.

  • Introduces Shape (stack-allocated wrapper for NVTEShape with a std::vector-like interface), refactors SimpleTensor/Tensor shape members, and adds flat_2d_dims() to compute both flat dimensions in one pass.
  • Splits tex.get_device_pointer_for_data_and_scales into two cleaner, more general functions — tex.copy_data_ptrs_to_device and tex.transform_and_copy_data_ptrs_to_device — with explicit transform-type strings replacing boolean flags.
  • Adds nvte_create_tensors / nvte_destroy_tensors batch APIs and a MultiTensorWrapper RAII helper to amortize mutex acquisition cost across multiple tensor allocations.

Confidence Score: 5/5

The refactoring is safe to merge — the core math for scale shape validation is provably equivalent to the old code, the new kernel correctly handles partial-vector tail bytes, and all existing call sites use uniformly-shaped tensors.

The changes are a well-scoped performance refactoring with no regressions in the core compute paths. The only non-trivial concerns are style-level: the per-tensor index was dropped from batch-validation error messages, and the uniform shape contract in transform_and_copy_data_ptrs_to_device is documented by naming convention but not enforced at runtime. Neither affects correctness for the current callers.

transformer_engine/pytorch/csrc/extensions/utils.cpp — the uniform-shape assumption for scale tensors should ideally be validated at runtime rather than relying solely on the transform-type name to communicate the contract.

Important Files Changed

Filename Overview
transformer_engine/common/common.h Introduces stack-allocated Shape class wrapping NVTEShape with a std::vector-like interface; replaces std::vector<size_t> shape members in SimpleTensor; adds flat_2d_dims() to Tensor; converts Check* function signatures to std::string_view. Changes are well-formed.
transformer_engine/common/transformer_engine.cpp Adds batched TensorAllocator::Allocate(mode, out, N) / delegates single-tensor Free to batch Free; refactors scale-shape validation to use flat_2d_dims() and std::array; adds nvte_create_tensors / nvte_destroy_tensors. Math for expected scale shapes is equivalent to the old code.
transformer_engine/common/util/utils.cu Replaces the pointer-specific write kernel with a general nvte_copy_host_to_device_via_kernel that copies arbitrary byte buffers via kernel arguments; deprecated nvte_convert_pointers_to_tensor now delegates to it. The new kernel correctly handles partial-vector tail bytes and chunks payloads into 2 KiB batches.
transformer_engine/pytorch/csrc/extensions/utils.cpp Splits old get_device_pointer_for_data_and_scales into copy_data_ptrs_to_device and transform_and_copy_data_ptrs_to_device; the new implementation uses the first tensor's scale shape for all tensors without validating that remaining tensors share the same shape.
transformer_engine/pytorch/csrc/extensions/swizzle.cpp Replaces per-tensor TensorWrapper vector with a single MultiTensorWrapper batch allocation; caches output dtype/shape before the kernel launch to avoid re-reading from NVTETensors. Logic is equivalent to the previous code.
transformer_engine/common/include/transformer_engine/transformer_engine.h Adds nvte_create_tensors, nvte_destroy_tensors API declarations and MultiTensorWrapper RAII class with correct move semantics and null-safe destructor.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Updates call sites from deprecated get_device_pointer_for_data_and_scales to separate copy_data_ptrs_to_device and transform_and_copy_data_ptrs_to_device calls; swizzled-scale buffer is retained via keepalive variables as before.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Same call-site migration as forward pass, using uniform_mxfp8_columnwise_swizzle for the backward weight-access pattern. Keepalive buffers are correctly named and retained.
transformer_engine/common/swizzle/swizzle.cu Uses flat_2d_dims() to compute both dimensions in one call; removes per-tensor index from error message names, losing specificity when batch validation fails.

Sequence Diagram

sequenceDiagram
    participant PY as Python (forward/backward MLP)
    participant TEX as tex (pybind)
    participant CPP as C++ utils.cpp
    participant KERN as CUDA Kernels

    PY->>TEX: copy_data_ptrs_to_device(data_tensors, device)
    TEX->>CPP: collect uint64_t ptrs on host
    CPP->>KERN: nvte_copy_host_to_device_via_kernel (via kernel args)
    KERN-->>CPP: ptrs written to device buffer
    CPP-->>PY: ptrs_device (at::Tensor)

    PY->>TEX: transform_and_copy_data_ptrs_to_device("uniform_mxfp8_rowwise_swizzle", scales, device)
    TEX->>CPP: create MultiTensorWrapper (batch alloc, 1 mutex acquire)
    CPP->>KERN: nvte_multi_tensor_swizzle_scaling_factors
    KERN-->>CPP: swizzled_scales written to device buffer
    CPP->>KERN: nvte_copy_host_to_device_via_kernel (swizzled scale ptrs)
    KERN-->>CPP: ptrs written to device buffer
    CPP-->>PY: (ptrs_device, swizzled_scales buffer)

    PY->>KERN: GEMM kernel (reads data via ptrs_device, scales via swizzled scale ptrs)
Loading

Reviews (10): Last reviewed commit: "Merge branch 'main' into tmoon/optimize-..." | Re-trigger Greptile

Comment thread transformer_engine/common/transformer_engine.cpp Outdated
Comment thread transformer_engine/common/util/utils.cu
Comment thread transformer_engine/common/common.h Outdated
Comment thread transformer_engine/common/util/utils.cu Outdated
NVTE_CHECK(data_tensors[0].is_cuda(), "data_tensors must be on CUDA.");
const auto device = data_tensors[0].device();
auto stream = at::cuda::getCurrentCUDAStream();
std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_on_device(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not committed to this name. I based it on std::transform. I suppose "map" would be more Python-focused, but that sounds worse.

Comment thread transformer_engine/common/util/utils.cu
Comment thread transformer_engine/common/transformer_engine.cpp Outdated
timmoon10 and others added 3 commits May 16, 2026 11:49
- Use size_t in kernel tail loop (was int64_t)
- Zero-initialize Payload before memcpy (Payload{})
- Rename Payload members to kMaxBytes/kVectorSize/kMaxVectors (linter)
- Consistent at::empty shape pattern: {static_cast<int64_t>(N)}
- Drop intermediate swizzled_scales_bytes variable
- Add comment explaining uniform-stride assumption in
  transform_and_load_data_ptrs_on_device
- Rename sfb_buffer -> _sfb_buffer (keepalive, not directly used)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the tmoon/optimize-get_device_pointer_for_data_and_scales branch from 7946e5d to 48cc585 Compare May 16, 2026 11:53
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 18, 2026

Seems a lot of those changes would basically not be needed if we did not use the std::vector in Tensor/SimpleTensor and just used NVTEShape everywhere - this would effectively make SimpleTensor and NVTEBasicTensor the same thing (we could even do the constructor in the public header, just behing the if cplusplus guard).

Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning up the APIs. Looks much nicer now. CPU overheads being caused by heap allocations of shape, makes me wonder whether we should revive this PR to standardize on NVTEShape yo avoid back and forth between vector<size_t> and NVTE_Shape

Comment thread transformer_engine/common/util/utils.cu Outdated
Comment on lines +505 to +508
fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_load_data_ptrs_on_device(
"uniform_mxfp8_columnwise_swizzle",
[w._columnwise_scale_inv for w in grouped_fc2_weight],
swizzle=True,
rowwise=False,
data_dtype=grouped_fc2_weight[0]._fp8_dtype,
device,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other optimization can be to load both fc1 and fc2 data and scale inv togegther at the start of backward. I am hoping it wouldnt make the code ugly.

Comment thread transformer_engine/common/common.h Outdated
SimpleTensor() : SimpleTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32) {}
SimpleTensor &operator=(const NVTEBasicTensor &tensor) {
dptr = tensor.data_ptr;
shape.assign(tensor.shape.data, tensor.shape.data + tensor.shape.ndim);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when you say heap allocations being done redundantly again and again. Do you mean the vector to NVTEShape conversions?

I rememember this problem being observed even with a basic te linear profiling. And I hadnt gotten this PR merged.
#2514

which essentially standadizes to use NVTEShape everywhere instead of using vector at all to avoid bouncing back and forth between the two allocations. Maybe it might be worth to revive the PR?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, assigning an NVTEBasicTensor to a SimpleTensor would trigger the constructor and then the move operator. This would allocate an std::vector, move it, and deallocate the old std::vector.

One other approach I was thinking about was implementing a Shape class that wraps around NVTEShape and has a similar API as std::vector. That way we can keep the nice ergonomics, while avoiding heap allocations.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried your other approach in the 2514 PR above, but eventually had removed it due to some complications. I have refer back to my notes on why it didnt work out for me.

But here is the commit that reverted it
b599776

I had called it NVTEShapeWrapper and implemented all the vector based APIs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One complication I do remember was to change a lot of attention interfaces to have NVTEShapeWrapper instead of using vector.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've found that adding a cast operator to std::vector helps reduce the number of places we need to change the interfaces.

timmoon10 and others added 4 commits May 20, 2026 01:25
Provides a std::vector<size_t>-like interface around NVTEShape without
heap allocation, used as the return type of Tensor::shape() in place of
the previous std::vector. Disambiguate cute::Shape from
transformer_engine::Shape in the hadamard_transform kernels.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Store shape in Shape class rather than std::vector.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci

timmoon10 and others added 7 commits May 21, 2026 01:31
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Expose nvte_create_tensors and nvte_destroy_tensors so multi-tensor
callers can amortize the TensorAllocator mutex across N tensors
instead of locking once per call. nvte_destroy_tensors was already
defined internally but not declared in the public header.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
…evice

The uniform swizzle path constructed 2N TensorWrappers and then
extracted their raw NVTETensors into separate vectors. Replace with a
single 2N nvte_create_tensors call into one contiguous buffer (inputs
in the first half, outputs in the second), an RAII guard for
nvte_destroy_tensors, and a local set_param lambda for the setters.
Drops the separate pack pass and reduces the allocator mutex
acquisitions from 4N to 2 per call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci

Comment thread transformer_engine/common/util/utils.cu Outdated
transformer_engine::DType data_dtype, scale_dtype;
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
data_dtype = transformer_engine::DType::kFloat8E4M3;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to hardcode data_dtype = kFloat8E4M3 here?

Copy link
Copy Markdown
Member Author

@timmoon10 timmoon10 May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't actually access the fp8e4m3 values when swizzling, this is a fake configuration so the tensor passes validation checks.

Comment thread transformer_engine/common/common.h Outdated
Comment thread transformer_engine/common/util/utils.cu Outdated
Comment thread transformer_engine/common/util/utils.cu
timmoon10 and others added 3 commits May 26, 2026 17:41
Signed-off-by: Tim Moon <tmoon@nvidia.com>
string_view is already a (ptr, len) reference — passing by const-ref
adds an indirection without benefit. Matches the C++ Core Guidelines
F.16 recommendation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci

Comment thread transformer_engine/common/normalization/layernorm/ln_api.cpp Outdated
return max_smem;
return max_smem;
};
static int cached_val = query_max_smem();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice but wrong - we need to be able to support this value per GPU rather than a single global, similarly to the other CUDA properties that we have in the CUDA runtime helpers (also, this function should live together with those other functions).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to handle the multi-device case.

check_scale_inv_shapes);
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]",
check_scale_inv_shapes);
CheckInputTensor(*input[i], "scaling_factor_input", check_scale_inv_shapes);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is making the error quality worse. I think the better approach would be to have a version of this function that still performs this string concatenation, but does so only when we actually want to raise an error.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could implement the Check*Tensor functions with variadic templates and unpack the args in NVTE_CHECK. Alternatively we could pass in a closure that creates the name. These seem overkill though. Both approaches seem excessive for what should be a quick sanity check.

Comment thread transformer_engine/common/common.h
Comment thread transformer_engine/common/common.h Outdated
Comment thread transformer_engine/common/transformer_engine.cpp
Comment thread transformer_engine/pytorch/csrc/extensions/recipe.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/swizzle.cpp Outdated
transformer_engine::DType scale_dtype;
if (is_fp8_dtype(data_dtype)) {
// Swizzle scales for GEMM, with uniform tensor sizes
const bool uniform_mxfp8_rowwise_swizzle = transform_type == "uniform_mxfp8_rowwise_swizzle";
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to make it possibly the fastest while still using strings, we should start the strings with the things that are different, rather than start with the "uniform_" part that will be the same in every case (and therefore the loop that this will become will need to go over that part every time no matter what).

timmoon10 and others added 4 commits May 26, 2026 22:21
Expand internal usage of Shape class. Zero-initialize in Shape::resize. Make sure dynamic smem querying is per-device. Reuse logic for batched and single tensor alloc/dealloc.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Thin RAII wrapper around a batched nvte_create_tensors /
nvte_destroy_tensors pair, with operator[], data(), iteration, and
implicit conversion to NVTETensor* for multi-tensor C APIs. Replaces
the ad-hoc DestroyGuard struct used at each call site in
recipe.cpp, swizzle.cpp, and utils.cpp.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants