Skip to content

Commit 8a5f5fb

Browse files
authored
Merge pull request #24 from silveroxides/fix/mixed-format-loading
Fix mixed dtype loading and fix loading of text encoders(especially i…
2 parents 6f18916 + dc9f832 commit 8a5f5fb

10 files changed

Lines changed: 430 additions & 1645 deletions

File tree

bnb4bit_ops.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.nn.functional as F
2323
import logging
2424
from comfy.ops import manual_cast, cast_bias_weight, uncast_bias_weight
25+
from unifiedefficientloader import tensor_to_dict
2526

2627

2728
# NF4 (Normal Float 4-bit) quantization table
@@ -68,13 +69,6 @@
6869
], dtype=torch.float32)
6970

7071

71-
def tensor_to_dict(tensor_data: torch.Tensor) -> dict:
72-
"""Convert a uint8 tensor containing JSON bytes back to a dictionary."""
73-
byte_data = bytes(tensor_data.cpu().tolist())
74-
json_str = byte_data.decode("utf-8")
75-
return json.loads(json_str)
76-
77-
7872
def get_quant_map(quant_type: str, device: torch.device) -> torch.Tensor:
7973
"""Get the quantization codebook for NF4 or FP4."""
8074
if quant_type == "nf4":

comfy_quant_helpers.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

fp8_ops.py

Lines changed: 0 additions & 525 deletions
This file was deleted.

int8_ops.py

Lines changed: 0 additions & 481 deletions
This file was deleted.

nodes/loader_nodes.py

Lines changed: 192 additions & 164 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ classifiers = [
2020
"Programming Language :: Python :: 3.12",
2121
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2222
]
23+
dependencies = [
24+
"unifiedefficientloader>=0.2.0"
25+
]
2326

2427
[project.urls]
2528
Repository = "https://github.com/silveroxides/ComfyUI-QuantOps"

quant_layouts/int8_layout.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,12 @@ def _int8_gemm_pytorch_fallback(
409409
b_fp32 = b_blocked.to(torch.float32) * b_scale_broadcast
410410
b_fp32 = b_fp32.permute(0, 2, 1, 3).reshape(N, K)
411411

412+
# Bias may arrive in bfloat16 when previous layers (e.g. TensorWiseINT8Layout)
413+
# output bfloat16 and cast_bias_weight matches the input dtype.
414+
# The INT8 fallback computes in float32, so bias must match.
415+
if bias is not None and bias.dtype != a_fp32.dtype:
416+
bias = bias.to(a_fp32.dtype)
417+
412418
output = torch.nn.functional.linear(a_fp32, b_fp32, bias)
413419
return output
414420

unified_ops.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
from comfy.ops import manual_cast, cast_bias_weight, uncast_bias_weight
1515
from comfy.quant_ops import QuantizedTensor, QUANT_ALGOS, get_layout_class
1616
from comfy.model_patcher import LowVramPatch
17+
from unifiedefficientloader import tensor_to_dict
1718

1819
# Try to import INT8 layouts
1920
try:
20-
from .quant_layouts.int8_layout import BlockWiseINT8Layout
21+
from comfy_kitchen.tensor.int8 import BlockWiseINT8Layout
2122
_HAS_INT8_LAYOUT = True
2223
except ImportError:
23-
_HAS_INT8_LAYOUT = False
24-
logging.warning("INT8 blockwise layout not available")
24+
try:
25+
from .quant_layouts.int8_layout import BlockWiseINT8Layout
26+
_HAS_INT8_LAYOUT = True
27+
except ImportError:
28+
_HAS_INT8_LAYOUT = False
29+
logging.warning("INT8 blockwise layout not available")
2530

2631
try:
2732
from comfy_kitchen.tensor.int8 import TensorWiseINT8Layout
@@ -31,22 +36,6 @@
3136
logging.warning("INT8 tensorwise layout not available from comfy_kitchen")
3237

3338

34-
def tensor_to_dict(tensor_data: torch.Tensor) -> dict:
35-
"""
36-
Convert a torch.uint8 tensor containing JSON bytes to a dictionary.
37-
"""
38-
try:
39-
if tensor_data.dtype == torch.uint8:
40-
byte_data = bytes(tensor_data.tolist())
41-
json_str = byte_data.decode("utf-8")
42-
return json.loads(json_str)
43-
else:
44-
return {}
45-
except Exception as e:
46-
logging.debug(f"Failed to parse comfy_quant metadata using tensor_to_dict: {e}")
47-
return {}
48-
49-
5039
class UnifiedQuantOps(manual_cast):
5140
"""
5241
Unified operations class that handles INT8, FP8, MXFP8, and NVFP4 formats.

utils/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""QuantOps utilities."""
2-
from .safetensors_loader import (
3-
MemoryEfficientSafeOpen,
4-
load_fp8_state_dict,
5-
get_layer_metadata,
6-
)
2+
from unifiedefficientloader import UnifiedSafetensorsLoader, tensor_to_dict
3+
from .safetensors_loader import extract_quantization_metadata, detect_quant_format, _is_scale_tensor
74

85
__all__ = [
9-
"MemoryEfficientSafeOpen",
10-
"load_fp8_state_dict",
11-
"get_layer_metadata",
6+
"UnifiedSafetensorsLoader",
7+
"tensor_to_dict",
8+
"extract_quantization_metadata",
9+
"detect_quant_format",
10+
"_is_scale_tensor",
1211
]

0 commit comments

Comments
 (0)