Skip to content

Commit 784fa62

Browse files
Use device_map="auto" in single file tests to support large models on limited GPU memory (#13816)
* fix flux tests OOM on 24G GPU Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert wrong change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 0cc1cdb commit 784fa62

3 files changed

Lines changed: 13 additions & 7 deletions

File tree

tests/models/testing_utils/single_file.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def teardown_method(self):
107107
backend_empty_cache(torch_device)
108108

109109
def test_single_file_model_config(self):
110-
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
111-
single_file_kwargs = {"device": torch_device}
110+
pretrained_kwargs = {"device_map": "auto", **self.pretrained_model_kwargs}
111+
single_file_kwargs = {"device_map": "auto"}
112112

113113
if self.torch_dtype:
114114
pretrained_kwargs["torch_dtype"] = self.torch_dtype
@@ -127,8 +127,8 @@ def test_single_file_model_config(self):
127127
)
128128

129129
def test_single_file_model_parameters(self):
130-
pretrained_kwargs = {"device_map": str(torch_device), **self.pretrained_model_kwargs}
131-
single_file_kwargs = {"device": torch_device}
130+
pretrained_kwargs = {"device_map": "auto", **self.pretrained_model_kwargs}
131+
single_file_kwargs = {"device_map": "auto"}
132132

133133
if self.torch_dtype:
134134
pretrained_kwargs["torch_dtype"] = self.torch_dtype
@@ -259,7 +259,7 @@ def test_checkpoint_variant_loading(self):
259259
backend_empty_cache(torch_device)
260260

261261
def test_single_file_loading_with_device_map(self):
262-
single_file_kwargs = {"device_map": torch_device}
262+
single_file_kwargs = {"device_map": "auto"}
263263

264264
if self.torch_dtype:
265265
single_file_kwargs["torch_dtype"] = self.torch_dtype

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ def alternate_ckpt_paths(self):
346346
def pretrained_model_name_or_path(self):
347347
return "black-forest-labs/FLUX.1-dev"
348348

349+
@property
350+
def torch_dtype(self):
351+
return torch.bfloat16
352+
349353

350354
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
351355
"""BitsAndBytes quantization tests for Flux Transformer."""

tests/single_file/test_model_flux_transformer_single_file.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import gc
1717

18+
import torch
19+
1820
from diffusers import (
1921
FluxTransformer2DModel,
2022
)
@@ -38,9 +40,9 @@ class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin):
3840
repo_id = "black-forest-labs/FLUX.1-dev"
3941
subfolder = "transformer"
4042

41-
def test_device_map_cuda(self):
43+
def test_device_map_auto(self):
4244
backend_empty_cache(torch_device)
43-
model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")
45+
model = self.model_class.from_single_file(self.ckpt_path, device_map="auto", torch_dtype=torch.bfloat16)
4446

4547
del model
4648
gc.collect()

0 commit comments

Comments
 (0)