Skip to content

Commit 9553b79

Browse files
authored
Merge branch 'main' into fix-memory-address-problem
2 parents b572234 + a185e1a commit 9553b79

11 files changed

Lines changed: 240 additions & 76 deletions

File tree

docs/source/en/optimization/fp16.md

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,63 @@ pipeline(prompt, num_inference_steps=30).images[0]
150150

151151
Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
152152

153+
### Dynamic shape compilation
154+
155+
> [!TIP]
156+
> Make sure to always use the nightly version of PyTorch for better support.
157+
158+
`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation.
159+
160+
To avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change.
161+
162+
```diff
163+
+ torch.fx.experimental._config.use_duck_shape = False
164+
+ pipeline.unet = torch.compile(
165+
pipeline.unet, fullgraph=True, dynamic=True
166+
)
167+
```
168+
169+
Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
170+
171+
Not all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation.
172+
173+
Feel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model.
174+
153175
### Regional compilation
154176

155-
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
156177

157-
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
178+
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
179+
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
180+
181+
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
182+
183+
```py
184+
# pip install -U diffusers
185+
import torch
186+
from diffusers import StableDiffusionXLPipeline
187+
188+
pipe = StableDiffusionXLPipeline.from_pretrained(
189+
"stabilityai/stable-diffusion-xl-base-1.0",
190+
torch_dtype=torch.float16,
191+
).to("cuda")
192+
193+
# Compile only the repeated Transformer layers inside the UNet
194+
pipe.unet.compile_repeated_blocks(fullgraph=True)
195+
```
196+
197+
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
198+
199+
200+
```py
201+
class MyUNet(ModelMixin):
202+
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
203+
```
204+
205+
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
206+
207+
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
208+
209+
158210

159211
```py
160212
# pip install -U accelerate
@@ -167,6 +219,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
167219
).to("cuda")
168220
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
169221
```
222+
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
223+
170224

171225
### Graph breaks
172226

@@ -241,4 +295,4 @@ An input is projected into three subspaces, represented by the projection matric
241295

242296
```py
243297
pipeline.fuse_qkv_projections()
244-
```
298+
```

src/diffusers/hooks/group_offloading.py

Lines changed: 81 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,58 @@ def _pinned_memory_tensors(self):
137137
finally:
138138
pinned_dict = None
139139

