Skip to content

Commit 3b1ccd7

Browse files
authored
Merge branch 'main' into cp-fixes-attn-backends
2 parents 0c35ed4 + 5851928 commit 3b1ccd7

9 files changed

Lines changed: 389 additions & 7 deletions

File tree

docs/source/en/api/models/controlnet.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ url = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/m
3333
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
3434
```
3535

36+
## Loading from Control LoRA
37+
38+
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
39+
40+
```py
41+
from diffusers import ControlNetModel, UNet2DConditionModel
42+
43+
lora_id = "stabilityai/control-lora"
44+
lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
45+
46+
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
47+
controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
48+
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
49+
```
50+
3651
## ControlNetModel
3752

3853
[[autodoc]] ControlNetModel
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Control-LoRA inference example
2+
3+
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
4+
5+
## Installing the dependencies
6+
7+
Before running the scripts, make sure to install the library's training dependencies:
8+
9+
**Important**
10+
11+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
12+
```bash
13+
git clone https://github.com/huggingface/diffusers
14+
cd diffusers
15+
pip install .
16+
```
17+
18+
Then cd in the example folder and run
19+
```bash
20+
pip install -r requirements.txt
21+
```
22+
23+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
24+
25+
```bash
26+
accelerate config
27+
```
28+
29+
## Inference on SDXL
30+
31+
[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) provides a set of Control-LoRA weights for SDXL. Here we use the `canny` condition to generate an image from a text prompt and a reference image.
32+
33+
```bash
34+
python control_lora.py
35+
```
36+
37+
## Acknowledgements
38+
39+
- [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora)
40+
- [comfyanonymous/ControlNet-v1-1_fp16_safetensors](https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors)
41+
- [HighCWu/control-lora-v2](https://github.com/HighCWu/control-lora-v2)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import cv2
2+
import numpy as np
3+
import torch
4+
from PIL import Image
5+
6+
from diffusers import (
7+
AutoencoderKL,
8+
ControlNetModel,
9+
StableDiffusionXLControlNetPipeline,
10+
UNet2DConditionModel,
11+
)
12+
from diffusers.utils import load_image, make_image_grid
13+
14+
15+
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
16+
lora_id = "stabilityai/control-lora"
17+
lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
18+
19+
unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
20+
controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
21+
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
22+
23+
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
24+
negative_prompt = "low quality, bad quality, sketches"
25+
26+
image = load_image(
27+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
28+
)
29+
30+
controlnet_conditioning_scale = 1.0 # recommended for good generalization
31+
32+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.bfloat16)
33+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
34+
pipe_id,
35+
unet=unet,
36+
controlnet=controlnet,
37+
vae=vae,
38+
torch_dtype=torch.bfloat16,
39+
safety_checker=None,
40+
).to("cuda")
41+
42+
image = np.array(image)
43+
image = cv2.Canny(image, 100, 200)
44+
image = image[:, :, None]
45+
image = np.concatenate([image, image, image], axis=2)
46+
image = Image.fromarray(image)
47+
48+
images = pipe(
49+
prompt,
50+
negative_prompt=negative_prompt,
51+
image=image,
52+
controlnet_conditioning_scale=controlnet_conditioning_scale,
53+
num_images_per_prompt=4,
54+
).images
55+
56+
final_image = [image] + images
57+
grid = make_image_grid(final_image, 1, 5)
58+
grid.save("hf-logo_canny.png")

src/diffusers/loaders/peft.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
MIN_PEFT_VERSION,
2828
USE_PEFT_BACKEND,
2929
check_peft_version,
30+
convert_sai_sd_control_lora_state_dict_to_peft,
3031
convert_unet_state_dict_to_peft,
3132
delete_adapter_layers,
3233
get_adapter_name,
@@ -232,6 +233,13 @@ def load_lora_adapter(
232233
if "lora_A" not in first_key:
233234
state_dict = convert_unet_state_dict_to_peft(state_dict)
234235

236+
# Control LoRA from SAI is different from BFL Control LoRA
237+
# https://huggingface.co/stabilityai/control-lora
238+
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
239+
is_sai_sd_control_lora = "lora_controlnet" in state_dict
240+
if is_sai_sd_control_lora:
241+
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
242+
235243
rank = {}
236244
for key, val in state_dict.items():
237245
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
@@ -263,6 +271,14 @@ def load_lora_adapter(
263271
adapter_name=adapter_name,
264272
)
265273

274+
# Adjust LoRA config for Control LoRA
275+
if is_sai_sd_control_lora:
276+
lora_config.lora_alpha = lora_config.r
277+
lora_config.alpha_pattern = lora_config.rank_pattern
278+
lora_config.bias = "all"
279+
lora_config.modules_to_save = lora_config.exclude_modules
280+
lora_config.exclude_modules = None
281+
266282
# <Unsafe code
267283
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
268284
# Now we remove any existing hooks to `_pipeline`.

src/diffusers/models/controlnets/controlnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.nn import functional as F
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
22+
from ...loaders import PeftAdapterMixin
2223
from ...loaders.single_file_model import FromOriginalModelMixin
2324
from ...utils import BaseOutput, logging
2425
from ..attention import AttentionMixin
@@ -106,7 +107,7 @@ def forward(self, conditioning):
106107
return embedding
107108

108109

109-
class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
110+
class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
110111
"""
111112
A ControlNet model.
112113

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import functools
1616
import math
17+
from math import prod
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

1920
import numpy as np
@@ -363,7 +364,13 @@ def __call__(
363364
@maybe_allow_in_graph
364365
class QwenImageTransformerBlock(nn.Module):
365366
def __init__(
366-
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
367+
self,
368+
dim: int,
369+
num_attention_heads: int,
370+
attention_head_dim: int,
371+
qk_norm: str = "rms_norm",
372+
eps: float = 1e-6,
373+
zero_cond_t: bool = False,
367374
):
368375
super().__init__()
369376

@@ -403,10 +410,43 @@ def __init__(
403410
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
404411
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
405412

406-
def _modulate(self, x, mod_params):
413+
self.zero_cond_t = zero_cond_t
414+
415+
def _modulate(self, x, mod_params, index=None):
407416
"""Apply modulation to input tensor"""
417+
# x: b l d, shift: b d, scale: b d, gate: b d
408418
shift, scale, gate = mod_params.chunk(3, dim=-1)
409-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
419+
420+
if index is not None:
421+
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
422+
# So shift, scale, gate have shape [2*actual_batch, d]
423+
actual_batch = shift.size(0) // 2
424+
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
425+
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
426+
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
427+
428+
# index: [b, l] where b is actual batch size
429+
# Expand to [b, l, 1] to match feature dimension
430+
index_expanded = index.unsqueeze(-1) # [b, l, 1]
431+
432+
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
433+
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
434+
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
435+
scale_0_exp = scale_0.unsqueeze(1)
436+
scale_1_exp = scale_1.unsqueeze(1)
437+
gate_0_exp = gate_0.unsqueeze(1)
438+
gate_1_exp = gate_1.unsqueeze(1)
439+
440+
# Use torch.where to select based on index
441+
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
442+
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
443+
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
444+
else:
445+
shift_result = shift.unsqueeze(1)
446+
scale_result = scale.unsqueeze(1)
447+
gate_result = gate.unsqueeze(1)
448+
449+
return x * (1 + scale_result) + shift_result, gate_result
410450

411451
def forward(
412452
self,
@@ -416,9 +456,13 @@ def forward(
416456
temb: torch.Tensor,
417457
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
418458
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
459+
modulate_index: Optional[List[int]] = None,
419460
) -> Tuple[torch.Tensor, torch.Tensor]:
420461
# Get modulation parameters for both streams
421462
img_mod_params = self.img_mod(temb) # [B, 6*dim]
463+
464+
if self.zero_cond_t:
465+
temb = torch.chunk(temb, 2, dim=0)[0]
422466
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
423467

424468
# Split modulation parameters for norm1 and norm2
@@ -427,7 +471,7 @@ def forward(
427471

428472
# Process image stream - norm1 + modulation
429473
img_normed = self.img_norm1(hidden_states)
430-
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
474+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index)
431475

432476
# Process text stream - norm1 + modulation
433477
txt_normed = self.txt_norm1(encoder_hidden_states)
@@ -457,7 +501,7 @@ def forward(
457501

458502
# Process image stream - norm2 + MLP
459503
img_normed2 = self.img_norm2(hidden_states)
460-
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
504+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index)
461505
img_mlp_output = self.img_mlp(img_modulated2)
462506
hidden_states = hidden_states + img_gate2 * img_mlp_output
463507

@@ -533,6 +577,7 @@ def __init__(
533577
joint_attention_dim: int = 3584,
534578
guidance_embeds: bool = False, # TODO: this should probably be removed
535579
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
580+
zero_cond_t: bool = False,
536581
):
537582
super().__init__()
538583
self.out_channels = out_channels or in_channels
@@ -553,6 +598,7 @@ def __init__(
553598
dim=self.inner_dim,
554599
num_attention_heads=num_attention_heads,
555600
attention_head_dim=attention_head_dim,
601+
zero_cond_t=zero_cond_t,
556602
)
557603
for _ in range(num_layers)
558604
]
@@ -562,6 +608,7 @@ def __init__(
562608
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
563609

564610
self.gradient_checkpointing = False
611+
self.zero_cond_t = zero_cond_t
565612

566613
def forward(
567614
self,
@@ -618,6 +665,17 @@ def forward(
618665
hidden_states = self.img_in(hidden_states)
619666

620667
timestep = timestep.to(hidden_states.dtype)
668+
669+
if self.zero_cond_t:
670+
timestep = torch.cat([timestep, timestep * 0], dim=0)
671+
modulate_index = torch.tensor(
672+
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
673+
device=timestep.device,
674+
dtype=torch.int,
675+
)
676+
else:
677+
modulate_index = None
678+
621679
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
622680
encoder_hidden_states = self.txt_in(encoder_hidden_states)
623681

@@ -641,6 +699,8 @@ def forward(
641699
encoder_hidden_states_mask,
642700
temb,
643701
image_rotary_emb,
702+
attention_kwargs,
703+
modulate_index,
644704
)
645705

646706
else:
@@ -651,6 +711,7 @@ def forward(
651711
temb=temb,
652712
image_rotary_emb=image_rotary_emb,
653713
joint_attention_kwargs=attention_kwargs,
714+
modulate_index=modulate_index,
654715
)
655716

656717
# controlnet residual
@@ -659,6 +720,8 @@ def forward(
659720
interval_control = int(np.ceil(interval_control))
660721
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
661722

723+
if self.zero_cond_t:
724+
temb = temb.chunk(2, dim=0)[0]
662725
# Use only the image part (hidden_states) from the dual-stream blocks
663726
hidden_states = self.norm_out(hidden_states, temb)
664727
output = self.proj_out(hidden_states)

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .remote_utils import remote_decode
144144
from .state_dict_utils import (
145145
convert_all_state_dict_to_peft,
146+
convert_sai_sd_control_lora_state_dict_to_peft,
146147
convert_state_dict_to_diffusers,
147148
convert_state_dict_to_kohya,
148149
convert_state_dict_to_peft,

0 commit comments

Comments
 (0)