|
| 1 | +""" |
| 2 | +Hybrid FP8 Operations with correct block_size handling. |
| 3 | +
|
| 4 | +This module provides custom ops that correctly read group_size from per-layer |
| 5 | +metadata for FP8 rowwise and blockwise quantized models. |
| 6 | +
|
| 7 | +The issue: Core ComfyUI's MixedPrecisionOps reads block_size from QUANT_ALGOS |
| 8 | +fallback instead of per-layer metadata, causing wrong block boundaries. |
| 9 | +""" |
| 10 | + |
| 11 | +import json |
| 12 | +import torch |
| 13 | +import logging |
| 14 | +from comfy.ops import manual_cast, cast_bias_weight, uncast_bias_weight |
| 15 | +from comfy.quant_ops import QuantizedTensor, LAYOUTS, QUANT_ALGOS |
| 16 | + |
| 17 | + |
| 18 | +class HybridFP8Ops(manual_cast): |
| 19 | + """ |
| 20 | + Hybrid FP8 operations class that correctly handles block_size from metadata. |
| 21 | +
|
| 22 | + Fixes the core bug where block_size is read from QUANT_ALGOS fallback |
| 23 | + instead of per-layer .comfy_quant metadata. |
| 24 | + """ |
| 25 | + |
| 26 | + class Linear(manual_cast.Linear): |
| 27 | + def __init__(self, *args, **kwargs): |
| 28 | + super().__init__(*args, **kwargs) |
| 29 | + self.scale_weight = None |
| 30 | + self.block_size = None |
| 31 | + self.is_quantized = False |
| 32 | + self.layout_type = None |
| 33 | + self.quant_format = None |
| 34 | + |
| 35 | + def reset_parameters(self): |
| 36 | + return None |
| 37 | + |
| 38 | + def _load_from_state_dict( |
| 39 | + self, |
| 40 | + state_dict, |
| 41 | + prefix, |
| 42 | + local_metadata, |
| 43 | + strict, |
| 44 | + missing_keys, |
| 45 | + unexpected_keys, |
| 46 | + error_msgs, |
| 47 | + ): |
| 48 | + """ |
| 49 | + Custom state dict loading that correctly reads group_size from per-layer metadata. |
| 50 | + """ |
| 51 | + weight_key = prefix + "weight" |
| 52 | + |
| 53 | + # Get weight_scale |
| 54 | + scale = state_dict.pop(prefix + "weight_scale", None) |
| 55 | + |
| 56 | + # Remove input_scale if present (not used for weight dequantization) |
| 57 | + state_dict.pop(prefix + "input_scale", None) |
| 58 | + |
| 59 | + # Parse comfy_quant metadata for layout type and block_size |
| 60 | + comfy_quant_tensor = state_dict.pop(prefix + "comfy_quant", None) |
| 61 | + layer_conf = None |
| 62 | + |
| 63 | + if comfy_quant_tensor is not None: |
| 64 | + try: |
| 65 | + # Decode the comfy_quant tensor to dict |
| 66 | + layer_conf = json.loads(comfy_quant_tensor.numpy().tobytes()) |
| 67 | + self.quant_format = layer_conf.get("format", None) |
| 68 | + # KEY FIX: Read group_size from per-layer metadata! |
| 69 | + self.block_size = layer_conf.get("group_size", None) |
| 70 | + logging.debug( |
| 71 | + f"HybridFP8Ops: Parsed comfy_quant for {prefix}: format={self.quant_format}, group_size={self.block_size}" |
| 72 | + ) |
| 73 | + except Exception as e: |
| 74 | + logging.debug( |
| 75 | + f"HybridFP8Ops: Failed to parse comfy_quant metadata: {e}" |
| 76 | + ) |
| 77 | + |
| 78 | + # Load weight tensor |
| 79 | + weight_tensor = state_dict.pop(weight_key, None) |
| 80 | + |
| 81 | + if weight_tensor is not None: |
| 82 | + # Check if this is an FP8 tensor |
| 83 | + if weight_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: |
| 84 | + self.is_quantized = True |
| 85 | + self.scale_weight = scale |
| 86 | + |
| 87 | + # Determine layout type from format |
| 88 | + if self.quant_format is not None: |
| 89 | + qconfig = QUANT_ALGOS.get(self.quant_format, {}) |
| 90 | + self.layout_type = qconfig.get( |
| 91 | + "comfy_tensor_layout", "TensorCoreFP8Layout" |
| 92 | + ) |
| 93 | + |
| 94 | + # Fallback block_size from QUANT_ALGOS only if not in metadata |
| 95 | + if self.block_size is None: |
| 96 | + self.block_size = qconfig.get("group_size", None) |
| 97 | + else: |
| 98 | + # Infer layout from scale shape |
| 99 | + if scale is not None: |
| 100 | + if scale.ndim == 0 or ( |
| 101 | + scale.ndim == 1 and scale.numel() == 1 |
| 102 | + ): |
| 103 | + self.layout_type = "TensorCoreFP8Layout" |
| 104 | + elif ( |
| 105 | + scale.ndim == 1 |
| 106 | + and scale.numel() == weight_tensor.shape[0] |
| 107 | + ): |
| 108 | + self.layout_type = "RowWiseFP8Layout" |
| 109 | + elif scale.ndim == 2: |
| 110 | + self.layout_type = "BlockWiseFP8Layout" |
| 111 | + # Infer block_size from scale shape |
| 112 | + if self.block_size is None: |
| 113 | + M, N = weight_tensor.shape |
| 114 | + scale_M, scale_N = scale.shape |
| 115 | + if M % scale_M == 0 and N % scale_N == 0: |
| 116 | + self.block_size = M // scale_M |
| 117 | + else: |
| 118 | + self.layout_type = "TensorCoreFP8Layout" |
| 119 | + else: |
| 120 | + self.layout_type = "TensorCoreFP8Layout" |
| 121 | + |
| 122 | + # Check if the layout is registered |
| 123 | + if self.layout_type not in LAYOUTS: |
| 124 | + logging.warning( |
| 125 | + f"HybridFP8Ops: Layout '{self.layout_type}' not registered, using TensorCoreFP8Layout" |
| 126 | + ) |
| 127 | + self.layout_type = "TensorCoreFP8Layout" |
| 128 | + |
| 129 | + # Build layout_params with correct block_size |
| 130 | + layout_params = { |
| 131 | + "scale": scale.to(torch.float32) if scale is not None else None, |
| 132 | + "orig_dtype": torch.bfloat16, # Will be updated in forward |
| 133 | + } |
| 134 | + |
| 135 | + # Add block_size for layouts that need it |
| 136 | + if self.layout_type in [ |
| 137 | + "BlockWiseFP8Layout", |
| 138 | + "BlockWiseINT8Layout", |
| 139 | + ]: |
| 140 | + if self.block_size is not None: |
| 141 | + layout_params["block_size"] = self.block_size |
| 142 | + else: |
| 143 | + # Last resort fallback |
| 144 | + layout_params["block_size"] = 64 |
| 145 | + logging.warning( |
| 146 | + f"HybridFP8Ops: No block_size found for {prefix}, using fallback 64" |
| 147 | + ) |
| 148 | + |
| 149 | + # Create QuantizedTensor |
| 150 | + self.weight = torch.nn.Parameter( |
| 151 | + QuantizedTensor(weight_tensor, self.layout_type, layout_params), |
| 152 | + requires_grad=False, |
| 153 | + ) |
| 154 | + logging.debug( |
| 155 | + f"HybridFP8Ops: Loaded FP8 layer {prefix} with layout={self.layout_type}, block_size={self.block_size}" |
| 156 | + ) |
| 157 | + else: |
| 158 | + # Non-FP8 weight - high-precision layer |
| 159 | + self.is_quantized = False |
| 160 | + self.scale_weight = None |
| 161 | + self.weight = torch.nn.Parameter(weight_tensor, requires_grad=False) |
| 162 | + else: |
| 163 | + missing_keys.append(weight_key) |
| 164 | + |
| 165 | + # Handle bias |
| 166 | + bias_key = prefix + "bias" |
| 167 | + bias_tensor = state_dict.pop(bias_key, None) |
| 168 | + if bias_tensor is not None: |
| 169 | + self.bias = torch.nn.Parameter(bias_tensor, requires_grad=False) |
| 170 | + else: |
| 171 | + self.bias = None |
| 172 | + |
| 173 | + def forward_comfy_cast_weights(self, input): |
| 174 | + """Forward pass with proper FP8 handling.""" |
| 175 | + weight = self.weight |
| 176 | + if isinstance(weight, torch.nn.Parameter): |
| 177 | + weight = weight.data |
| 178 | + |
| 179 | + input_dtype = input.dtype |
| 180 | + |
| 181 | + # Handle QuantizedTensor (triggers dispatch to layout handlers) |
| 182 | + if isinstance(weight, QuantizedTensor): |
| 183 | + # Move to input device if needed |
| 184 | + if weight.device != input.device: |
| 185 | + weight = weight.to(device=input.device) |
| 186 | + |
| 187 | + # Update orig_dtype for dequantization |
| 188 | + if hasattr(weight, "_layout_params"): |
| 189 | + weight._layout_params["orig_dtype"] = input_dtype |
| 190 | + |
| 191 | + bias = self.bias |
| 192 | + if bias is not None: |
| 193 | + bias = bias.to(device=input.device, dtype=input_dtype) |
| 194 | + |
| 195 | + # This triggers QuantizedTensor dispatch -> layout-specific handler |
| 196 | + return torch.nn.functional.linear(input, weight, bias) |
| 197 | + |
| 198 | + # Fallback: dequantize FP8 weight manually if needed |
| 199 | + if self.is_quantized and weight.dtype in [ |
| 200 | + torch.float8_e4m3fn, |
| 201 | + torch.float8_e5m2, |
| 202 | + ]: |
| 203 | + weight = weight.to(device=input.device) |
| 204 | + |
| 205 | + if self.scale_weight is not None: |
| 206 | + scale = self.scale_weight.to(device=input.device) |
| 207 | + weight_dequant = self._dequantize_weight(weight, scale, input_dtype) |
| 208 | + else: |
| 209 | + weight_dequant = weight.to(input_dtype) |
| 210 | + |
| 211 | + bias = self.bias |
| 212 | + if bias is not None: |
| 213 | + bias = bias.to(device=input.device, dtype=input_dtype) |
| 214 | + return torch.nn.functional.linear(input, weight_dequant, bias) |
| 215 | + |
| 216 | + # Standard manual_cast path for non-quantized weights |
| 217 | + weight, bias, offload_stream = cast_bias_weight( |
| 218 | + self, input, offloadable=True |
| 219 | + ) |
| 220 | + out = torch.nn.functional.linear(input, weight, bias) |
| 221 | + uncast_bias_weight(self, weight, bias, offload_stream) |
| 222 | + return out |
| 223 | + |
| 224 | + def _dequantize_weight(self, weight, scale, input_dtype): |
| 225 | + """Dequantize FP8 weight to float. |
| 226 | +
|
| 227 | + Handles: |
| 228 | + - TensorCoreFP8Layout: scalar scale |
| 229 | + - RowWiseFP8Layout: scale shape (M,) |
| 230 | + - BlockWiseFP8Layout: scale shape (M//block_size, N//block_size) |
| 231 | + """ |
| 232 | + M, N = weight.shape |
| 233 | + |
| 234 | + # Scalar scale (tensor-wise) |
| 235 | + if scale.ndim == 0 or (scale.ndim == 1 and scale.numel() == 1): |
| 236 | + return weight.to(input_dtype) * scale.item() |
| 237 | + |
| 238 | + # Row-wise scale |
| 239 | + if scale.ndim == 1 and scale.shape[0] == M: |
| 240 | + scale_broadcast = scale.unsqueeze(1).to( |
| 241 | + device=weight.device, dtype=input_dtype |
| 242 | + ) |
| 243 | + return weight.to(input_dtype) * scale_broadcast |
| 244 | + |
| 245 | + # Block-wise scale |
| 246 | + if scale.ndim == 2 and self.block_size is not None: |
| 247 | + block_size = self.block_size |
| 248 | + if M % block_size == 0 and N % block_size == 0: |
| 249 | + qdata_blocked = weight.reshape( |
| 250 | + M // block_size, block_size, N // block_size, block_size |
| 251 | + ) |
| 252 | + qdata_blocked = qdata_blocked.permute(0, 2, 1, 3) |
| 253 | + scale_broadcast = ( |
| 254 | + scale.unsqueeze(-1) |
| 255 | + .unsqueeze(-1) |
| 256 | + .to(device=weight.device, dtype=input_dtype) |
| 257 | + ) |
| 258 | + dequant = qdata_blocked.to(input_dtype) * scale_broadcast |
| 259 | + return dequant.permute(0, 2, 1, 3).reshape(M, N) |
| 260 | + |
| 261 | + # Fallback: try broadcasting |
| 262 | + logging.warning( |
| 263 | + f"FP8 scale shape {scale.shape} for weight {weight.shape}, using broadcast" |
| 264 | + ) |
| 265 | + return weight.to(input_dtype) * scale.to( |
| 266 | + device=weight.device, dtype=input_dtype |
| 267 | + ) |
| 268 | + |
| 269 | + def forward(self, *args, **kwargs): |
| 270 | + if ( |
| 271 | + self.comfy_cast_weights |
| 272 | + or len(self.weight_function) > 0 |
| 273 | + or len(self.bias_function) > 0 |
| 274 | + ): |
| 275 | + return self.forward_comfy_cast_weights(*args, **kwargs) |
| 276 | + else: |
| 277 | + weight = self.weight |
| 278 | + if isinstance(weight, torch.nn.Parameter): |
| 279 | + weight = weight.data |
| 280 | + |
| 281 | + # FP8 needs our special forward path |
| 282 | + if weight.dtype in [ |
| 283 | + torch.float8_e4m3fn, |
| 284 | + torch.float8_e5m2, |
| 285 | + ] or isinstance(weight, QuantizedTensor): |
| 286 | + return self.forward_comfy_cast_weights(*args, **kwargs) |
| 287 | + return super().forward(*args, **kwargs) |
| 288 | + |
| 289 | + def convert_weight(self, weight, inplace=False, **kwargs): |
| 290 | + """Convert weight for LoRA patching - dequantize FP8.""" |
| 291 | + if isinstance(weight, QuantizedTensor): |
| 292 | + return weight.dequantize() |
| 293 | + |
| 294 | + if ( |
| 295 | + weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] |
| 296 | + and self.scale_weight is not None |
| 297 | + ): |
| 298 | + return self._dequantize_weight(weight, self.scale_weight, torch.float32) |
| 299 | + |
| 300 | + return weight |
| 301 | + |
| 302 | + def set_weight( |
| 303 | + self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs |
| 304 | + ): |
| 305 | + """Set weight after LoRA patching.""" |
| 306 | + if return_weight: |
| 307 | + return weight |
| 308 | + |
| 309 | + if inplace_update: |
| 310 | + self.weight.data.copy_(weight) |
| 311 | + else: |
| 312 | + self.weight = torch.nn.Parameter(weight, requires_grad=False) |
| 313 | + |
| 314 | + # Mark as no longer quantized after patching |
| 315 | + self.is_quantized = False |
| 316 | + self.scale_weight = None |
| 317 | + |
| 318 | + # Normalization layers - use standard manual_cast versions |
| 319 | + class GroupNorm(manual_cast.GroupNorm): |
| 320 | + pass |
| 321 | + |
| 322 | + class LayerNorm(manual_cast.LayerNorm): |
| 323 | + pass |
| 324 | + |
| 325 | + class RMSNorm(manual_cast.RMSNorm): |
| 326 | + pass |
| 327 | + |
| 328 | + # Convolution layers - use standard manual_cast versions |
| 329 | + class Conv1d(manual_cast.Conv1d): |
| 330 | + pass |
| 331 | + |
| 332 | + class Conv2d(manual_cast.Conv2d): |
| 333 | + pass |
| 334 | + |
| 335 | + class Conv3d(manual_cast.Conv3d): |
| 336 | + pass |
| 337 | + |
| 338 | + class ConvTranspose1d(manual_cast.ConvTranspose1d): |
| 339 | + pass |
| 340 | + |
| 341 | + class ConvTranspose2d(manual_cast.ConvTranspose2d): |
| 342 | + pass |
| 343 | + |
| 344 | + class Embedding(manual_cast.Embedding): |
| 345 | + pass |
| 346 | + |
| 347 | + @classmethod |
| 348 | + def conv_nd(cls, dims, *args, **kwargs): |
| 349 | + if dims == 2: |
| 350 | + return cls.Conv2d(*args, **kwargs) |
| 351 | + elif dims == 3: |
| 352 | + return cls.Conv3d(*args, **kwargs) |
| 353 | + else: |
| 354 | + raise ValueError(f"unsupported dimensions: {dims}") |
0 commit comments