140+
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
141+
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
142+
if self.record_stream and current_stream is not None:
143+
tensor.data.record_stream(current_stream)
144+
145+
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
146+
for group_module in self.modules:
147+
for param in group_module.parameters():
148+
source = pinned_memory[param] if pinned_memory else param.data
149+
self._transfer_tensor_to_device(param, source, current_stream)
150+
for buffer in group_module.buffers():
151+
source = pinned_memory[buffer] if pinned_memory else buffer.data
152+
self._transfer_tensor_to_device(buffer, source, current_stream)
153+
154+
for param in self.parameters:
155+
source = pinned_memory[param] if pinned_memory else param.data
156+
self._transfer_tensor_to_device(param, source, current_stream)
157+
158+
for buffer in self.buffers:
159+
source = pinned_memory[buffer] if pinned_memory else buffer.data
160+
self._transfer_tensor_to_device(buffer, source, current_stream)
161+
162+
def _onload_from_disk(self, current_stream):
163+
if self.stream is not None:
164+
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
165+
166+
for key, tensor_obj in self.key_to_tensor.items():
167+
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
168+
169+
with self._pinned_memory_tensors() as pinned_memory:
170+
for key, tensor_obj in self.key_to_tensor.items():
171+
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
172+
173+
self.cpu_param_dict.clear()
174+
175+
else:
176+
onload_device = (
177+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
178+
)
179+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
180+
for key, tensor_obj in self.key_to_tensor.items():
181+
tensor_obj.data = loaded_tensors[key]
182+
183+
def _onload_from_memory(self, current_stream):
184+
if self.stream is not None:
185+
with self._pinned_memory_tensors() as pinned_memory:
186+
self._process_tensors_from_modules(pinned_memory, current_stream)
187+
else:
188+
self._process_tensors_from_modules(None, current_stream)
189+
140190
@torch.compiler.disable()
141191
def onload_(self):
142-
r"""Onloads the group of modules to the onload_device."""
143192
torch_accelerator_module = (
144193
getattr(torch, torch.accelerator.current_accelerator().type)
145194
if hasattr(torch, "accelerator")
@@ -177,67 +226,30 @@ def onload_(self):
177226
self.stream.synchronize()
178227

179228
with context:
180-
if self.stream is not None:
181-
with self._pinned_memory_tensors() as pinned_memory:
182-
for group_module in self.modules:
183-
for param in group_module.parameters():
184-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
185-
if self.record_stream:
186-
param.data.record_stream(current_stream)
187-
for buffer in group_module.buffers():
188-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
189-
if self.record_stream:
190-
buffer.data.record_stream(current_stream)
191-
192-
for param in self.parameters:
193-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
194-
if self.record_stream:
195-
param.data.record_stream(current_stream)
196-
197-
for buffer in self.buffers:
198-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
199-
if self.record_stream:
200-
buffer.data.record_stream(current_stream)
201-
229+
if self.offload_to_disk_path:
230+
self._onload_from_disk(current_stream)
202231
else:
203-
for group_module in self.modules:
204-
for param in group_module.parameters():
205-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
206-
for buffer in group_module.buffers():
207-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
208-
209-
for param in self.parameters:
210-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
211-
212-
for buffer in self.buffers:
213-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
214-
if self.record_stream:
215-
buffer.data.record_stream(current_stream)
216-
217-
@torch.compiler.disable()
218-
def offload_(self):
219-
r"""Offloads the group of modules to the offload_device."""
220-
if self.offload_to_disk_path:
221-
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
222-
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
223-
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
224-
# we perform a write.
225-
# Check if the file has been saved in this session or if it already exists on disk.
226-
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
227-
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
228-
tensors_to_save = {
229-
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
230-
}
231-
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
232-
233-
# The group is now considered offloaded to disk for the rest of the session.
234-
self._is_offloaded_to_disk = True
235-
236-
# We do this to free up the RAM which is still holding the up tensor data.
237-
for tensor_obj in self.tensor_to_key.keys():
238-
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
239-
return
240-
232+
self._onload_from_memory(current_stream)
233+
234+
def _offload_to_disk(self):
235+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
236+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
237+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
238+
# we perform a write.
239+
# Check if the file has been saved in this session or if it already exists on disk.
240+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
241+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
242+
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
243+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
244+
245+
# The group is now considered offloaded to disk for the rest of the session.
246+
self._is_offloaded_to_disk = True
247+
248+
# We do this to free up the RAM which is still holding the up tensor data.
249+
for tensor_obj in self.tensor_to_key.keys():
250+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
251+
252+
def _offload_to_memory(self):
241253
torch_accelerator_module = (
242254
getattr(torch, torch.accelerator.current_accelerator().type)
243255
if hasattr(torch, "accelerator")
@@ -262,6 +274,14 @@ def offload_(self):
262274
for buffer in self.buffers:
263275
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
264276

277+
@torch.compiler.disable()
278+
def offload_(self):
279+
r"""Offloads the group of modules to the offload_device."""
280+
if self.offload_to_disk_path:
281+
self._offload_to_disk()
282+
else:
283+
self._offload_to_memory()
284+
265285

266286
class GroupOffloadingHook(ModelHook):
267287
r"""

src/diffusers/models/modeling_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266266
_keep_in_fp32_modules = None
267267
_skip_layerwise_casting_patterns = None
268268
_supports_group_offloading = True
269+
_repeated_blocks = []
269270

270271
def __init__(self):
271272
super().__init__()
@@ -1404,6 +1405,39 @@ def float(self, *args):
14041405
else:
14051406
return super().float(*args)
14061407

1408+
def compile_repeated_blocks(self, *args, **kwargs):
1409+
"""
1410+
Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
1411+
compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
1412+
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
1413+
substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
1414+
1415+
The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
1416+
model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
1417+
module whose class name matches will be compiled.
1418+
1419+
Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
1420+
positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
1421+
`torch.compile`.
1422+
"""
1423+
repeated_blocks = getattr(self, "_repeated_blocks", None)
1424+
1425+
if not repeated_blocks:
1426+
raise ValueError(
1427+
"`_repeated_blocks` attribute is empty. "
1428+
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
1429+
)
1430+
has_compiled_region = False
1431+
for submod in self.modules():
1432+
if submod.__class__.__name__ in repeated_blocks:
1433+
submod.compile(*args, **kwargs)
1434+
has_compiled_region = True
1435+
1436+
if not has_compiled_region:
1437+
raise ValueError(
1438+
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
1439+
)
1440+
14071441
@classmethod
14081442
def _load_pretrained_model(
14091443
cls,

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
407407

408408
_supports_gradient_checkpointing = True
409409
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410+
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410411
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
411412

412413
@register_to_config

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
227227
_supports_gradient_checkpointing = True
228228
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229229
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
230+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
230231

231232
@register_to_config
232233
def __init__(

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
870870
"HunyuanVideoPatchEmbed",
871871
"HunyuanVideoTokenRefiner",
872872
]
873+
_repeated_blocks = [
874+
"HunyuanVideoTransformerBlock",
875+
"HunyuanVideoSingleTransformerBlock",
876+
"HunyuanVideoPatchEmbed",
877+
"HunyuanVideoTokenRefiner",
878+
]
873879

874880
@register_to_config
875881
def __init__(

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
328328

329329
_supports_gradient_checkpointing = True
330330
_skip_layerwise_casting_patterns = ["norm"]
331+
_repeated_blocks = ["LTXVideoTransformerBlock"]
331332

332333
@register_to_config
333334
def __init__(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345345
_no_split_modules = ["WanTransformerBlock"]
346346
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347347
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
348+
_repeated_blocks = ["WanTransformerBlock"]
348349

349350
@register_to_config
350351
def __init__(

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class conditioning with `class_embed_type` equal to `None`.
167167
_supports_gradient_checkpointing = True
168168
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
169169
_skip_layerwise_casting_patterns = ["norm"]
170+
_repeated_blocks = ["BasicTransformerBlock"]
170171

171172
@register_to_config
172173
def __init__(

0 commit comments

Comments
 (0)