Skip to content

Commit d7fa445

Browse files
SunMarcCopilotsayakpaul
authored
Remove 8bit device restriction (#12972)
* allow to * update version * fix version again * again * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * style * xfail * add pr --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 7feb4fc commit d7fa445

File tree

3 files changed

+34
-23
lines changed

3 files changed

+34
-23
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,12 +1360,12 @@ def cuda(self, *args, **kwargs):
13601360

13611361
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
13621362
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1363-
if getattr(self, "is_loaded_in_8bit", False):
1363+
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
13641364
raise ValueError(
1365-
"Calling `cuda()` is not supported for `8-bit` quantized models. "
1366-
" Please use the model as it is, since the model has already been set to the correct devices."
1365+
"Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
1366+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
13671367
)
1368-
elif is_bitsandbytes_version("<", "0.43.2"):
1368+
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
13691369
raise ValueError(
13701370
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
13711371
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
@@ -1412,17 +1412,16 @@ def to(self, *args, **kwargs):
14121412
)
14131413

14141414
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1415-
if getattr(self, "is_loaded_in_8bit", False):
1415+
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
14161416
raise ValueError(
1417-
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
1418-
" model has already been set to the correct devices and casted to the correct `dtype`."
1417+
"Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
1418+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
14191419
)
1420-
elif is_bitsandbytes_version("<", "0.43.2"):
1420+
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
14211421
raise ValueError(
14221422
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
14231423
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
14241424
)
1425-
14261425
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
14271426
logger.warning(
14281427
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
deprecate,
6161
is_accelerate_available,
6262
is_accelerate_version,
63+
is_bitsandbytes_version,
6364
is_hpu_available,
6465
is_torch_npu_available,
6566
is_torch_version,
@@ -444,7 +445,10 @@ def module_is_sequentially_offloaded(module):
444445

445446
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
446447

447-
if is_loaded_in_8bit_bnb:
448+
# https://github.com/huggingface/accelerate/pull/3907
449+
if is_loaded_in_8bit_bnb and (
450+
is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0")
451+
):
448452
return False
449453

450454
return hasattr(module, "_hf_hook") and (
@@ -523,9 +527,10 @@ def module_is_offloaded(module):
523527
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
524528
)
525529

526-
if is_loaded_in_8bit_bnb and device is not None:
530+
if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"):
527531
logger.warning(
528532
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
533+
"You need to upgrade bitsandbytes to at least 0.48.0"
529534
)
530535

531536
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
@@ -542,6 +547,14 @@ def module_is_offloaded(module):
542547
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
543548
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
544549
module.to(device=device)
550+
# added here https://github.com/huggingface/transformers/pull/43258
551+
if (
552+
is_loaded_in_8bit_bnb
553+
and device is not None
554+
and is_transformers_version(">", "4.58.0")
555+
and is_bitsandbytes_version(">=", "0.48.0")
556+
):
557+
module.to(device=device)
545558
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
546559
module.to(device, dtype)
547560

@@ -1223,7 +1236,9 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
12231236

12241237
# This is because the model would already be placed on a CUDA device.
12251238
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
1226-
if is_loaded_in_8bit_bnb:
1239+
if is_loaded_in_8bit_bnb and (
1240+
is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0")
1241+
):
12271242
logger.info(
12281243
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
12291244
)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -288,31 +288,29 @@ def test_config_from_pretrained(self):
288288
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
289289
self.assertTrue(hasattr(linear.weight, "SCB"))
290290

291+
@require_bitsandbytes_version_greater("0.48.0")
291292
def test_device_and_dtype_assignment(self):
292293
r"""
293294
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
294295
Checks also if other models are casted correctly.
295296
"""
296-
with self.assertRaises(ValueError):
297-
# Tries with `str`
298-
self.model_8bit.to("cpu")
299297

300298
with self.assertRaises(ValueError):
301299
# Tries with a `dtype``
302300
self.model_8bit.to(torch.float16)
303301

304-
with self.assertRaises(ValueError):
305-
# Tries with a `device`
306-
self.model_8bit.to(torch.device(f"{torch_device}:0"))
307-
308302
with self.assertRaises(ValueError):
309303
# Tries with a `device`
310304
self.model_8bit.float()
311305

312306
with self.assertRaises(ValueError):
313-
# Tries with a `device`
307+
# Tries with a `dtype`
314308
self.model_8bit.half()
315309

310+
# This should work with 0.48.0
311+
self.model_8bit.to("cpu")
312+
self.model_8bit.to(torch.device(f"{torch_device}:0"))
313+
316314
# Test if we did not break anything
317315
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
318316
input_dict_for_transformer = self.get_dummy_inputs()
@@ -837,7 +835,7 @@ def test_serialization_sharded(self):
837835

838836

839837
@require_torch_version_greater_equal("2.6.0")
840-
@require_bitsandbytes_version_greater("0.45.5")
838+
@require_bitsandbytes_version_greater("0.48.0")
841839
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
842840
@property
843841
def quantization_config(self):
@@ -848,7 +846,7 @@ def quantization_config(self):
848846
)
849847

850848
@pytest.mark.xfail(
851-
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
849+
reason="Test fails because of a type change when recompiling."
852850
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
853851
)
854852
def test_torch_compile(self):
@@ -858,6 +856,5 @@ def test_torch_compile(self):
858856
def test_torch_compile_with_cpu_offload(self):
859857
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
860858

861-
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
862859
def test_torch_compile_with_group_offload_leaf(self):
863860
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)

0 commit comments

Comments
 (0)