Skip to content

Commit 86bb28b

Browse files
authored
Merge pull request #20 from silveroxides/fix/dynamic-vram-speed
Fix/dynamic vram speed
2 parents 6c01f95 + 352021f commit 86bb28b

2 files changed

Lines changed: 40 additions & 46 deletions

File tree

nodes/loader_nodes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def INPUT_TYPES(cls):
3939
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
4040
"quant_format": (["auto", "int8", "int8_tensorwise", "float8_e4m3fn", "float8_e4m3fn_blockwise", "float8_e4m3fn_rowwise", "mxfp8", "hybrid_mxfp8", "nvfp4"],),
4141
"kernel_backend": (["pytorch", "triton"],),
42-
"disable_dynamic": ("BOOLEAN", {"default": True}),
42+
"disable_dynamic": ("BOOLEAN", {"default": False}),
4343
},
4444
}
4545

@@ -160,7 +160,7 @@ def INPUT_TYPES(cls):
160160
"unet_name": (folder_paths.get_filename_list("diffusion_models"),),
161161
"quant_format": (["auto", "int8", "int8_tensorwise", "float8_e4m3fn", "float8_e4m3fn_blockwise", "float8_e4m3fn_rowwise", "mxfp8", "hybrid_mxfp8", "nvfp4"],),
162162
"kernel_backend": (["pytorch", "triton"],),
163-
"disable_dynamic": ("BOOLEAN", {"default": True}),
163+
"disable_dynamic": ("BOOLEAN", {"default": False}),
164164
},
165165
}
166166

@@ -265,7 +265,7 @@ def INPUT_TYPES(cls):
265265
"type": (cls.CLIP_TYPES,),
266266
"quant_format": (["auto", "int8", "int8_tensorwise", "float8_e4m3fn", "float8_e4m3fn_blockwise", "float8_e4m3fn_rowwise", "mxfp8", "hybrid_mxfp8", "nvfp4"],),
267267
"kernel_backend": (["pytorch", "triton"],),
268-
"disable_dynamic": ("BOOLEAN", {"default": True}),
268+
"disable_dynamic": ("BOOLEAN", {"default": False}),
269269
},
270270
}
271271

@@ -383,7 +383,7 @@ def INPUT_TYPES(cls):
383383
"type": (cls.CLIP_TYPES,),
384384
"quant_format": (["auto", "int8", "int8_tensorwise", "float8_e4m3fn", "float8_e4m3fn_blockwise", "float8_e4m3fn_rowwise", "mxfp8", "hybrid_mxfp8", "nvfp4"],),
385385
"kernel_backend": (["pytorch", "triton"],),
386-
"disable_dynamic": ("BOOLEAN", {"default": True}),
386+
"disable_dynamic": ("BOOLEAN", {"default": False}),
387387
},
388388
}
389389

@@ -789,7 +789,7 @@ def INPUT_TYPES(cls):
789789
return {
790790
"required": {
791791
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
792-
"disable_dynamic": ("BOOLEAN", {"default": True}),
792+
"disable_dynamic": ("BOOLEAN", {"default": False}),
793793
},
794794
}
795795
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
@@ -808,7 +808,7 @@ def INPUT_TYPES(cls):
808808
return {
809809
"required": {
810810
"unet_name": (folder_paths.get_filename_list("diffusion_models"),),
811-
"disable_dynamic": ("BOOLEAN", {"default": True}),
811+
"disable_dynamic": ("BOOLEAN", {"default": False}),
812812
},
813813
}
814814
RETURN_TYPES = ("MODEL",)
@@ -828,7 +828,7 @@ def INPUT_TYPES(cls):
828828
"required": {
829829
"clip_name": (folder_paths.get_filename_list("text_encoders"),),
830830
"type": (QuantizedCLIPLoader.CLIP_TYPES,),
831-
"disable_dynamic": ("BOOLEAN", {"default": True}),
831+
"disable_dynamic": ("BOOLEAN", {"default": False}),
832832
},
833833
}
834834
RETURN_TYPES = ("CLIP",)
@@ -851,7 +851,7 @@ def INPUT_TYPES(cls):
851851
"text_encoder1": (te_list,),
852852
"text_encoder2": (te_and_ckpt_list,),
853853
"type": (QuantizedDualCLIPLoader.CLIP_TYPES,),
854-
"disable_dynamic": ("BOOLEAN", {"default": True}),
854+
"disable_dynamic": ("BOOLEAN", {"default": False}),
855855
},
856856
}
857857
RETURN_TYPES = ("CLIP",)

unified_ops.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -288,17 +288,22 @@ def forward_comfy_cast_weights(self, input):
288288

289289
input_dtype = input.dtype
290290

291-
if isinstance(weight, QuantizedTensor):
292-
if weight.device != input.device:
293-
weight = weight.to(device=input.device)
291+
is_quantized_fast_path = isinstance(weight, QuantizedTensor)
292+
cast_dtype = weight.dtype if is_quantized_fast_path else None
293+
cast_bias_dtype = input_dtype if is_quantized_fast_path else None
294+
295+
weight, bias, offload_stream = cast_bias_weight(
296+
self,
297+
input,
298+
dtype=cast_dtype,
299+
bias_dtype=cast_bias_dtype,
300+
offloadable=True,
301+
)
294302

