Skip to content

Commit f3b97c2

Browse files
authored
Fix out of bounds access in the FP4 dequantize kernel (NVIDIA#2346)
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
1 parent dcaca2a commit f3b97c2

File tree

2 files changed

+7
-30
lines changed

2 files changed

+7
-30
lines changed

transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ __global__ void __launch_bounds__(512)
3939
const size_t x = thread_idx % M;
4040
const size_t y = thread_idx / M;
4141

42+
if (y >= N) {
43+
return;
44+
}
45+
4246
union fp4vec {
4347
uint64_t vec;
4448
fp4e2m1x4 small_vec[4];

transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
import torch
1515

16-
# import transformer_engine_torch as tex
16+
import transformer_engine_torch as tex
1717
from transformer_engine_torch import DType as TE_DType
1818

1919
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
2020

21-
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
21+
from ...constants import TE_DType as torch_to_transformer_engine_dtype
2222
from ...utils import _empty_tensor
2323

2424

@@ -45,34 +45,7 @@ def forward(
4545

4646
# Dequantize row-wise data
4747
if tensor._rowwise_data is not None:
48-
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
49-
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
50-
51-
# Tensor properties
52-
shape = list(tensor._rowwise_data.size())
53-
shape[-1] *= 2
54-
device = tensor._rowwise_data.device
55-
56-
# Convert FP4E2M1 values to FP32
57-
data = tensor._rowwise_data.view(torch.uint8).to(torch.int32)
58-
data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape)
59-
data = _fp4_e2m1_vals(device, dtype=torch.float32)[data]
60-
data = data.to(torch.float32).contiguous()
61-
62-
# Convert FP8E4M3 block scales to FP32
63-
block_scales = tensor._rowwise_scale_inv
64-
block_scales = block_scales.reshape(-1, block_scales.size(-1))
65-
block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16]
66-
block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32)
67-
68-
# Convert amax to FP32 tensor scale
69-
tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max
70-
71-
# Apply scales
72-
block_data = data.view(-1, 16)
73-
block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1)
74-
75-
return data.to(dtype)
48+
return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
7649

7750
if tensor._columnwise_data is not None:
7851
raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!")

0 commit comments

Comments
 (0)