@@ -288,17 +288,22 @@ def forward_comfy_cast_weights(self, input):
288288
289289 input_dtype = input .dtype
290290
291- if isinstance (weight , QuantizedTensor ):
292- if weight .device != input .device :
293- weight = weight .to (device = input .device )
291+ is_quantized_fast_path = isinstance (weight , QuantizedTensor )
292+ cast_dtype = weight .dtype if is_quantized_fast_path else None
293+ cast_bias_dtype = input_dtype if is_quantized_fast_path else None
294+
295+ weight , bias , offload_stream = cast_bias_weight (
296+ self ,
297+ input ,
298+ dtype = cast_dtype ,
299+ bias_dtype = cast_bias_dtype ,
300+ offloadable = True ,
301+ )
294302
303+ if isinstance (weight , QuantizedTensor ):
295304 if hasattr (weight , "_params" ):
296305 object .__setattr__ (weight ._params , "orig_dtype" , input_dtype )
297306
298- bias = self .bias
299- if bias is not None :
300- bias = bias .to (device = input .device , dtype = input_dtype )
301-
302307 if self .layout_type == "TensorCoreMXFP8Layout" :
303308 input_shape = input .shape
304309 tensor_3d = input .ndim == 3
@@ -314,16 +319,15 @@ def forward_comfy_cast_weights(self, input):
314319 q_input = input
315320
316321 q_input = QuantizedTensor .from_float (q_input , "TensorCoreMXFP8Layout" )
317- output = torch .nn .functional .linear (q_input , weight , bias )
322+ out = torch .nn .functional .linear (q_input , weight , bias )
318323 if tensor_3d :
319- output = output .reshape (input_shape [0 ], input_shape [1 ], - 1 )
324+ out = out .reshape (input_shape [0 ], input_shape [1 ], - 1 )
320325 if input .dtype == torch .float32 :
321- return output .to (torch .float32 )
322- return output
326+ out = out .to (torch .float32 )
323327 else :
324- return torch .nn .functional .linear (input .reshape (input_shape ), weight .dequantize (), bias )
328+ out = torch .nn .functional .linear (input .reshape (input_shape ), weight .dequantize (), bias )
325329
326- if self .layout_type == "TensorCoreNVFP4Layout" :
330+ elif self .layout_type == "TensorCoreNVFP4Layout" :
327331 input_shape = input .shape
328332 tensor_3d = input .ndim == 3
329333
@@ -338,16 +342,15 @@ def forward_comfy_cast_weights(self, input):
338342 q_input = input
339343
340344 q_input = QuantizedTensor .from_float (q_input , "TensorCoreNVFP4Layout" )
341- output = torch .nn .functional .linear (q_input , weight , bias )
345+ out = torch .nn .functional .linear (q_input , weight , bias )
342346 if tensor_3d :
343- output = output .reshape (input_shape [0 ], input_shape [1 ], - 1 )
347+ out = out .reshape (input_shape [0 ], input_shape [1 ], - 1 )
344348 if input .dtype == torch .float32 :
345- return output .to (torch .float32 )
346- return output
349+ out = out .to (torch .float32 )
347350 else :
348- return torch .nn .functional .linear (input .reshape (input_shape ), weight .dequantize (), bias )
351+ out = torch .nn .functional .linear (input .reshape (input_shape ), weight .dequantize (), bias )
349352
350- if self .layout_type in ["TensorCoreFP8Layout" , "TensorCoreFP8E4M3Layout" , "TensorCoreFP8E5M2Layout" ]:
353+ elif self .layout_type in ["TensorCoreFP8Layout" , "TensorCoreFP8E4M3Layout" , "TensorCoreFP8E5M2Layout" ]:
351354 input_shape = input .shape
352355 tensor_3d = input .ndim == 3
353356
@@ -362,30 +365,21 @@ def forward_comfy_cast_weights(self, input):
362365 q_input = input
363366
364367 q_input = QuantizedTensor .from_float (q_input , self .layout_type , scale = getattr (self , 'input_scale' , None ))
365- output = torch .nn .functional .linear (q_input , weight , bias )
368+ out = torch .nn .functional .linear (q_input , weight , bias )
366369 if tensor_3d :
367- output = output .reshape (input_shape [0 ], input_shape [1 ], - 1 )
370+ out = out .reshape (input_shape [0 ], input_shape [1 ], - 1 )
368371 if input .dtype == torch .float32 :
369- return output .to (torch .float32 )
370- return output
372+ out = out .to (torch .float32 )
371373 else :
372- return torch .nn .functional .linear (input .reshape (input_shape ), weight .dequantize (), bias )
374+ out = torch .nn .functional .linear (input .reshape (input_shape ), weight .dequantize (), bias )
373375
374- # Default trigger for QuantizedTensor dispatch -> layout-specific handler
375- return torch .nn .functional .linear (input , weight , bias )
376+ else :
377+ # Default trigger for QuantizedTensor dispatch -> layout-specific handler
378+ out = torch .nn .functional .linear (input , weight , bias )
379+
380+ else :
381+ out = torch .nn .functional .linear (input , weight , bias )
376382
377- # Fallback path if it's not wrapped in QuantizedTensor
378- if self .is_quantized :
379- weight = weight .to (device = input .device )
380-
381- # We strictly avoid dequantizing the full weight here unless we have to,
382- # but since we create QuantizedTensors for everything during load,
383- # this path should barely ever be hit unless the user passes a raw quant tensor.
384- # Just fallback to comfy manual cast.
385- pass
386-
387- weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
388- out = torch .nn .functional .linear (input , weight , bias )
389383 uncast_bias_weight (self , weight , bias , offload_stream )
390384 return out
391385
0 commit comments