@@ -108,22 +108,24 @@ def lightllm_per_token_group_quant_fp8(
108108def per_token_group_quant_fp8 (
109109 x : torch .Tensor ,
110110 group_size : int ,
111- x_q : torch .Tensor ,
112- x_s : torch .Tensor = None ,
113111 eps : float = 1e-10 ,
114112 dtype : torch .dtype = torch .float8_e4m3fn ,
115113 column_major_scales : bool = False ,
116114 scale_tma_aligned : bool = False ,
117115 alloc_func : Callable = torch .empty ,
118116):
117+ x_q = alloc_func (x .shape , device = x .device , dtype = dtype )
118+ x_s = None
119119 # Adapted from
120120 # https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290
121121 if HAS_SGL_KERNEL :
122122 finfo = torch .finfo (dtype )
123123 fp8_max , fp8_min = finfo .max , finfo .min
124+
125+ # 创建scale张量
124126 if column_major_scales :
125127 if scale_tma_aligned :
126- # aligned to 4 * sizeof(float)
128+ # 对齐到4 * sizeof(float)
127129 aligned_size = (x .shape [- 2 ] + 3 ) // 4 * 4
128130 x_s = alloc_func (
129131 x .shape [:- 2 ] + (x .shape [- 1 ] // group_size , aligned_size ),
@@ -137,16 +139,24 @@ def per_token_group_quant_fp8(
137139 dtype = torch .float32 ,
138140 ).permute (- 1 , - 2 )
139141 else :
140- if x_s is None :
141- x_s = alloc_func (
142- x .shape [:- 1 ] + (x .shape [- 1 ] // group_size ,),
143- device = x .device ,
144- dtype = torch .float32 ,
145- )
142+ x_s = alloc_func (
143+ x .shape [:- 1 ] + (x .shape [- 1 ] // group_size ,),
144+ device = x .device ,
145+ dtype = torch .float32 ,
146+ )
147+
148+ # 使用SGL kernel进行量化
146149 sgl_ops .sgl_per_token_group_quant_fp8 (x , x_q , x_s , group_size , 1e-10 , fp8_min , fp8_max , False )
147150 else :
151+ # 使用LightLLM kernel进行量化
152+ x_s = alloc_func (
153+ x .shape [:- 1 ] + (x .shape [- 1 ] // group_size ,),
154+ device = x .device ,
155+ dtype = torch .float32 ,
156+ )
148157 lightllm_per_token_group_quant_fp8 (x , group_size , x_q , x_s , eps = 1e-10 , dtype = torch .float8_e4m3fn )
149-
158+ if column_major_scales and scale_tma_aligned :
159+ x_s = tma_align_input_scale (x_s )
150160 return x_q , x_s
151161
152162
@@ -237,9 +247,9 @@ def test_tma_align():
237247 m = 576
238248 k = 8192
239249 x = torch .randn ((m , k // 128 ), dtype = torch .float32 ).cuda ()
250+
240251 for _ in range (10 ):
241252 x_padded = tma_align_input_scale (x )
242- print (x_padded .shape )
243253 import time
244254
245255 torch .cuda .synchronize ()
@@ -255,11 +265,9 @@ def test_tma_align():
255265def test_per_token_group_quant_fp8 ():
256266 group_size = 128
257267 x = torch .randn ((1024 , 8192 ), dtype = torch .bfloat16 ).cuda ()
258-
259- x_q = torch .randn ((1024 , 8192 )).cuda ().to (torch .float8_e4m3fn )
260268 # x_s = torch.randn((1024, 8192 // group_size), dtype=torch.float32).cuda()
261269 # x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t()
262- _ , x_s = per_token_group_quant_fp8 (x , group_size , x_q , None , column_major_scales = True )
270+ x_q , x_s = per_token_group_quant_fp8 (x , group_size , column_major_scales = True , scale_tma_aligned = True )
263271 x_s = x_s [:1024 ]
264272 th_x_q , th_x_s = torch_quant (x , group_size )
265273 print ("th_x_s - x_s" , torch .abs (th_x_s - x_s .reshape (- 1 )).max ())
@@ -268,4 +276,4 @@ def test_per_token_group_quant_fp8():
268276
269277if __name__ == "__main__" :
270278 test_per_token_group_quant_fp8 ()
271- # test_tma_align()
279+ test_tma_align ()
0 commit comments