1313
1414import torch
1515
16- # import transformer_engine_torch as tex
16+ import transformer_engine_torch as tex
1717from transformer_engine_torch import DType as TE_DType
1818
1919from ...quantized_tensor import QuantizedTensorStorage , Quantizer
2020
21- # from ...constants import TE_DType as torch_to_transformer_engine_dtype
21+ from ...constants import TE_DType as torch_to_transformer_engine_dtype
2222from ...utils import _empty_tensor
2323
2424
@@ -45,34 +45,7 @@ def forward(
4545
4646 # Dequantize row-wise data
4747 if tensor ._rowwise_data is not None :
48- ### TODO(tmoon): Debug dequantize kernel and remove unfused impl
49- # return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
50-
51- # Tensor properties
52- shape = list (tensor ._rowwise_data .size ())
53- shape [- 1 ] *= 2
54- device = tensor ._rowwise_data .device
55-
56- # Convert FP4E2M1 values to FP32
57- data = tensor ._rowwise_data .view (torch .uint8 ).to (torch .int32 )
58- data = torch .stack ((data & 0x0F , data >> 4 ), dim = - 1 ).reshape (shape )
59- data = _fp4_e2m1_vals (device , dtype = torch .float32 )[data ]
60- data = data .to (torch .float32 ).contiguous ()
61-
62- # Convert FP8E4M3 block scales to FP32
63- block_scales = tensor ._rowwise_scale_inv
64- block_scales = block_scales .reshape (- 1 , block_scales .size (- 1 ))
65- block_scales = block_scales [: math .prod (shape [:- 1 ]), : shape [- 1 ] // 16 ]
66- block_scales = block_scales .view (torch .float8_e4m3fn ).to (torch .float32 )
67-
68- # Convert amax to FP32 tensor scale
69- tensor_scale = tensor ._amax_rowwise / (6.0 * 448.0 ) # Scale by FP4E2M1 and FP8E4M3 max
70-
71- # Apply scales
72- block_data = data .view (- 1 , 16 )
73- block_data *= tensor_scale .view (()) * block_scales .reshape (- 1 , 1 )
74-
75- return data .to (dtype )
48+ return tex .dequantize (tensor , torch_to_transformer_engine_dtype [dtype ])
7649
7750 if tensor ._columnwise_data is not None :
7851 raise NotImplementedError ("Dequantizing column-wise NVFP4 data is not implemented yet!" )
0 commit comments