Skip to content

Commit e355c38

Browse files
committed
fix merge conflicts, now things working
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
1 parent 53a41b2 commit e355c38

13 files changed

Lines changed: 812 additions & 66 deletions

File tree

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 217 files

tests/cpp/operator/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_executable(test_operator
2727
test_memset.cu
2828
test_splits_to_offsets.cu
2929
test_multi_cast_transpose.cu
30+
test_multi_tensor_adam_mxfp8.cu
3031
test_multi_padding.cu
3132
test_multi_unpadding.cu
3233
test_causal_softmax.cu
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include <cuda_fp8.h>
8+
#include <cuda_runtime.h>
9+
#include <gtest/gtest.h>
10+
11+
#include <algorithm>
12+
#include <cmath>
13+
#include <cstring>
14+
#include <vector>
15+
16+
#include <transformer_engine/cast.h>
17+
#include <transformer_engine/multi_tensor.h>
18+
19+
#include "../test_common.h"
20+
21+
using namespace transformer_engine;
22+
using namespace test;
23+
24+
namespace {
25+
26+
uint8_t fp8_to_u8(fp8e4m3 v) {
27+
uint8_t out = 0;
28+
std::memcpy(&out, &v, sizeof(uint8_t));
29+
return out;
30+
}
31+
32+
uint8_t fp8_to_u8(fp8e5m2 v) {
33+
uint8_t out = 0;
34+
std::memcpy(&out, &v, sizeof(uint8_t));
35+
return out;
36+
}
37+
38+
void run_mxfp8_adam_test(DType fp8_dtype) {
39+
const std::vector<size_t> shape1{64, 128};
40+
const std::vector<size_t> shape2{32, 64};
41+
const float lr = 1e-3f;
42+
const float beta1 = 0.9f;
43+
const float beta2 = 0.999f;
44+
const float eps = 1e-8f;
45+
const int step = 1;
46+
const int mode = 1;
47+
const int bias_correction = 1;
48+
const float weight_decay = 0.0f;
49+
50+
// Run with 25 tensors > 24[MXFP8_MAX_TENSORS] to check
51+
// the chunking logic
52+
const size_t tensor_count = 25;
53+
std::vector<std::vector<size_t>> shapes;
54+
shapes.reserve(tensor_count);
55+
for (size_t i = 0; i < tensor_count; ++i) {
56+
shapes.push_back((i % 2 == 0) ? shape1 : shape2);
57+
}
58+
59+
std::vector<std::string> names;
60+
names.reserve(tensor_count * 11);
61+
std::vector<Tensor> g;
62+
std::vector<Tensor> p;
63+
std::vector<Tensor> m;
64+
std::vector<Tensor> v;
65+
std::vector<Tensor> p_ref_t;
66+
std::vector<Tensor> m_ref_t;
67+
std::vector<Tensor> v_ref_t;
68+
std::vector<Tensor> q_ref;
69+
std::vector<Tensor> dq;
70+
std::vector<Tensor> dq_ref;
71+
std::vector<Tensor> q;
72+
g.reserve(tensor_count);
73+
p.reserve(tensor_count);
74+
m.reserve(tensor_count);
75+
v.reserve(tensor_count);
76+
p_ref_t.reserve(tensor_count);
77+
m_ref_t.reserve(tensor_count);
78+
v_ref_t.reserve(tensor_count);
79+
q_ref.reserve(tensor_count);
80+
dq.reserve(tensor_count);
81+
dq_ref.reserve(tensor_count);
82+
q.reserve(tensor_count);
83+
84+
for (size_t i = 0; i < tensor_count; ++i) {
85+
const std::vector<size_t> &shape = shapes[i];
86+
names.push_back("g" + std::to_string(i));
87+
g.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
88+
names.push_back("p" + std::to_string(i));
89+
p.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
90+
names.push_back("m" + std::to_string(i));
91+
m.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
92+
names.push_back("v" + std::to_string(i));
93+
v.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
94+
95+
fillUniform(&g.back());
96+
fillUniform(&p.back());
97+
std::fill_n(m.back().rowwise_cpu_dptr<float>(), product(m.back().rowwise_shape()), 0.0f);
98+
std::fill_n(v.back().rowwise_cpu_dptr<float>(), product(v.back().rowwise_shape()), 0.0f);
99+
m.back().from_cpu();
100+
v.back().from_cpu();
101+
102+
names.push_back("p_ref_" + std::to_string(i));
103+
p_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
104+
names.push_back("m_ref_" + std::to_string(i));
105+
m_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
106+
names.push_back("v_ref_" + std::to_string(i));
107+
v_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
108+
const size_t n = shape[0] * shape[1];
109+
std::memcpy(p_ref_t.back().rowwise_cpu_dptr<float>(), p.back().rowwise_cpu_dptr<float>(),
110+
n * sizeof(float));
111+
std::memcpy(m_ref_t.back().rowwise_cpu_dptr<float>(), m.back().rowwise_cpu_dptr<float>(),
112+
n * sizeof(float));
113+
std::memcpy(v_ref_t.back().rowwise_cpu_dptr<float>(), v.back().rowwise_cpu_dptr<float>(),
114+
n * sizeof(float));
115+
p_ref_t.back().from_cpu();
116+
m_ref_t.back().from_cpu();
117+
v_ref_t.back().from_cpu();
118+
119+
names.push_back("q_ref_" + std::to_string(i));
120+
q_ref.emplace_back(names.back().c_str(), shape, fp8_dtype, true, true, NVTE_MXFP8_1D_SCALING);
121+
q_ref.back().set_with_gemm_swizzled_scales(false);
122+
123+
names.push_back("dq" + std::to_string(i));
124+
dq.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
125+
names.push_back("dq_ref_" + std::to_string(i));
126+
dq_ref.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false);
127+
128+
names.push_back("q" + std::to_string(i));
129+
q.emplace_back(names.back().c_str(), shape, fp8_dtype, true, true, NVTE_MXFP8_1D_SCALING);
130+
q.back().set_with_gemm_swizzled_scales(false);
131+
}
132+
133+
Tensor noop("noop", std::vector<size_t>{1}, DType::kInt32, true, false);
134+
int zero = 0;
135+
std::memcpy(noop.rowwise_cpu_dptr<int>(), &zero, sizeof(int));
136+
noop.from_cpu();
137+
138+
std::vector<std::vector<NVTETensor>> lists(8);
139+
std::vector<TensorWrapper> extra_wrappers;
140+
extra_wrappers.reserve(tensor_count * 4);
141+
142+
auto add_tensor = [&](Tensor &g, Tensor &p, Tensor &m, Tensor &v, Tensor &q) {
143+
lists[0].push_back(g.data());
144+
lists[1].push_back(p.data());
145+
lists[2].push_back(m.data());
146+
lists[3].push_back(v.data());
147+
148+
extra_wrappers.emplace_back(q.rowwise_dptr(), q.rowwise_shape(), fp8_dtype);
149+
lists[4].push_back(extra_wrappers.back().data());
150+
extra_wrappers.emplace_back(q.columnwise_dptr(), q.columnwise_shape(), fp8_dtype);
151+
lists[5].push_back(extra_wrappers.back().data());
152+
extra_wrappers.emplace_back(q.rowwise_scale_inv_dptr(), q.rowwise_scale_inv_shape(),
153+
DType::kByte);
154+
lists[6].push_back(extra_wrappers.back().data());
155+
extra_wrappers.emplace_back(q.columnwise_scale_inv_dptr(), q.columnwise_scale_inv_shape(),
156+
DType::kByte);
157+
lists[7].push_back(extra_wrappers.back().data());
158+
};
159+
160+
for (size_t i = 0; i < tensor_count; ++i) {
161+
add_tensor(g[i], p[i], m[i], v[i], q[i]);
162+
}
163+
164+
std::vector<NVTETensor *> list_ptrs;
165+
list_ptrs.reserve(lists.size());
166+
for (auto &l : lists) {
167+
list_ptrs.push_back(l.data());
168+
}
169+
170+
nvte_multi_tensor_adam_mxfp8_cuda(65536, noop.data(), list_ptrs.data(), lists.size(),
171+
lists[0].size(), static_cast<NVTEDType>(fp8_dtype), lr, beta1,
172+
beta2, eps, step, mode, bias_correction, weight_decay, 0);
173+
174+
std::vector<std::vector<NVTETensor>> ref_lists(4);
175+
for (size_t i = 0; i < tensor_count; ++i) {
176+
ref_lists[0].push_back(g[i].data());
177+
ref_lists[1].push_back(p_ref_t[i].data());
178+
ref_lists[2].push_back(m_ref_t[i].data());
179+
ref_lists[3].push_back(v_ref_t[i].data());
180+
}
181+
std::vector<NVTETensor *> ref_list_ptrs;
182+
ref_list_ptrs.reserve(ref_lists.size());
183+
for (auto &l : ref_lists) {
184+
ref_list_ptrs.push_back(l.data());
185+
}
186+
187+
nvte_multi_tensor_adam_cuda(65536, noop.data(), ref_list_ptrs.data(), ref_lists.size(),
188+
ref_lists[0].size(), lr, beta1, beta2, eps, step, mode,
189+
bias_correction, weight_decay, 0);
190+
191+
for (size_t i = 0; i < tensor_count; ++i) {
192+
nvte_quantize(p_ref_t[i].data(), q_ref[i].data(), 0);
193+
nvte_dequantize(q[i].data(), dq[i].data(), 0);
194+
nvte_dequantize(q_ref[i].data(), dq_ref[i].data(), 0);
195+
}
196+
197+
cudaDeviceSynchronize();
198+
199+
for (size_t i = 0; i < tensor_count; ++i) {
200+
q[i].to_cpu();
201+
p[i].to_cpu();
202+
m[i].to_cpu();
203+
v[i].to_cpu();
204+
q_ref[i].to_cpu();
205+
dq[i].to_cpu();
206+
dq_ref[i].to_cpu();
207+
p_ref_t[i].to_cpu();
208+
m_ref_t[i].to_cpu();
209+
v_ref_t[i].to_cpu();
210+
}
211+
212+
for (size_t i = 0; i < lists[0].size(); ++i) {
213+
const Tensor &g_i = g[i];
214+
const Tensor &p_i = p[i];
215+
const Tensor &m_i = m[i];
216+
const Tensor &v_i = v[i];
217+
Tensor &q_i = q[i];
218+
const Tensor &p_ref_t_i = p_ref_t[i];
219+
const Tensor &m_ref_t_i = m_ref_t[i];
220+
const Tensor &v_ref_t_i = v_ref_t[i];
221+
Tensor &q_ref_i = q_ref[i];
222+
223+
compareResults("p", p_i, p_ref_t_i.rowwise_cpu_dptr<float>(), true, 0.0, 0.0, true, 0);
224+
compareResults("m", m_i, m_ref_t_i.rowwise_cpu_dptr<float>(), true, 0.0, 0.0, true, 0);
225+
compareResults("v", v_i, v_ref_t_i.rowwise_cpu_dptr<float>(), true, 0.0, 0.0, true, 0);
226+
227+
const Tensor &dq_i = dq[i];
228+
const Tensor &dq_ref_i = dq_ref[i];
229+
compareResults("dequantized", dq_i, dq_ref_i.rowwise_cpu_dptr<float>(), true, 0.0, 0.0, true,
230+
0);
231+
232+
const size_t rs = q_i.rowwise_scale_inv_shape().data[1];
233+
const size_t cs = q_i.columnwise_scale_inv_shape().data[1];
234+
const size_t rowwise_scale_size = q_i.rowwise_scale_inv_shape().data[0] * rs;
235+
const size_t colwise_scale_size = q_i.columnwise_scale_inv_shape().data[0] * cs;
236+
compareResults("rowwise_scale", q_i.rowwise_cpu_scale_inv_ptr<uint8_t>(),
237+
q_ref_i.rowwise_cpu_scale_inv_ptr<uint8_t>(), rowwise_scale_size, 0.0f);
238+
compareResults("colwise_scale", q_i.columnwise_cpu_scale_inv_ptr<uint8_t>(),
239+
q_ref_i.columnwise_cpu_scale_inv_ptr<uint8_t>(), colwise_scale_size, 0.0f);
240+
241+
uint8_t *row_data = nullptr;
242+
uint8_t *col_data = nullptr;
243+
uint8_t *row_data_ref = nullptr;
244+
uint8_t *col_data_ref = nullptr;
245+
if (fp8_dtype == DType::kFloat8E4M3) {
246+
row_data = reinterpret_cast<uint8_t *>(q_i.rowwise_cpu_dptr<fp8e4m3>());
247+
col_data = reinterpret_cast<uint8_t *>(q_i.columnwise_cpu_dptr<fp8e4m3>());
248+
row_data_ref = reinterpret_cast<uint8_t *>(q_ref_i.rowwise_cpu_dptr<fp8e4m3>());
249+
col_data_ref = reinterpret_cast<uint8_t *>(q_ref_i.columnwise_cpu_dptr<fp8e4m3>());
250+
} else {
251+
row_data = reinterpret_cast<uint8_t *>(q_i.rowwise_cpu_dptr<fp8e5m2>());
252+
col_data = reinterpret_cast<uint8_t *>(q_i.columnwise_cpu_dptr<fp8e5m2>());
253+
row_data_ref = reinterpret_cast<uint8_t *>(q_ref_i.rowwise_cpu_dptr<fp8e5m2>());
254+
col_data_ref = reinterpret_cast<uint8_t *>(q_ref_i.columnwise_cpu_dptr<fp8e5m2>());
255+
}
256+
const size_t data_size = q_i.rowwise_shape().data[0] * q_i.rowwise_shape().data[1];
257+
compareResults("rowwise_data", row_data, row_data_ref, data_size, 0.0f);
258+
compareResults("colwise_data", col_data, col_data_ref, data_size, 0.0f);
259+
}
260+
}
261+
262+
} // namespace
263+
264+
TEST(MultiTensorAdamMXFP8, E4M3) { run_mxfp8_adam_test(DType::kFloat8E4M3); }
265+
266+
TEST(MultiTensorAdamMXFP8, E5M2) { run_mxfp8_adam_test(DType::kFloat8E5M2); }

