Skip to content

Commit 2af2e9d

Browse files
committed
Fix quant_state None on AMD GPUs by caching quant_state_dict at load time
1 parent 7013f86 commit 2af2e9d

2 files changed

Lines changed: 11 additions & 8 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ git clone https://github.com/mengqin/ComfyUI-UnetBnbModelLoader ComfyUI/custom_n
4242
.\python_embeded\python.exe -s -m pip install -r .\ComfyUI\custom_nodes\ComfyUI-UnetBnbModelLoader\requirements.txt
4343
```
4444

45-
Because this plugin relies on bitsandbytes, we are unable to support macOS and AMD GPUs.
45+
Because this plugin relies on bitsandbytes, we are unable to support macOS.
4646

4747
## Usage
4848

ops.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
5555
data=weight_data, quantized_stats=quant_state_dict, device=device
5656
)
5757
self.weight = bnb_param
58-
58+
self._bnb_quant_state_dict = quant_state_dict
59+
5960
for k in bnb_state_dict.keys():
6061
state_dict.pop(k)
6162
if k in unexpected_keys: unexpected_keys.remove(k)
@@ -94,9 +95,10 @@ def forward(self, x):
9495
if getattr(self, "is_bnb_quantized", lambda : False)():
9596
if not patches_for_this_layer:
9697
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None
97-
return bnb.matmul_4bit(
98-
x, self.weight.t(), bias=bias, quant_state=getattr(self.weight, "quant_state", None)
99-
).to(x.dtype)
98+
qs = getattr(self.weight, "quant_state", None)
99+
if qs is None and hasattr(self, "_bnb_quant_state_dict"):
100+
qs = bnb.functional.QuantState.from_dict(self._bnb_quant_state_dict, device=x.device)
101+
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=qs).to(x.dtype)
100102

101103
try:
102104
base_w = self.weight.to(x.device)
@@ -113,9 +115,10 @@ def forward(self, x):
113115

114116
if weight_final_fp32 is None:
115117
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None
116-
return bnb.matmul_4bit(
117-
x, self.weight.t(), bias=bias, quant_state=getattr(self.weight, "quant_state", None)
118-
).to(x.dtype)
118+
qs = getattr(self.weight, "quant_state", None)
119+
if qs is None and hasattr(self, "_bnb_quant_state_dict"):
120+
qs = bnb.functional.QuantState.from_dict(self._bnb_quant_state_dict, device=x.device)
121+
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=qs).to(x.dtype)
119122

120123
weight_final = comfy.float.stochastic_rounding(weight_final_fp32, x.dtype)
121124
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None

0 commit comments

Comments
 (0)