Skip to content

Commit 4dd4dc9

Browse files
authored
Merge branch 'main' into lora-docs-refactor
2 parents e41bd33 + 6760300 commit 4dd4dc9

40 files changed

Lines changed: 227 additions & 113 deletions

examples/dreambooth/test_dreambooth_lora_sana.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -204,3 +207,42 @@ def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_mult
204207
run_command(self._launch_args + resume_run_args)
205208

206209
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
210+
211+
def test_dreambooth_lora_sana_with_metadata(self):
212+
lora_alpha = 8
213+
rank = 4
214+
with tempfile.TemporaryDirectory() as tmpdir:
215+
test_args = f"""
216+
{self.script_path}
217+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
218+
--instance_data_dir={self.instance_data_dir}
219+
--output_dir={tmpdir}
220+
--resolution=32
221+
--train_batch_size=1
222+
--gradient_accumulation_steps=1
223+
--max_train_steps=4
224+
--lora_alpha={lora_alpha}
225+
--rank={rank}
226+
--checkpointing_steps=2
227+
--max_sequence_length 166
228+
""".split()
229+
230+
test_args.extend(["--instance_prompt", ""])
231+
run_command(self._launch_args + test_args)
232+
233+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
234+
self.assertTrue(os.path.isfile(state_dict_file))
235+
236+
# Check if the metadata was properly serialized.
237+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
238+
metadata = f.metadata() or {}
239+
240+
metadata.pop("format", None)
241+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
242+
if raw:
243+
raw = json.loads(raw)
244+
245+
loaded_lora_alpha = raw["transformer.lora_alpha"]
246+
self.assertTrue(loaded_lora_alpha == lora_alpha)
247+
loaded_lora_rank = raw["transformer.r"]
248+
self.assertTrue(loaded_lora_rank == rank)

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from diffusers.optimization import get_scheduler
5454
from diffusers.training_utils import (
55+
_collate_lora_metadata,
5556
cast_training_params,
5657
compute_density_for_timestep_sampling,
5758
compute_loss_weighting_for_sd3,
@@ -323,9 +324,13 @@ def parse_args(input_args=None):
323324
default=4,
324325
help=("The dimension of the LoRA update matrices."),
325326
)
326-
327+
parser.add_argument(
328+
"--lora_alpha",
329+
type=int,
330+
default=4,
331+
help="LoRA alpha to be used for additional scaling.",
332+
)
327333
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
328-
329334
parser.add_argument(
330335
"--with_prior_preservation",
331336
default=False,
@@ -1023,7 +1028,7 @@ def main(args):
10231028
# now we will add new LoRA weights the transformer layers
10241029
transformer_lora_config = LoraConfig(
10251030
r=args.rank,
1026-
lora_alpha=args.rank,
1031+
lora_alpha=args.lora_alpha,
10271032
lora_dropout=args.lora_dropout,
10281033
init_lora_weights="gaussian",
10291034
target_modules=target_modules,
@@ -1039,10 +1044,11 @@ def unwrap_model(model):
10391044
def save_model_hook(models, weights, output_dir):
10401045
if accelerator.is_main_process:
10411046
transformer_lora_layers_to_save = None
1042-
1047+
modules_to_save = {}
10431048
for model in models:
10441049
if isinstance(model, type(unwrap_model(transformer))):
10451050
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1051+
modules_to_save["transformer"] = model
10461052
else:
10471053
raise ValueError(f"unexpected save model: {model.__class__}")
10481054

@@ -1052,6 +1058,7 @@ def save_model_hook(models, weights, output_dir):
10521058
SanaPipeline.save_lora_weights(
10531059
output_dir,
10541060
transformer_lora_layers=transformer_lora_layers_to_save,
1061+
**_collate_lora_metadata(modules_to_save),
10551062
)
10561063

10571064
def load_model_hook(models, input_dir):
@@ -1507,15 +1514,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15071514
accelerator.wait_for_everyone()
15081515
if accelerator.is_main_process:
15091516
transformer = unwrap_model(transformer)
1517+
modules_to_save = {}
15101518
if args.upcast_before_saving:
15111519
transformer.to(torch.float32)
15121520
else:
15131521
transformer = transformer.to(weight_dtype)
15141522
transformer_lora_layers = get_peft_model_state_dict(transformer)
1523+
modules_to_save["transformer"] = transformer
15151524

15161525
SanaPipeline.save_lora_weights(
15171526
save_directory=args.output_dir,
15181527
transformer_lora_layers=transformer_lora_layers,
1528+
**_collate_lora_metadata(modules_to_save),
15191529
)
15201530

15211531
# Final inference

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
replace_example_docstring,
4242
)
4343
from ...utils.import_utils import is_transformers_version
44-
from ...utils.torch_utils import randn_tensor
44+
from ...utils.torch_utils import empty_device_cache, randn_tensor
4545
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
4646
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
4747

@@ -267,9 +267,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
267267

268268
if self.device.type != "cpu":
269269
self.to("cpu", silence_dtype_warnings=True)
270-
device_mod = getattr(torch, device.type, None)
271-
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
272-
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
270+
empty_device_cache(device.type)
273271

274272
model_sequence = [
275273
self.text_encoder.text_model,

src/diffusers/pipelines/consisid/consisid_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):
294294
295295
Parameters:
296296
- model_path: Path to the directory containing model files.
297-
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
297+
- device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
298298
- dtype: Data type (e.g., torch.float32) for model inference.
299299
300300
Returns:

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
scale_lora_layers,
3838
unscale_lora_layers,
3939
)
40-
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
40+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
4141
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4242
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
4343
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -1339,7 +1339,7 @@ def __call__(
13391339
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
13401340
self.unet.to("cpu")
13411341
self.controlnet.to("cpu")
1342-
torch.cuda.empty_cache()
1342+
empty_device_cache()
13431343

13441344
if not output_type == "latent":
13451345
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
scale_lora_layers,
3737
unscale_lora_layers,
3838
)
39-
from ...utils.torch_utils import is_compiled_module, randn_tensor
39+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
4040
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4141
from ..stable_diffusion import StableDiffusionPipelineOutput
4242
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -1311,7 +1311,7 @@ def __call__(
13111311
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
13121312
self.unet.to("cpu")
13131313
self.controlnet.to("cpu")
1314-
torch.cuda.empty_cache()
1314+
empty_device_cache()
13151315

13161316
if not output_type == "latent":
13171317
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
scale_lora_layers,
3939
unscale_lora_layers,
4040
)
41-
from ...utils.torch_utils import is_compiled_module, randn_tensor
41+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
4242
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4343
from ..stable_diffusion import StableDiffusionPipelineOutput
4444
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -1500,7 +1500,7 @@ def __call__(
15001500
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
15011501
self.unet.to("cpu")
15021502
self.controlnet.to("cpu")
1503-
torch.cuda.empty_cache()
1503+
empty_device_cache()
15041504

15051505
if not output_type == "latent":
15061506
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
scale_lora_layers,
5252
unscale_lora_layers,
5353
)
54-
from ...utils.torch_utils import is_compiled_module, randn_tensor
54+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
5555
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5656
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
5757

