@@ -153,11 +153,6 @@ def __init__(
153153 else :
154154 self .bias = None
155155
156- # Blackwell/SM120 currently misbehaves with the fused AWQ kernels, so
157- # those devices rebuild one dense compatibility weight per module.
158- self ._sm120_compat_weight : torch .Tensor | None = None
159- self ._sm120_compat_weight_device : tuple [torch .device , torch .dtype ] | None = None
160-
161156 def forward (self , x : torch .Tensor ):
162157 if not awq_runtime_available ():
163158 raise ModuleNotFoundError ("AWQ torch.ops kernels are not properly installed. Error: " + awq_runtime_error ())
@@ -180,11 +175,6 @@ def forward(self, x: torch.Tensor):
180175
181176 self ._ensure_runtime_buffers (device = inputs .device , dtype = inputs .dtype )
182177
183- # Route SM120 devices through a compatibility implementation until the
184- # fused decode/prefill kernels are fixed for Blackwell.
185- if self ._use_sm120_compat_path (inputs .device ):
186- return self ._sm120_compat_forward (x = x , inputs = inputs , input_dtype = input_dtype )
187-
188178 zeros = self ._runtime_zeros ()
189179 if inputs_dim == 3 and batch_size < 8 and n_tokens == 1 :
190180 out = awq_fast_gemv_forward_decode (
@@ -212,84 +202,9 @@ def forward(self, x: torch.Tensor):
212202
213203 return out
214204
215- def _use_sm120_compat_path (self , device : torch .device ) -> bool :
216- """Enable the SM120 compatibility path on Blackwell-class CUDA devices."""
217-
218- if device .type != "cuda" :
219- return False
220- major , _minor = torch .cuda .get_device_capability (device )
221- return major >= 12
222-
223- def _sm120_compat_forward (
224- self ,
225- * ,
226- x : torch .Tensor ,
227- inputs : torch .Tensor ,
228- input_dtype : torch .dtype ,
229- ) -> torch .Tensor :
230- """Run a dense compatibility matmul for SM120 until fused kernels are stable."""
231-
232- out_shape = inputs .shape [:- 1 ] + (self .out_features ,)
233- weight = self ._sm120_compat_dense_weight (device = inputs .device , dtype = inputs .dtype )
234- out = inputs .reshape (- 1 , inputs .shape [- 1 ]).matmul (weight ).reshape (out_shape )
235-
236- if input_dtype != torch .float16 :
237- out = out .to (dtype = input_dtype )
238-
239- out = out + self .bias if self .bias is not None else out
240-
241- if self .adapter :
242- out = self .adapter .apply (x = x , out = out )
243-
244- return out
245-
246- def _sm120_compat_dense_weight (self , * , device : torch .device , dtype : torch .dtype ) -> torch .Tensor :
247- """Cache one dense AWQ weight matrix per device/dtype for the SM120 path."""
248-
249- cache_key = (device , dtype )
250- if self ._sm120_compat_weight is not None and self ._sm120_compat_weight_device == cache_key :
251- return self ._sm120_compat_weight
252-
253- intweight = self ._unpack_reference_intweight (device = device )
254-
255- num_groups = max (1 , (self .in_features + self .group_size - 1 ) // self .group_size )
256- scales = self .scales .transpose (0 , 1 )[:, :num_groups ].to (device = device , dtype = dtype )
257- zeros = self ._runtime_zeros ().transpose (0 , 1 )[:, :num_groups ].to (device = device , dtype = dtype )
258-
259- scales = scales .repeat_interleave (self .group_size , dim = 1 )[:, : self .in_features ]
260- zeros = zeros .repeat_interleave (self .group_size , dim = 1 )[:, : self .in_features ]
261-
262- weight = (intweight .to (dtype = dtype ) * scales + zeros ).transpose (0 , 1 ).contiguous ()
263- self ._sm120_compat_weight = weight
264- self ._sm120_compat_weight_device = cache_key
265- return weight
266-
267- def _unpack_reference_intweight (self , * , device : torch .device ) -> torch .Tensor :
268- """Invert the GEMV_FAST int16 packing so SM120 can rebuild dense weights."""
269-
270- packed = self .qweight .to (device = device , dtype = torch .int32 )
271- unpacked = torch .stack (
272- [
273- torch .bitwise_and (torch .bitwise_right_shift (packed , shift ), 0xF )
274- for shift in (0 , 4 , 8 , 12 )
275- ],
276- dim = - 1 ,
277- )
278- unpacked = unpacked .view (packed .shape [0 ], packed .shape [1 ] // 64 , 4 , 64 )
279- unpacked = unpacked .permute (0 , 2 , 1 , 3 ).contiguous ()
280- unpacked = unpacked .view (packed .shape [0 ] * 4 , self .in_features )
281- unpacked = unpacked .view (packed .shape [0 ] * 4 , self .in_features // 32 , 4 , 2 , 4 )
282- unpacked = unpacked .permute (0 , 1 , 2 , 4 , 3 ).contiguous ()
283- unpacked = unpacked .view (packed .shape [0 ] * 4 , self .in_features // 32 , 32 )
284- unpacked = unpacked .view (packed .shape [0 ] * 4 , self .in_features // 32 , 4 , 4 , 2 )
285- unpacked = unpacked .permute (0 , 1 , 3 , 2 , 4 ).contiguous ()
286- return unpacked .view (self .out_features , self .in_features )
287-
288205 def _ensure_runtime_buffers (self , * , device : torch .device , dtype : torch .dtype ):
289206 if self .qweight .device != device or not self .qweight .is_contiguous ():
290207 self .qweight = self .qweight .to (device = device ).contiguous ()
291- self ._sm120_compat_weight = None
292- self ._sm120_compat_weight_device = None
293208
294209 zeros = self ._runtime_zeros ()
295210 if zeros .device != device or zeros .dtype != dtype or not zeros .is_contiguous ():
@@ -300,13 +215,9 @@ def _ensure_runtime_buffers(self, *, device: torch.device, dtype: torch.dtype):
300215 self .scaled_zeros = zeros
301216 else :
302217 raise ValueError (f"Unsupported zeros buffer: { self .zeros_name } " )
303- self ._sm120_compat_weight = None
304- self ._sm120_compat_weight_device = None
305218
306219 if self .scales .device != device or self .scales .dtype != dtype or not self .scales .is_contiguous ():
307220 self .scales = self .scales .to (device = device , dtype = dtype ).contiguous ()
308- self ._sm120_compat_weight = None
309- self ._sm120_compat_weight_device = None
310221
311222 if self .bias is not None and (
312223 self .bias .device != device or self .bias .dtype != dtype or not self .bias .is_contiguous ()
0 commit comments