Skip to content

Commit 7872c06

Browse files
committed
add an offload utility that can be used as a context manager.
1 parent cd81349 commit 7872c06

3 files changed

Lines changed: 56 additions & 36 deletions

File tree

.github/workflows/pr_tests_gpu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- "src/diffusers/loaders/peft.py"
1414
- "tests/pipelines/test_pipelines_common.py"
1515
- "tests/models/test_modeling_common.py"
16+
- "examples/**/*.py"
1617
workflow_dispatch:
1718

1819
concurrency:

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
compute_density_for_timestep_sampling,
5959
compute_loss_weighting_for_sd3,
6060
free_memory,
61+
offload_models,
6162
)
6263
from diffusers.utils import (
6364
check_min_version,
@@ -1364,43 +1365,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13641365
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
13651366
# the redundant encoding.
13661367
if not train_dataset.custom_instance_prompts:
1367-
if args.offload:
1368-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1369-
(
1370-
instance_prompt_hidden_states_t5,
1371-
instance_prompt_hidden_states_llama3,
1372-
instance_pooled_prompt_embeds,
1373-
_,
1374-
_,
1375-
_,
1376-
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
1377-
if args.offload:
1378-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1368+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1369+
(
1370+
instance_prompt_hidden_states_t5,
1371+
instance_prompt_hidden_states_llama3,
1372+
instance_pooled_prompt_embeds,
1373+
_,
1374+
_,
1375+
_,
1376+
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
13791377

13801378
# Handle class prompt for prior-preservation.
13811379
if args.with_prior_preservation:
1382-
if args.offload:
1383-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1384-
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1385-
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1386-
)
1387-
if args.offload:
1388-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1380+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1381+
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1382+
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1383+
)
13891384

13901385
validation_embeddings = {}
13911386
if args.validation_prompt is not None:
1392-
if args.offload:
1393-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1394-
(
1395-
validation_embeddings["prompt_embeds_t5"],
1396-
validation_embeddings["prompt_embeds_llama3"],
1397-
validation_embeddings["pooled_prompt_embeds"],
1398-
validation_embeddings["negative_prompt_embeds_t5"],
1399-
validation_embeddings["negative_prompt_embeds_llama3"],
1400-
validation_embeddings["negative_pooled_prompt_embeds"],
1401-
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
1402-
if args.offload:
1403-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1387+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1388+
(
1389+
validation_embeddings["prompt_embeds_t5"],
1390+
validation_embeddings["prompt_embeds_llama3"],
1391+
validation_embeddings["pooled_prompt_embeds"],
1392+
validation_embeddings["negative_prompt_embeds_t5"],
1393+
validation_embeddings["negative_prompt_embeds_llama3"],
1394+
validation_embeddings["negative_pooled_prompt_embeds"],
1395+
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
14041396

14051397
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14061398
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1581,12 +1573,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15811573
if args.cache_latents:
15821574
model_input = latents_cache[step].sample()
15831575
else:
1584-
if args.offload:
1585-
vae = vae.to(accelerator.device)
1586-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1576+
with offload_models(vae, device=accelerator.device, offload=args.offload):
1577+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
15871578
model_input = vae.encode(pixel_values).latent_dist.sample()
1588-
if args.offload:
1589-
vae = vae.to("cpu")
1579+
15901580
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
15911581
model_input = model_input.to(dtype=weight_dtype)
15921582

src/diffusers/training_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import gc
44
import math
55
import random
6+
from contextlib import contextmanager
67
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
78

89
import numpy as np
910
import torch
1011

1112
from .models import UNet2DConditionModel
13+
from .pipelines import DiffusionPipeline
1214
from .schedulers import SchedulerMixin
1315
from .utils import (
1416
convert_state_dict_to_diffusers,
@@ -316,6 +318,33 @@ def free_memory():
316318
torch.xpu.empty_cache()
317319

318320

321+
@contextmanager
322+
def offload_models(*modules: torch.nn.Module | DiffusionPipeline, device: str | torch.device, offload: bool = True):
323+
"""
324+
Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
325+
device on exit.
326+
"""
327+
if offload:
328+
is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
329+
# record where each module was
330+
if is_model:
331+
original_devices = [next(m.parameters()).device for m in modules]
332+
else:
333+
assert len(modules) == 1
334+
original_devices = modules[0].device
335+
# move to target device
336+
for m in modules:
337+
m.to(device)
338+
339+
try:
340+
yield
341+
finally:
342+
if offload:
343+
# move back to original devices
344+
for m, orig_dev in zip(modules, original_devices):
345+
m.to(orig_dev)
346+
347+
319348
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
320349
class EMAModel:
321350
"""

0 commit comments

Comments
 (0)