Commit 9e5a847
Optimize function that loads pointers on GPU (#3001)
* Remove unnecessary heap allocations
Avoid constructing temporary std::vector when converting NVTEBasicTensor to SimpleTensor. Avoid string operations in multi-tensor swizzle. Avoid temporary std::vector when checking scale tensors.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Avoid heap allocation in Tensor::flat_first_dim/flat_last_dim
Tensor::shape() returns a std::vector<size_t> by value, allocating
on the heap. flat_first_dim and flat_last_dim only need to walk
the dims, so the allocation was pure overhead in hot paths.
Introduce Tensor::compute_shape() returning an NVTEShape (fixed
inline buffer, no heap) as the single source of truth for the
format-dependent shape logic. shape() is now a thin std::vector
wrapper around it for callers that want a vector; flat_first_dim
and flat_last_dim call compute_shape() directly.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Add Tensor::flat_2d_dims() to compute both matrix dims in one pass
flat_first_dim() and flat_last_dim() each called compute_shape()
independently. flat_2d_dims() computes both in a single pass; the
scalar helpers now delegate to it.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Use flat_2d_dims() throughout common lib
Replace all paired flat_first_dim() + flat_last_dim() calls on the
same tensor with a single flat_2d_dims() call. Saves one compute_shape()
per tensor in CheckScaleTensorShape, the multi-tensor swizzle loop, and
various cast/GEMM dispatch paths.
Also adds reserve() to the local vectors in
nvte_multi_tensor_swizzle_scaling_factors to avoid reallocation.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Generalize API for CUDA-Graph-safe copy to GPU.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Dedup swizzle logic in get_device_pointer_for_data_and_scales
Replace the inline swizzle implementation with a call to
multi_tensor_swizzle_scales_for_gemm, which has identical logic
(16B-aligned contiguous output buffer, TensorWrapper construction,
nvte_multi_tensor_swizzle_scaling_factors kernel). Swizzled pointers
are read back from the updated TensorWrappers after the call.
Add reserve() to vectors in multi_tensor_swizzle_scales_for_gemm_impl
now that this function is on the hot path for get_device_pointer_for_data_and_scales.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Make separate functions for load data_ptrs and swizzle + load data_ptrs.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Change function name to nvte_load_value_on_device
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Fix code review issues before opening PR
- Use size_t in kernel tail loop (was int64_t)
- Zero-initialize Payload before memcpy (Payload{})
- Rename Payload members to kMaxBytes/kVectorSize/kMaxVectors (linter)
- Consistent at::empty shape pattern: {static_cast<int64_t>(N)}
- Drop intermediate swizzled_scales_bytes variable
- Add comment explaining uniform-stride assumption in
transform_and_load_data_ptrs_on_device
- Rename sfb_buffer -> _sfb_buffer (keepalive, not directly used)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Formatter and review suggestions from @greptile-apps
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Add Shape class wrapping NVTEShape
Provides a std::vector<size_t>-like interface around NVTEShape without
heap allocation, used as the return type of Tensor::shape() in place of
the previous std::vector. Disambiguate cute::Shape from
transformer_engine::Shape in the hadamard_transform kernels.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Make SimpleTensor stack-allocatable
Store shape in Shape class rather than std::vector.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Make Shape conversion constructors explicit
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Make conversion from Shape to std::vector explicit
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Add batched NVTETensor create/destroy
Expose nvte_create_tensors and nvte_destroy_tensors so multi-tensor
callers can amortize the TensorAllocator mutex across N tensors
instead of locking once per call. nvte_destroy_tensors was already
defined internally but not declared in the public header.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Use batched NVTETensor allocator in transform_and_load_data_ptrs_on_device
The uniform swizzle path constructed 2N TensorWrappers and then
extracted their raw NVTETensors into separate vectors. Replace with a
single 2N nvte_create_tensors call into one contiguous buffer (inputs
in the first half, outputs in the second), an RAII guard for
nvte_destroy_tensors, and a local set_param lambda for the setters.
Drops the separate pack pass and reduces the allocator mutex
acquisitions from 4N to 2 per call.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Expand usage of batched NVTETensor allocator
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Use string_view in tensor checking functions
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Tweak function names
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Pass std::string_view by value in Check*Tensor helpers
string_view is already a (ptr, len) reference — passing by const-ref
adds an indirection without benefit. Matches the C++ Core Guidelines
F.16 recommendation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Review suggestions from @ptrendx
Expand internal usage of Shape class. Zero-initialize in Shape::resize. Make sure dynamic smem querying is per-device. Reuse logic for batched and single tensor alloc/dealloc.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Add MultiTensorWrapper for batched NVTETensor allocation
Thin RAII wrapper around a batched nvte_create_tensors /
nvte_destroy_tensors pair, with operator[], data(), iteration, and
implicit conversion to NVTETensor* for multi-tensor C APIs. Replaces
the ad-hoc DestroyGuard struct used at each call site in
recipe.cpp, swizzle.cpp, and utils.cpp.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent ace2a96 commit 9e5a847
36 files changed
Lines changed: 724 additions & 496 deletions
File tree
- transformer_engine
- common
- cast
- dispatch
- fp8
- mxfp8
- nvfp4
- specialized
- comm_gemm
- gemm
- hadamard_transform
- include/transformer_engine
- normalization
- layernorm
- rmsnorm
- swizzle
- transpose
- util
- pytorch
- csrc
- extensions
- ops/fused
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
98 | 98 | | |
99 | 99 | | |
100 | 100 | | |
101 | | - | |
102 | | - | |
| 101 | + | |
103 | 102 | | |
104 | 103 | | |
105 | 104 | | |
| |||
260 | 259 | | |
261 | 260 | | |
262 | 261 | | |
263 | | - | |
264 | | - | |
| 262 | + | |
265 | 263 | | |
266 | 264 | | |
267 | 265 | | |
| |||
396 | 394 | | |
397 | 395 | | |
398 | 396 | | |
399 | | - | |
400 | | - | |
| 397 | + | |
401 | 398 | | |
402 | 399 | | |
403 | 400 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
391 | 391 | | |
392 | 392 | | |
393 | 393 | | |
394 | | - | |
395 | | - | |
| 394 | + | |
396 | 395 | | |
397 | 396 | | |
398 | 397 | | |
| |||
406 | 405 | | |
407 | 406 | | |
408 | 407 | | |
409 | | - | |
| 408 | + | |
410 | 409 | | |
411 | 410 | | |
412 | 411 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
261 | 261 | | |
262 | 262 | | |
263 | 263 | | |
264 | | - | |
265 | | - | |
| 264 | + | |
266 | 265 | | |
267 | 266 | | |
268 | 267 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
867 | 867 | | |
868 | 868 | | |
869 | 869 | | |
870 | | - | |
| 870 | + | |
871 | 871 | | |
872 | 872 | | |
873 | 873 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
578 | 578 | | |
579 | 579 | | |
580 | 580 | | |
581 | | - | |
582 | | - | |
| 581 | + | |
583 | 582 | | |
584 | 583 | | |
585 | 584 | | |
| |||
622 | 621 | | |
623 | 622 | | |
624 | 623 | | |
625 | | - | |
| 624 | + | |
626 | 625 | | |
627 | 626 | | |
628 | 627 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
95 | 95 | | |
96 | 96 | | |
97 | 97 | | |
98 | | - | |
99 | | - | |
| 98 | + | |
100 | 99 | | |
101 | 100 | | |
102 | 101 | | |
| |||
Lines changed: 2 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
783 | 783 | | |
784 | 784 | | |
785 | 785 | | |
786 | | - | |
787 | | - | |
| 786 | + | |
788 | 787 | | |
789 | 788 | | |
790 | 789 | | |
| |||
835 | 834 | | |
836 | 835 | | |
837 | 836 | | |
838 | | - | |
| 837 | + | |
839 | 838 | | |
840 | 839 | | |
841 | 840 | | |
| |||
Lines changed: 3 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
121 | 121 | | |
122 | 122 | | |
123 | 123 | | |
124 | | - | |
125 | | - | |
| 124 | + | |
126 | 125 | | |
127 | 126 | | |
128 | 127 | | |
| |||
1359 | 1358 | | |
1360 | 1359 | | |
1361 | 1360 | | |
1362 | | - | |
1363 | | - | |
| 1361 | + | |
1364 | 1362 | | |
1365 | 1363 | | |
1366 | 1364 | | |
| |||
1391 | 1389 | | |
1392 | 1390 | | |
1393 | 1391 | | |
1394 | | - | |
| 1392 | + | |
1395 | 1393 | | |
1396 | 1394 | | |
1397 | 1395 | | |
| |||
Lines changed: 2 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
718 | 718 | | |
719 | 719 | | |
720 | 720 | | |
721 | | - | |
722 | | - | |
| 721 | + | |
723 | 722 | | |
724 | 723 | | |
725 | 724 | | |
| |||
750 | 749 | | |
751 | 750 | | |
752 | 751 | | |
753 | | - | |
| 752 | + | |
754 | 753 | | |
755 | 754 | | |
756 | 755 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
139 | 136 | | |
140 | 137 | | |
141 | 138 | | |
| |||
169 | 166 | | |
170 | 167 | | |
171 | 168 | | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
178 | 172 | | |
179 | 173 | | |
180 | 174 | | |
| |||
213 | 207 | | |
214 | 208 | | |
215 | 209 | | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | | - | |
220 | | - | |
221 | | - | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
222 | 213 | | |
223 | 214 | | |
224 | 215 | | |
| |||
0 commit comments