Skip to content

Commit 1a8e9ee

Browse files
authored
Merge pull request #3 from silveroxides/fp8-loading
Fp8 blockwise loading
2 parents c6af2c9 + e99c705 commit 1a8e9ee

4 files changed

Lines changed: 699 additions & 160 deletions

File tree

fp8_ops.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
"""
2+
Hybrid FP8 Operations with correct block_size handling.
3+
4+
This module provides custom ops that correctly read group_size from per-layer
5+
metadata for FP8 rowwise and blockwise quantized models.
6+
7+
The issue: Core ComfyUI's MixedPrecisionOps reads block_size from QUANT_ALGOS
8+
fallback instead of per-layer metadata, causing wrong block boundaries.
9+
"""
10+
11+
import json
12+
import torch
13+
import logging
14+
from comfy.ops import manual_cast, cast_bias_weight, uncast_bias_weight
15+
from comfy.quant_ops import QuantizedTensor, LAYOUTS, QUANT_ALGOS
16+
17+
18+
class HybridFP8Ops(manual_cast):
19+
"""
20+
Hybrid FP8 operations class that correctly handles block_size from metadata.
21+
22+
Fixes the core bug where block_size is read from QUANT_ALGOS fallback
23+
instead of per-layer .comfy_quant metadata.
24+
"""
25+
26+
class Linear(manual_cast.Linear):
27+
def __init__(self, *args, **kwargs):
28+
super().__init__(*args, **kwargs)
29+
self.scale_weight = None
30+
self.block_size = None
31+
self.is_quantized = False
32+
self.layout_type = None
33+
self.quant_format = None
34+
35+
def reset_parameters(self):
36+
return None
37+
38+
def _load_from_state_dict(
39+
self,
40+
state_dict,
41+
prefix,
42+
local_metadata,
43+
strict,
44+
missing_keys,
45+
unexpected_keys,
46+
error_msgs,
47+
):
48+
"""
49+
Custom state dict loading that correctly reads group_size from per-layer metadata.
50+
"""
51+
weight_key = prefix + "weight"
52+
53+
# Get weight_scale
54+
scale = state_dict.pop(prefix + "weight_scale", None)
55+
56+
# Remove input_scale if present (not used for weight dequantization)
57+
state_dict.pop(prefix + "input_scale", None)
58+
59+
# Parse comfy_quant metadata for layout type and block_size
60+
comfy_quant_tensor = state_dict.pop(prefix + "comfy_quant", None)
61+
layer_conf = None
62+
63+
if comfy_quant_tensor is not None:
64+
try:
65+
# Decode the comfy_quant tensor to dict
66+
layer_conf = json.loads(comfy_quant_tensor.numpy().tobytes())
67+
self.quant_format = layer_conf.get("format", None)
68+
# KEY FIX: Read group_size from per-layer metadata!
69+
self.block_size = layer_conf.get("group_size", None)
70+
logging.debug(
71+
f"HybridFP8Ops: Parsed comfy_quant for {prefix}: format={self.quant_format}, group_size={self.block_size}"
72+
)
73+
except Exception as e:
74+
logging.debug(
75+
f"HybridFP8Ops: Failed to parse comfy_quant metadata: {e}"
76+
)
77+
78+
# Load weight tensor
79+
weight_tensor = state_dict.pop(weight_key, None)
80+
81+
if weight_tensor is not None:
82+
# Check if this is an FP8 tensor
83+
if weight_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
84+
self.is_quantized = True
85+
self.scale_weight = scale
86+
87+
# Determine layout type from format
88+
if self.quant_format is not None:
89+
qconfig = QUANT_ALGOS.get(self.quant_format, {})
90+
self.layout_type = qconfig.get(
91+
"comfy_tensor_layout", "TensorCoreFP8Layout"
92+
)
93+
94+
# Fallback block_size from QUANT_ALGOS only if not in metadata
95+
if self.block_size is None:
96+
self.block_size = qconfig.get("group_size", None)
97+
else:
98+
# Infer layout from scale shape
99+
if scale is not None:
100+
if scale.ndim == 0 or (
101+
scale.ndim == 1 and scale.numel() == 1
102+
):
103+
self.layout_type = "TensorCoreFP8Layout"
104+
elif (
105+
scale.ndim == 1
106+
and scale.numel() == weight_tensor.shape[0]
107+
):
108+
self.layout_type = "RowWiseFP8Layout"
109+
elif scale.ndim == 2:
110+
self.layout_type = "BlockWiseFP8Layout"
111+
# Infer block_size from scale shape
112+
if self.block_size is None:
113+
M, N = weight_tensor.shape
114+
scale_M, scale_N = scale.shape
115+
if M % scale_M == 0 and N % scale_N == 0:
116+
self.block_size = M // scale_M
117+
else:
118+
self.layout_type = "TensorCoreFP8Layout"
119+
else:
120+
self.layout_type = "TensorCoreFP8Layout"
121+
122+
# Check if the layout is registered
123+
if self.layout_type not in LAYOUTS:
124+
logging.warning(
125+
f"HybridFP8Ops: Layout '{self.layout_type}' not registered, using TensorCoreFP8Layout"
126+
)
127+
self.layout_type = "TensorCoreFP8Layout"
128+
129+
# Build layout_params with correct block_size
130+
layout_params = {
131+
"scale": scale.to(torch.float32) if scale is not None else None,
132+
"orig_dtype": torch.bfloat16, # Will be updated in forward
133+
}
134+
135+
# Add block_size for layouts that need it
136+
if self.layout_type in [
137+
"BlockWiseFP8Layout",
138+
"BlockWiseINT8Layout",
139+
]:
140+
if self.block_size is not None:
141+
layout_params["block_size"] = self.block_size
142+
else:
143+
# Last resort fallback
144+
layout_params["block_size"] = 64
145+
logging.warning(
146+
f"HybridFP8Ops: No block_size found for {prefix}, using fallback 64"
147+
)
148+
149+
# Create QuantizedTensor
150+
self.weight = torch.nn.Parameter(
151+
QuantizedTensor(weight_tensor, self.layout_type, layout_params),
152+
requires_grad=False,
153+
)
154+
logging.debug(
155+
f"HybridFP8Ops: Loaded FP8 layer {prefix} with layout={self.layout_type}, block_size={self.block_size}"
156+
)
157+
else:
158+
# Non-FP8 weight - high-precision layer
159+
self.is_quantized = False
160+
self.scale_weight = None
161+
self.weight = torch.nn.Parameter(weight_tensor, requires_grad=False)
162+
else:
163+
missing_keys.append(weight_key)
164+
165+
# Handle bias
166+
bias_key = prefix + "bias"
167+
bias_tensor = state_dict.pop(bias_key, None)
168+
if bias_tensor is not None:
169+
self.bias = torch.nn.Parameter(bias_tensor, requires_grad=False)
170+
else:
171+
self.bias = None
172+
173+
def forward_comfy_cast_weights(self, input):
174+
"""Forward pass with proper FP8 handling."""
175+
weight = self.weight
176+
if isinstance(weight, torch.nn.Parameter):
177+
weight = weight.data
178+
179+
input_dtype = input.dtype
180+
181+
# Handle QuantizedTensor (triggers dispatch to layout handlers)
182+
if isinstance(weight, QuantizedTensor):
183+
# Move to input device if needed
184+
if weight.device != input.device:
185+
weight = weight.to(device=input.device)
186+
187+
# Update orig_dtype for dequantization
188+
if hasattr(weight, "_layout_params"):
189+
weight._layout_params["orig_dtype"] = input_dtype
190+
191+
bias = self.bias
192+
if bias is not None:
193+
bias = bias.to(device=input.device, dtype=input_dtype)
194+
195+
# This triggers QuantizedTensor dispatch -> layout-specific handler
196+
return torch.nn.functional.linear(input, weight, bias)
197+
198+
# Fallback: dequantize FP8 weight manually if needed
199+
if self.is_quantized and weight.dtype in [
200+
torch.float8_e4m3fn,
201+
torch.float8_e5m2,
202+
]:
203+
weight = weight.to(device=input.device)
204+
205+
if self.scale_weight is not None:
206+
scale = self.scale_weight.to(device=input.device)
207+
weight_dequant = self._dequantize_weight(weight, scale, input_dtype)
208+
else:
209+
weight_dequant = weight.to(input_dtype)
210+
211+
bias = self.bias
212+
if bias is not None:
213+
bias = bias.to(device=input.device, dtype=input_dtype)
214+
return torch.nn.functional.linear(input, weight_dequant, bias)
215+
216+
# Standard manual_cast path for non-quantized weights
217+
weight, bias, offload_stream = cast_bias_weight(
218+
self, input, offloadable=True
219+
)
220+
out = torch.nn.functional.linear(input, weight, bias)
221+
uncast_bias_weight(self, weight, bias, offload_stream)
222+
return out
223+
224+
def _dequantize_weight(self, weight, scale, input_dtype):
225+
"""Dequantize FP8 weight to float.
226+
227+
Handles:
228+
- TensorCoreFP8Layout: scalar scale
229+
- RowWiseFP8Layout: scale shape (M,)
230+
- BlockWiseFP8Layout: scale shape (M//block_size, N//block_size)
231+
"""
232+
M, N = weight.shape
233+
234+
# Scalar scale (tensor-wise)
235+
if scale.ndim == 0 or (scale.ndim == 1 and scale.numel() == 1):
236+
return weight.to(input_dtype) * scale.item()
237+
238+
# Row-wise scale
239+
if scale.ndim == 1 and scale.shape[0] == M:
240+
scale_broadcast = scale.unsqueeze(1).to(
241+
device=weight.device, dtype=input_dtype
242+
)
243+
return weight.to(input_dtype) * scale_broadcast
244+
245+
# Block-wise scale
246+
if scale.ndim == 2 and self.block_size is not None:
247+
block_size = self.block_size
248+
if M % block_size == 0 and N % block_size == 0:
249+
qdata_blocked = weight.reshape(
250+
M // block_size, block_size, N // block_size, block_size
251+
)
252+
qdata_blocked = qdata_blocked.permute(0, 2, 1, 3)
253+
scale_broadcast = (
254+
scale.unsqueeze(-1)
255+
.unsqueeze(-1)
256+
.to(device=weight.device, dtype=input_dtype)
257+
)
258+
dequant = qdata_blocked.to(input_dtype) * scale_broadcast
259+
return dequant.permute(0, 2, 1, 3).reshape(M, N)
260+
261+
# Fallback: try broadcasting
262+
logging.warning(
263+
f"FP8 scale shape {scale.shape} for weight {weight.shape}, using broadcast"
264+
)
265+
return weight.to(input_dtype) * scale.to(
266+
device=weight.device, dtype=input_dtype
267+
)
268+
269+
def forward(self, *args, **kwargs):
270+
if (
271+
self.comfy_cast_weights
272+
or len(self.weight_function) > 0
273+
or len(self.bias_function) > 0
274+
):
275+
return self.forward_comfy_cast_weights(*args, **kwargs)
276+
else:
277+
weight = self.weight
278+
if isinstance(weight, torch.nn.Parameter):
279+
weight = weight.data
280+
281+
# FP8 needs our special forward path
282+
if weight.dtype in [
283+
torch.float8_e4m3fn,
284+
torch.float8_e5m2,
285+
] or isinstance(weight, QuantizedTensor):
286+
return self.forward_comfy_cast_weights(*args, **kwargs)
287+
return super().forward(*args, **kwargs)
288+
289+
def convert_weight(self, weight, inplace=False, **kwargs):
290+
"""Convert weight for LoRA patching - dequantize FP8."""
291+
if isinstance(weight, QuantizedTensor):
292+
return weight.dequantize()
293+
294+
if (
295+
weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
296+
and self.scale_weight is not None
297+
):
298+
return self._dequantize_weight(weight, self.scale_weight, torch.float32)
299+
300+
return weight
301+
302+
def set_weight(
303+
self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs
304+
):
305+
"""Set weight after LoRA patching."""
306+
if return_weight:
307+
return weight
308+
309+
if inplace_update:
310+
self.weight.data.copy_(weight)
311+
else:
312+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
313+
314+
# Mark as no longer quantized after patching
315+
self.is_quantized = False
316+
self.scale_weight = None
317+
318+
# Normalization layers - use standard manual_cast versions
319+
class GroupNorm(manual_cast.GroupNorm):
320+
pass
321+
322+
class LayerNorm(manual_cast.LayerNorm):
323+
pass
324+
325+
class RMSNorm(manual_cast.RMSNorm):
326+
pass
327+
328+
# Convolution layers - use standard manual_cast versions
329+
class Conv1d(manual_cast.Conv1d):
330+
pass
331+
332+
class Conv2d(manual_cast.Conv2d):
333+
pass
334+
335+
class Conv3d(manual_cast.Conv3d):
336+
pass
337+
338+
class ConvTranspose1d(manual_cast.ConvTranspose1d):
339+
pass
340+
341+
class ConvTranspose2d(manual_cast.ConvTranspose2d):
342+
pass
343+
344+
class Embedding(manual_cast.Embedding):
345+
pass
346+
347+
@classmethod
348+
def conv_nd(cls, dims, *args, **kwargs):
349+
if dims == 2:
350+
return cls.Conv2d(*args, **kwargs)
351+
elif dims == 3:
352+
return cls.Conv3d(*args, **kwargs)
353+
else:
354+
raise ValueError(f"unsupported dimensions: {dims}")

0 commit comments

Comments
 (0)