Skip to content

Commit e97a1ad

Browse files
LucasStesayakpaul
andauthored
Enable BitsAndBytes quantization in MPS (#13915)
Fix BitsAndBytes quantization in MPS Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent aed5c64 commit e97a1ad

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, quantization_config, **kwargs):
6161
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
6262

6363
def validate_environment(self, *args, **kwargs):
64-
if not (torch.cuda.is_available() or torch.xpu.is_available()):
64+
if not (torch.cuda.is_available() or torch.xpu.is_available() or torch.mps.is_available()):
6565
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
6666
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
6767
raise ImportError(
@@ -240,6 +240,8 @@ def update_device_map(self, device_map):
240240
if device_map is None:
241241
if torch.xpu.is_available():
242242
current_device = f"xpu:{torch.xpu.current_device()}"
243+
elif torch.mps.is_available():
244+
current_device = "mps"
243245
else:
244246
current_device = f"cuda:{torch.cuda.current_device()}"
245247
device_map = {"": current_device}
@@ -411,6 +413,8 @@ def update_device_map(self, device_map):
411413
if device_map is None:
412414
if torch.xpu.is_available():
413415
current_device = f"xpu:{torch.xpu.current_device()}"
416+
elif torch.mps.is_available():
417+
current_device = "mps"
414418
else:
415419
current_device = f"cuda:{torch.cuda.current_device()}"
416420
device_map = {"": current_device}

0 commit comments

Comments
 (0)