@@ -1858,7 +1858,7 @@ def denoising_value_valid(dnv):
18581858
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
18591859
self.unet.to("cpu")
18601860
self.controlnet.to("cpu")
1861-
torch.cuda.empty_cache()
1861+
empty_device_cache()
18621862

18631863
if not output_type == "latent":
18641864
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,11 @@ def __call__(
14651465

14661466
# Relevant thread:
14671467
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1468-
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1468+
if (
1469+
torch.cuda.is_available()
1470+
and (is_unet_compiled and is_controlnet_compiled)
1471+
and is_torch_higher_equal_2_1
1472+
):
14691473
torch._inductor.cudagraph_mark_step_begin()
14701474
# expand the latents if we are doing classifier free guidance
14711475
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
scale_lora_layers,
5454
unscale_lora_layers,
5555
)
56-
from ...utils.torch_utils import is_compiled_module, randn_tensor
56+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
5757
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5858
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
5959

@@ -921,7 +921,7 @@ def prepare_latents(
921921
# Offload text encoder if `enable_model_cpu_offload` was enabled
922922
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
923923
self.text_encoder_2.to("cpu")
924-
torch.cuda.empty_cache()
924+
empty_device_cache()
925925

926926
image = image.to(device=device, dtype=dtype)
927927

@@ -1632,7 +1632,7 @@ def __call__(
16321632
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
16331633
self.unet.to("cpu")
16341634
self.controlnet.to("cpu")
1635-
torch.cuda.empty_cache()
1635+
empty_device_cache()
16361636

16371637
if not output_type == "latent":
16381638
# make sure the VAE is in float32 mode, as it overflows in float16

0 commit comments

Comments
 (0)