Skip to content

Commit 48cc585

Browse files
committed
Formatter and review suggestions from @greptile-apps
Signed-off-by: Tim Moon <tmoon@nvidia.com>
1 parent 24e9e7f commit 48cc585

8 files changed

Lines changed: 65 additions & 86 deletions

File tree

transformer_engine/common/common.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,18 +339,14 @@ struct Tensor {
339339
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
340340
* as a (D1*D2*...*D(n-1), Dn) matrix.
341341
*/
342-
size_t flat_first_dim() const {
343-
return flat_2d_dims()[0];
344-
}
342+
size_t flat_first_dim() const { return flat_2d_dims()[0]; }
345343

346344
/*! Matrix width after tensor is flattened to 2D
347345
*
348346
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
349347
* as a (D1*D2*...*D(n-1), Dn) matrix.
350348
*/
351-
size_t flat_last_dim() const {
352-
return flat_2d_dims()[1];
353-
}
349+
size_t flat_last_dim() const { return flat_2d_dims()[1]; }
354350
};
355351

356352
struct GroupedTensor {

transformer_engine/common/include/transformer_engine/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ extern "C" {
3232
* \param[in] stream CUDA stream for the operation.
3333
*/
3434
void nvte_load_value_on_device(const void *host_ptr, void *device_ptr, size_t num_bytes,
35-
cudaStream_t stream);
35+
cudaStream_t stream);
3636

3737
/*! \deprecated Use nvte_load_value_on_device instead.
3838
*

transformer_engine/common/swizzle/swizzle.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ constexpr int MXFP8_BLOCK_SIZE = 32;
2525
constexpr int NVFP4_BLOCK_SIZE = 16;
2626

2727
int get_max_dynamic_smem() {
28-
auto query_max_smem = [] () -> int {
28+
auto query_max_smem = []() -> int {
2929
int device{0}, max_smem{0};
3030
NVTE_CHECK_CUDA(cudaGetDevice(&device));
31-
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
31+
NVTE_CHECK_CUDA(
32+
cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
3233
return max_smem;
3334
};
3435
static int cached_val = query_max_smem();

transformer_engine/common/transformer_engine.cpp

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -98,28 +98,23 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
9898
if (t.has_data()) {
9999
constexpr std::array<size_t, 2> block_shape{1, 32};
100100
const std::array<size_t, 2> expected{
101-
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
102-
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])
103-
};
104-
NVTE_CHECK(t.scale_inv.shape.size() == 2
105-
&& t.scale_inv.shape[0] == expected[0]
106-
&& t.scale_inv.shape[1] == expected[1],
107-
"Tensor \"", name,
108-
"\" has invalid scale_inv shape (expected ", expected,
101+
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
102+
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])};
103+
NVTE_CHECK(t.scale_inv.shape.size() == 2 && t.scale_inv.shape[0] == expected[0] &&
104+
t.scale_inv.shape[1] == expected[1],
105+
"Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected,
109106
", got ", t.scale_inv.shape, ")");
110107
}
111108
if (t.has_columnwise_data()) {
112109
constexpr std::array<size_t, 2> block_shape{32, 1};
113110
const std::array<size_t, 2> expected{
114-
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[1]),
115-
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[0])
116-
};
117-
NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2
118-
&& t.columnwise_scale_inv.shape[0] == expected[0]
119-
&& t.columnwise_scale_inv.shape[1] == expected[1],
120-
"Tensor \"", name,
121-
"\" has invalid columnwise_scale_inv shape (expected ", expected,
122-
", got ", t.scale_inv.shape, ")");
111+
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[1]),
112+
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[0])};
113+
NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2 &&
114+
t.columnwise_scale_inv.shape[0] == expected[0] &&
115+
t.columnwise_scale_inv.shape[1] == expected[1],
116+
"Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ",
117+
expected, ", got ", t.columnwise_scale_inv.shape, ")");
123118
}
124119
} else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) {
125120
const auto [first_dim, last_dim] = t.flat_2d_dims();
@@ -128,29 +123,24 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
128123
constexpr std::array<size_t, 2> block_shape{1, 16};
129124
constexpr std::array<size_t, 2> block_alignment{128, 4};
130125
const std::array<size_t, 2> expected{
131-
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
132-
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])
133-
};
134-
NVTE_CHECK(t.scale_inv.shape.size() == 2
135-
&& t.scale_inv.shape[0] == expected[0]
136-
&& t.scale_inv.shape[1] == expected[1],
137-
"Tensor \"", name,
138-
"\" has invalid scale_inv shape (expected ", expected,
126+
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
127+
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])};
128+
NVTE_CHECK(t.scale_inv.shape.size() == 2 && t.scale_inv.shape[0] == expected[0] &&
129+
t.scale_inv.shape[1] == expected[1],
130+
"Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected,
139131
", got ", t.scale_inv.shape, ")");
140132
}
141133
if (t.has_columnwise_data()) {
142134
constexpr std::array<size_t, 2> block_shape{1, 16};
143135
constexpr std::array<size_t, 2> block_alignment{128, 4};
144136
const std::array<size_t, 2> expected{
145-
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[0]), block_alignment[0]),
146-
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[1]), block_alignment[1])
147-
};
148-
NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2
149-
&& t.columnwise_scale_inv.shape[0] == expected[0]
150-
&& t.columnwise_scale_inv.shape[1] == expected[1],
151-
"Tensor \"", name,
152-
"\" has invalid columnwise_scale_inv shape (expected ", expected,
153-
", got ", t.scale_inv.shape, ")");
137+
DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[0]), block_alignment[0]),
138+
DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[1]), block_alignment[1])};
139+
NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2 &&
140+
t.columnwise_scale_inv.shape[0] == expected[0] &&
141+
t.columnwise_scale_inv.shape[1] == expected[1],
142+
"Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ",
143+
expected, ", got ", t.columnwise_scale_inv.shape, ")");
154144
}
155145
}
156146
}

transformer_engine/common/util/utils.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
* See LICENSE for license information.
55
************************************************************************/
66

7+
#include <cuda_runtime.h>
78
#include <transformer_engine/utils.h>
89

910
#include <algorithm>
1011
#include <cstring>
1112

12-
#include <cuda_runtime.h>
13-
1413
#include "../common.h"
1514
#include "../util/logging.h"
1615

@@ -27,7 +26,7 @@ union Payload {
2726
};
2827

2928
constexpr size_t block_size = 512;
30-
constexpr size_t num_blocks = DIVUP(Payload::kMaxBytes / Payload::kVectorSize, block_size);
29+
constexpr size_t num_blocks = DIVUP(Payload::kMaxVectors, block_size);
3130

3231
__global__ void __launch_bounds__(block_size) kernel(Payload payload, size_t num_bytes, void *out) {
3332
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -56,7 +55,8 @@ void nvte_load_value_on_device(const void *host_ptr, void *device_ptr, size_t nu
5655

5756
// Check pointers
5857
NVTE_CHECK(host_ptr != nullptr, "Attempting to read ", num_bytes, " bytes from a null pointer.");
59-
NVTE_CHECK(device_ptr != nullptr, "Attempting to write ", num_bytes, " bytes into a null pointer.");
58+
NVTE_CHECK(device_ptr != nullptr, "Attempting to write ", num_bytes,
59+
" bytes into a null pointer.");
6060
NVTE_CHECK(reinterpret_cast<uintptr_t>(device_ptr) % Payload::kVectorSize == 0,
6161
"Device pointer is not aligned to ", Payload::kVectorSize, " bytes.");
6262

@@ -74,6 +74,7 @@ void nvte_load_value_on_device(const void *host_ptr, void *device_ptr, size_t nu
7474

7575
void nvte_convert_pointers_to_tensor(const uint64_t *host_ptrs, NVTETensor output, int64_t count,
7676
cudaStream_t stream) {
77+
NVTE_API_CALL(nvte_convert_pointers_to_tensor);
7778
using namespace transformer_engine;
7879
Tensor *out_tensor = convertNVTETensorCheck(output);
7980
nvte_load_value_on_device(host_ptrs, out_tensor->data.dptr,

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,9 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_
489489
at::Tensor load_data_ptrs_on_device(const std::vector<at::Tensor> &tensors,
490490
const c10::Device &device);
491491

492-
std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_on_device(const std::string &transform_type,
493-
const std::vector<at::Tensor> &tensors,
494-
const c10::Device &device);
495-
492+
std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_on_device(
493+
const std::string &transform_type, const std::vector<at::Tensor> &tensors,
494+
const c10::Device &device);
496495

497496
/***************************************************************************************************
498497
* Support THD format for Context Parallel

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
489489
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
490490
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
491491
py::call_guard<py::gil_scoped_release>());
492-
m.def("load_data_ptrs_on_device",
493-
&transformer_engine::pytorch::load_data_ptrs_on_device,
494-
py::arg("tensors"), py::arg("device"),
495-
py::call_guard<py::gil_scoped_release>());
492+
m.def("load_data_ptrs_on_device", &transformer_engine::pytorch::load_data_ptrs_on_device,
493+
py::arg("tensors"), py::arg("device"), py::call_guard<py::gil_scoped_release>());
496494
m.def("transform_and_load_data_ptrs_on_device",
497495
&transformer_engine::pytorch::transform_and_load_data_ptrs_on_device,
498496
py::arg("transform_type"), py::arg("tensors"), py::arg("device"),

transformer_engine/pytorch/csrc/extensions/utils.cpp

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,14 @@ at::Tensor load_data_ptrs_on_device(const std::vector<at::Tensor> &tensors,
3232

3333
// Load pointers on device
3434
nvte_load_value_on_device(ptrs_host.data(), ptrs_device.data_ptr(),
35-
tensors.size() * sizeof(uint64_t),
36-
at::cuda::getCurrentCUDAStream());
35+
tensors.size() * sizeof(uint64_t), at::cuda::getCurrentCUDAStream());
3736

3837
return ptrs_device;
3938
}
4039

41-
std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_on_device(const std::string &transform_type,
42-
const std::vector<at::Tensor> &tensors,
43-
const c10::Device &device) {
40+
std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_on_device(
41+
const std::string &transform_type, const std::vector<at::Tensor> &tensors,
42+
const c10::Device &device) {
4443
const size_t num_tensors = tensors.size();
4544

4645
// Trivial cases
@@ -50,9 +49,8 @@ std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_o
5049
}
5150
if (num_tensors == 0) {
5251
// No input tensors, return tensor with no elements
53-
return {
54-
at::empty({int64_t{0}}, at::TensorOptions().dtype(at::kLong).device(device)),
55-
std::nullopt};
52+
return {at::empty({int64_t{0}}, at::TensorOptions().dtype(at::kLong).device(device)),
53+
std::nullopt};
5654
}
5755

5856
// CUDA stream
@@ -62,9 +60,7 @@ std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_o
6260
const bool uniform_mxfp8_rowwise_swizzle = transform_type == "uniform_mxfp8_rowwise_swizzle";
6361
const bool uniform_mxfp8_colwise_swizzle = transform_type == "uniform_mxfp8_columnwise_swizzle";
6462
const bool uniform_nvfp4_swizzle = transform_type == "uniform_nvfp4_swizzle";
65-
if (uniform_mxfp8_rowwise_swizzle
66-
|| uniform_mxfp8_colwise_swizzle
67-
|| uniform_nvfp4_swizzle) {
63+
if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle || uniform_nvfp4_swizzle) {
6864
// Tensor format
6965
NVTEScalingMode scaling_mode = NVTE_INVALID_SCALING;
7066
if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle) {
@@ -76,16 +72,16 @@ std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_o
7672
// Data types
7773
transformer_engine::DType data_dtype, scale_dtype;
7874
switch (scaling_mode) {
79-
case NVTE_MXFP8_1D_SCALING:
80-
data_dtype = transformer_engine::DType::kFloat8E4M3;
81-
scale_dtype = transformer_engine::DType::kFloat8E8M0;
82-
break;
83-
case NVTE_NVFP4_1D_SCALING:
84-
data_dtype = transformer_engine::DType::kFloat4E2M1;
85-
scale_dtype = transformer_engine::DType::kFloat8E4M3;
86-
break;
87-
default:
88-
NVTE_ERROR("Unsupported case.");
75+
case NVTE_MXFP8_1D_SCALING:
76+
data_dtype = transformer_engine::DType::kFloat8E4M3;
77+
scale_dtype = transformer_engine::DType::kFloat8E8M0;
78+
break;
79+
case NVTE_NVFP4_1D_SCALING:
80+
data_dtype = transformer_engine::DType::kFloat4E2M1;
81+
scale_dtype = transformer_engine::DType::kFloat8E4M3;
82+
break;
83+
default:
84+
NVTE_ERROR("Unsupported case.");
8985
}
9086

9187
// Scale shape
@@ -128,8 +124,8 @@ std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_o
128124
for (size_t i = 0; i < num_tensors; ++i) {
129125
inputs_nvte.emplace_back(scaling_mode);
130126
outputs_nvte.emplace_back(scaling_mode);
131-
auto& input_nvte = inputs_nvte.back();
132-
auto& output_nvte = outputs_nvte.back();
127+
auto &input_nvte = inputs_nvte.back();
128+
auto &output_nvte = outputs_nvte.back();
133129
output_nvte.set_with_gemm_swizzled_scales(true);
134130
void *in_scale_ptr = tensors[i].data_ptr();
135131
void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride;
@@ -150,28 +146,26 @@ std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_o
150146
std::vector<NVTETensor> inputs_nvte_raw, outputs_nvte_raw;
151147
inputs_nvte_raw.reserve(num_tensors);
152148
outputs_nvte_raw.reserve(num_tensors);
153-
for (auto& t : inputs_nvte) inputs_nvte_raw.push_back(t.data());
154-
for (auto& t : outputs_nvte) outputs_nvte_raw.push_back(t.data());
149+
for (auto &t : inputs_nvte) inputs_nvte_raw.push_back(t.data());
150+
for (auto &t : outputs_nvte) outputs_nvte_raw.push_back(t.data());
155151

156152
// Launch kernel
157153
nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(),
158-
inputs_nvte_raw.size(),
159-
stream);
154+
inputs_nvte_raw.size(), stream);
160155

161156
// Collect data pointers
162157
std::vector<uint64_t> ptrs_host;
163158
ptrs_host.reserve(num_tensors);
164159
for (size_t i = 0; i < num_tensors; ++i) {
165-
ptrs_host.push_back(reinterpret_cast<uintptr_t>(swizzled_scales_dptr
166-
+ i * swizzled_scales_stride));
160+
ptrs_host.push_back(
161+
reinterpret_cast<uintptr_t>(swizzled_scales_dptr + i * swizzled_scales_stride));
167162
}
168163

169164
// Load pointers on device
170165
auto ptrs_device = at::empty({static_cast<int64_t>(num_tensors)},
171166
at::TensorOptions().dtype(at::kLong).device(device));
172167
nvte_load_value_on_device(ptrs_host.data(), ptrs_device.data_ptr(),
173-
num_tensors * sizeof(uint64_t),
174-
stream);
168+
num_tensors * sizeof(uint64_t), stream);
175169

176170
return {std::move(ptrs_device), std::move(swizzled_scales)};
177171
}

0 commit comments

Comments
 (0)