303+
if isinstance(weight, QuantizedTensor):
295304
if hasattr(weight, "_params"):
296305
object.__setattr__(weight._params, "orig_dtype", input_dtype)
297306

298-
bias = self.bias
299-
if bias is not None:
300-
bias = bias.to(device=input.device, dtype=input_dtype)
301-
302307
if self.layout_type == "TensorCoreMXFP8Layout":
303308
input_shape = input.shape
304309
tensor_3d = input.ndim == 3
@@ -314,16 +319,15 @@ def forward_comfy_cast_weights(self, input):
314319
q_input = input
315320

316321
q_input = QuantizedTensor.from_float(q_input, "TensorCoreMXFP8Layout")
317-
output = torch.nn.functional.linear(q_input, weight, bias)
322+
out = torch.nn.functional.linear(q_input, weight, bias)
318323
if tensor_3d:
319-
output = output.reshape(input_shape[0], input_shape[1], -1)
324+
out = out.reshape(input_shape[0], input_shape[1], -1)
320325
if input.dtype == torch.float32:
321-
return output.to(torch.float32)
322-
return output
326+
out = out.to(torch.float32)
323327
else:
324-
return torch.nn.functional.linear(input.reshape(input_shape), weight.dequantize(), bias)
328+
out = torch.nn.functional.linear(input.reshape(input_shape), weight.dequantize(), bias)
325329

326-
if self.layout_type == "TensorCoreNVFP4Layout":
330+
elif self.layout_type == "TensorCoreNVFP4Layout":
327331
input_shape = input.shape
328332
tensor_3d = input.ndim == 3
329333

@@ -338,16 +342,15 @@ def forward_comfy_cast_weights(self, input):
338342
q_input = input
339343

340344
q_input = QuantizedTensor.from_float(q_input, "TensorCoreNVFP4Layout")
341-
output = torch.nn.functional.linear(q_input, weight, bias)
345+
out = torch.nn.functional.linear(q_input, weight, bias)
342346
if tensor_3d:
343-
output = output.reshape(input_shape[0], input_shape[1], -1)
347+
out = out.reshape(input_shape[0], input_shape[1], -1)
344348
if input.dtype == torch.float32:
345-
return output.to(torch.float32)
346-
return output
349+
out = out.to(torch.float32)
347350
else:
348-
return torch.nn.functional.linear(input.reshape(input_shape), weight.dequantize(), bias)
351+
out = torch.nn.functional.linear(input.reshape(input_shape), weight.dequantize(), bias)
349352

350-
if self.layout_type in ["TensorCoreFP8Layout", "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout"]:
353+
elif self.layout_type in ["TensorCoreFP8Layout", "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout"]:
351354
input_shape = input.shape
352355
tensor_3d = input.ndim == 3
353356

@@ -362,30 +365,21 @@ def forward_comfy_cast_weights(self, input):
362365
q_input = input
363366

364367
q_input = QuantizedTensor.from_float(q_input, self.layout_type, scale=getattr(self, 'input_scale', None))
365-
output = torch.nn.functional.linear(q_input, weight, bias)
368+
out = torch.nn.functional.linear(q_input, weight, bias)
366369
if tensor_3d:
367-
output = output.reshape(input_shape[0], input_shape[1], -1)
370+
out = out.reshape(input_shape[0], input_shape[1], -1)
368371
if input.dtype == torch.float32:
369-
return output.to(torch.float32)
370-
return output
372+
out = out.to(torch.float32)
371373
else:
372-
return torch.nn.functional.linear(input.reshape(input_shape), weight.dequantize(), bias)
374+
out = torch.nn.functional.linear(input.reshape(input_shape), weight.dequantize(), bias)
373375

374-
# Default trigger for QuantizedTensor dispatch -> layout-specific handler
375-
return torch.nn.functional.linear(input, weight, bias)
376+
else:
377+
# Default trigger for QuantizedTensor dispatch -> layout-specific handler
378+
out = torch.nn.functional.linear(input, weight, bias)
379+
380+
else:
381+
out = torch.nn.functional.linear(input, weight, bias)
376382

377-
# Fallback path if it's not wrapped in QuantizedTensor
378-
if self.is_quantized:
379-
weight = weight.to(device=input.device)
380-
381-
# We strictly avoid dequantizing the full weight here unless we have to,
382-
# but since we create QuantizedTensors for everything during load,
383-
# this path should barely ever be hit unless the user passes a raw quant tensor.
384-
# Just fallback to comfy manual cast.
385-
pass
386-
387-
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
388-
out = torch.nn.functional.linear(input, weight, bias)
389383
uncast_bias_weight(self, weight, bias, offload_stream)
390384
return out
391385

0 commit comments

Comments
 (0)