-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy path__init__.py
More file actions
328 lines (279 loc) · 12 KB
/
__init__.py
File metadata and controls
328 lines (279 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
"""
ComfyUI-QuantOps: Extended Quantization Layouts for ComfyUI
This custom node extends ComfyUI's quantization system with additional layouts:
- INT8 blockwise (with optional Triton acceleration)
- INT8 tensorwise (uses torch._int_mm with dynamic activation quant)
- Row-wise and Block-wise FP8 variants
All layouts are lazy-loaded to avoid import errors when optional dependencies
(like Triton) are not installed.
"""
import logging
# =============================================================================
# Module-level state for comfy-kitchen backend integration
# =============================================================================
_CK_AVAILABLE = False
_CK_TRITON_AVAILABLE = False
def is_ck_triton_available() -> bool:
"""Check if comfy-kitchen triton backend is available and enabled."""
return _CK_TRITON_AVAILABLE
# =============================================================================
# Backend Setup
# =============================================================================
def _setup_comfy_kitchen_backends():
"""
Configure comfy-kitchen backends for QuantOps.
1. Re-enable triton backend (ComfyUI disables it by default)
2. Register QuantOps kernels as a custom backend
"""
global _CK_AVAILABLE, _CK_TRITON_AVAILABLE
try:
import comfy_kitchen as ck
_CK_AVAILABLE = True
except ImportError:
logging.debug("ComfyUI-QuantOps: comfy-kitchen not available")
_CK_AVAILABLE = False
_CK_TRITON_AVAILABLE = False
return
# Step 1: Re-enable triton backend (ComfyUI disables it)
try:
ck.enable_backend("triton")
backends = ck.list_backends()
triton_info = backends.get("triton", {})
if triton_info.get("available") and not triton_info.get("disabled"):
_CK_TRITON_AVAILABLE = True
logging.info("ComfyUI-QuantOps: Enabled comfy-kitchen triton backend")
else:
unavail_reason = triton_info.get("unavailable_reason", "unknown")
logging.info(f"ComfyUI-QuantOps: comfy-kitchen triton unavailable: {unavail_reason}")
_CK_TRITON_AVAILABLE = False
except Exception as e:
logging.warning(f"ComfyUI-QuantOps: Failed to enable ck triton backend: {e}")
_CK_TRITON_AVAILABLE = False
# Step 2: Register QuantOps kernels as a custom backend
_register_quantops_backend()
def _register_quantops_backend():
"""
Register QuantOps Triton kernels with comfy-kitchen registry.
This allows ck dispatch to use our INT8/FP8 kernels.
"""
try:
import torch
from comfy_kitchen.registry import registry
from comfy_kitchen.constraints import (
FunctionConstraints,
ParamConstraint,
ExactDims,
DivisibleBy,
)
# Import our kernel modules
from .kernels import int8_kernels
from .kernels import fp8_kernels
cuda_devices = frozenset({"cuda"})
standard_floats = frozenset({torch.float32, torch.float16, torch.bfloat16})
# Build constraints for INT8 kernels
int8_constraints = {
"act_quant": FunctionConstraints(
params={
"x": ParamConstraint(
dtypes=standard_floats,
shape_rules=(DivisibleBy(-1, 128),), # Last dim divisible by block_size
),
},
default_devices=cuda_devices,
),
"act_dequant": FunctionConstraints(
params={
"x": ParamConstraint(dtypes=frozenset({torch.int8})),
"s": ParamConstraint(dtypes=frozenset({torch.float32})),
},
default_devices=cuda_devices,
),
"weight_quant": FunctionConstraints(
params={
"x": ParamConstraint(
dtypes=standard_floats,
shape_rules=(ExactDims(2),),
),
},
default_devices=cuda_devices,
),
"weight_dequant": FunctionConstraints(
params={
"x": ParamConstraint(dtypes=frozenset({torch.int8})),
"s": ParamConstraint(dtypes=frozenset({torch.float32})),
},
default_devices=cuda_devices,
),
}
# Build constraints for FP8 kernels
fp8_constraints = {
"fp8_act_quant": FunctionConstraints(
params={
"x": ParamConstraint(dtypes=standard_floats),
},
default_devices=cuda_devices,
),
"fp8_gemm_blockwise": FunctionConstraints(
params={
"a": ParamConstraint(dtypes=frozenset({torch.float8_e4m3fn})),
"b": ParamConstraint(dtypes=frozenset({torch.float8_e4m3fn})),
"a_s": ParamConstraint(dtypes=frozenset({torch.float32})),
"b_s": ParamConstraint(dtypes=frozenset({torch.float32})),
},
default_devices=cuda_devices,
),
"fp8_gemm_rowwise": FunctionConstraints(
params={
"a": ParamConstraint(dtypes=frozenset({torch.float8_e4m3fn})),
"b": ParamConstraint(dtypes=frozenset({torch.float8_e4m3fn})),
"a_s": ParamConstraint(dtypes=frozenset({torch.float32})),
"b_s": ParamConstraint(dtypes=frozenset({torch.float32})),
},
default_devices=cuda_devices,
),
}
# Register INT8 backend
try:
registry.register(
name="quantops_int8",
module=int8_kernels,
capabilities=int8_constraints,
)
logging.info("ComfyUI-QuantOps: Registered quantops_int8 backend")
except Exception as e:
logging.debug(f"ComfyUI-QuantOps: Could not register INT8 backend: {e}")
# Register FP8 backend
try:
registry.register(
name="quantops_fp8",
module=fp8_kernels,
capabilities=fp8_constraints,
)
logging.info("ComfyUI-QuantOps: Registered quantops_fp8 backend")
except Exception as e:
logging.debug(f"ComfyUI-QuantOps: Could not register FP8 backend: {e}")
except ImportError as e:
logging.debug(f"ComfyUI-QuantOps: Could not register backends (missing deps): {e}")
except Exception as e:
logging.warning(f"ComfyUI-QuantOps: Backend registration failed: {e}")
# =============================================================================
# Layout Registration
# =============================================================================
def _register_layouts():
"""Register our custom layouts into ComfyUI's layout registry and QUANT_ALGOS dict."""
try:
from comfy.quant_ops import QUANT_ALGOS, register_layout_class
import torch
# Import our layouts (this also registers their operation handlers)
from .quant_layouts.int8_layout import BlockWiseINT8Layout
from .quant_layouts.fp8_variants import RowWiseFP8Layout, BlockWiseFP8Layout
# Register layouts using the new comfy_kitchen API
register_layout_class("BlockWiseINT8Layout", BlockWiseINT8Layout)
register_layout_class("RowWiseFP8Layout", RowWiseFP8Layout)
register_layout_class("BlockWiseFP8Layout", BlockWiseFP8Layout)
# Tensorwise INT8 from comfy_kitchen
try:
from comfy_kitchen.tensor.int8 import TensorWiseINT8Layout
register_layout_class("TensorWiseINT8Layout", TensorWiseINT8Layout)
# Load our patch for per-channel scale support
from .quant_layouts import tensorwise_int8_layout
logging.info("ComfyUI-QuantOps: Registered TensorWiseINT8Layout")
except ImportError:
logging.debug("ComfyUI-QuantOps: TensorWiseINT8Layout not available")
# Register QUANT_ALGOS
QUANT_ALGOS.setdefault(
"int8_tensorwise",
{
"storage_t": torch.int8,
"parameters": {"weight_scale", "input_scale"}, # Keep input_scale if checkpoints have it
"comfy_tensor_layout": "TensorWiseINT8Layout", # Must match the class name above
}
)
QUANT_ALGOS.setdefault(
"int8_blockwise",
{
"storage_t": torch.int8,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "BlockWiseINT8Layout",
"group_size": 128,
"asymmetric_layout": True,
},
)
QUANT_ALGOS.setdefault(
"float8_e4m3fn_rowwise",
{
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "RowWiseFP8Layout",
},
)
QUANT_ALGOS.setdefault(
"float8_e4m3fn_blockwise",
{
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "BlockWiseFP8Layout",
"group_size": 64,
},
)
# MXFP8 from comfy_kitchen
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
logging.info("ComfyUI-QuantOps: Registered TensorCoreMXFP8Layout")
except ImportError:
logging.debug("ComfyUI-QuantOps: TensorCoreMXFP8Layout not available")
QUANT_ALGOS.setdefault(
"mxfp8",
{
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
},
)
# Hybrid MXFP8 from comfy_kitchen
try:
from comfy_kitchen.tensor import HybridMXFP8Layout
register_layout_class("HybridMXFP8Layout", HybridMXFP8Layout)
logging.info("ComfyUI-QuantOps: Registered HybridMXFP8Layout")
except ImportError:
logging.debug("ComfyUI-QuantOps: HybridMXFP8Layout not available")
QUANT_ALGOS.setdefault(
"hybrid_mxfp8",
{
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "weight_scalar"},
"comfy_tensor_layout": "HybridMXFP8Layout",
"group_size": 32,
},
)
# NVFP4: Don't register layout (ComfyUI core does this), just add QUANT_ALGOS entry if missing
QUANT_ALGOS.setdefault(
"nvfp4",
{
"storage_t": torch.uint8,
"parameters": {"weight_scale", "weight_scale_2"},
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
"group_size": 16,
},
)
# Verify registration
registered = ["BlockWiseINT8Layout", "TensorWiseINT8Layout", "RowWiseFP8Layout", "BlockWiseFP8Layout", "TensorCoreMXFP8Layout"]
logging.info(f"ComfyUI-QuantOps: Registered layouts: {registered}")
except Exception as e:
logging.error(f"ComfyUI-QuantOps: Failed to register layouts: {e}")
# =============================================================================
# Module Initialization
# =============================================================================
# Setup backends first (enables ck triton, registers our kernels)
_setup_comfy_kitchen_backends()
# Register layouts
_register_layouts()
# Import nodes for ComfyUI discovery
from .nodes.loader_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = [
"NODE_CLASS_MAPPINGS",
"NODE_DISPLAY_NAME_MAPPINGS",
"is_ck_triton_available",
]