tests/cpp/test_common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ class Tensor {
200200
return tensor_.get_columnwise_data().data_ptr;
201201
}
202202

203+
void *rowwise_scale_inv_dptr() const {
204+
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
205+
return tensor_.get_rowwise_scale_inv().data_ptr;
206+
}
207+
208+
void *columnwise_scale_inv_dptr() const {
209+
NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!");
210+
return tensor_.get_columnwise_scale_inv().data_ptr;
211+
}
212+
203213
template <typename T>
204214
T *rowwise_cpu_dptr() const {
205215
NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");

tests/pytorch/distributed/run_fsdp2_fused_adam.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def get_recipe_from_string(recipe):
3636
SEQ_LEN = 32
3737
BATCH_PER_RANK = 2
3838
NUM_STEPS = 3
39+
LOCAL_RANK = None
3940

41+
def dist_print(msg):
42+
if LOCAL_RANK == 0:
43+
print(msg)
4044

4145
def save_custom_attrs(module):
4246
custom_attrs = {}
@@ -151,6 +155,8 @@ def test_fused_adam_fp8_master_weights(recipe=None):
151155
- Training loop completes without error
152156
- DTensor wrapping and QuantizedTensor local tensors are preserved
153157
"""
158+
global LOCAL_RANK
159+
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
154160
world_size, _, device = _setup()
155161

156162
model = _build_model(fp8_init=True, recipe=recipe)
@@ -183,7 +189,7 @@ def test_fused_adam_fp8_master_weights(recipe=None):
183189
loss = F.mse_loss(output, target)
184190
loss.backward()
185191
optimizer.step()
186-
192+
dist_print(f"Step {step} completed with loss {loss.item()}")
187193
# Verify optimizer states
188194
for param in model.parameters():
189195
state = optimizer.state[param]

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,6 @@ def test_fsdp2_dcp_output_parity_async(fp_recipe):
224224
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
225225
def test_fsdp2_safetensors_fp32_export(fp_recipe):
226226
"""Export FP32 model from optimizer master weights to safetensors."""
227-
if fp_recipe == "MXFP8BlockScaling":
228-
pytest.xfail(
229-
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
230-
"MXFP8 quantized tensors, causing illegal memory access"
231-
)
232227
_run_fused_adam_test("safetensors_fp32_export", fp_recipe)
233228

234229

transformer_engine/common/include/transformer_engine/multi_tensor.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,41 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
149149
const float weight_decay, const NVTEDType fp8_dtype,
150150
cudaStream_t stream);
151151

152+
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
153+
* when model parameters are in MXFP8 precision.
154+
*
155+
* The update is applied to FP32 master parameters, then the master
156+
* parameters are quantized to MXFP8 rowwise and columnwise data
157+
* (both are always required).
158+
*
159+
* \warning This API is **experimental** and subject to change.
160+
*
161+
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
162+
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
163+
* \param[in,out] tensor_lists 2D array of input tensors with 8 lists in order:
164+
* (0) gradients, (1) FP32 master params, (2) first moment,
165+
* (3) second moment, (4) rowwise MXFP8 data,
166+
* (5) columnwise MXFP8 data, (6) rowwise scale-inv,
167+
* (7) columnwise scale-inv.
168+
* \param[in] num_tensor_lists Size (dim0) of tensor_lists. Must be 8.
169+
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
170+
* \param[in] fp8_dtype MXFP8 element type for quantization (E4M3/E5M2).
171+
* \param[in] lr Learning rate.
172+
* \param[in] beta1 Coefficient for first moment of gradient.
173+
* \param[in] beta2 Coefficient for second moment of gradient.
174+
* \param[in] epsilon Term added to the denominator for numerical stability.
175+
* \param[in] step Iteration counter.
176+
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
177+
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
178+
* \param[in] weight_decay L2 penalty for weight decay.
179+
* \param[in] stream CUDA stream used for this operation.
180+
*/
181+
void nvte_multi_tensor_adam_mxfp8_cuda(
182+
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
183+
const size_t num_tensor_lists, const size_t num_tensors_per_list, const NVTEDType fp8_dtype,
184+
const float lr, const float beta1, const float beta2, const float epsilon, const int step,
185+
const int mode, const int bias_correction, const float weight_decay, cudaStream_t stream);
186+
152187
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
153188
* with CUDA graph support and LR scheduling.
154189
*

0 commit comments

Comments
 (0)