Skip to content

Commit 8abcf35

Browse files
authored
feat: implement apply_lora_scale to remove boilerplate. (#12994)
* feat: implement apply_lora_scale to remove boilerplate. * apply to the rest. * up * remove more. * remove. * fix * apply feedback.
1 parent 2843b3d commit 8abcf35

37 files changed

+137
-640
lines changed

src/diffusers/models/controlnets/controlnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
2323
from ...loaders.single_file_model import FromOriginalModelMixin
24-
from ...utils import BaseOutput, logging
24+
from ...utils import BaseOutput, apply_lora_scale, logging
2525
from ..attention import AttentionMixin
2626
from ..attention_processor import (
2727
ADDED_KV_ATTENTION_PROCESSORS,
@@ -598,6 +598,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[i
598598
for module in self.children():
599599
fn_recursive_set_attention_slice(module, reversed_slice_size)
600600

601+
@apply_lora_scale("cross_attention_kwargs")
601602
def forward(
602603
self,
603604
sample: torch.Tensor,

src/diffusers/models/controlnets/controlnet_flux.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23-
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
23+
from ...utils import (
24+
BaseOutput,
25+
apply_lora_scale,
26+
logging,
27+
)
2428
from ..attention import AttentionMixin
2529
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
2630
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
@@ -150,6 +154,7 @@ def from_transformer(
150154

151155
return controlnet
152156

157+
@apply_lora_scale("joint_attention_kwargs")
153158
def forward(
154159
self,
155160
hidden_states: torch.Tensor,
@@ -197,20 +202,6 @@ def forward(
197202
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
198203
`tuple` where the first element is the sample tensor.
199204
"""
200-
if joint_attention_kwargs is not None:
201-
joint_attention_kwargs = joint_attention_kwargs.copy()
202-
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
203-
else:
204-
lora_scale = 1.0
205-
206-
if USE_PEFT_BACKEND:
207-
# weight the lora layers by setting `lora_scale` for each PEFT layer
208-
scale_lora_layers(self, lora_scale)
209-
else:
210-
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
211-
logger.warning(
212-
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
213-
)
214205
hidden_states = self.x_embedder(hidden_states)
215206

216207
if self.input_hint_block is not None:
@@ -323,10 +314,6 @@ def forward(
323314
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
324315
)
325316

326-
if USE_PEFT_BACKEND:
327-
# remove `lora_scale` from each PEFT layer
328-
unscale_lora_layers(self, lora_scale)
329-
330317
if not return_dict:
331318
return (controlnet_block_samples, controlnet_single_block_samples)
332319

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23-
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
23+
from ...utils import (
24+
BaseOutput,
25+
apply_lora_scale,
26+
deprecate,
27+
logging,
28+
)
2429
from ..attention import AttentionMixin
2530
from ..cache_utils import CacheMixin
2631
from ..controlnets.controlnet import zero_module
@@ -123,6 +128,7 @@ def from_transformer(
123128

124129
return controlnet
125130

131+
@apply_lora_scale("joint_attention_kwargs")
126132
def forward(
127133
self,
128134
hidden_states: torch.Tensor,
@@ -181,20 +187,6 @@ def forward(
181187
standard_warn=False,
182188
)
183189

184-
if joint_attention_kwargs is not None:
185-
joint_attention_kwargs = joint_attention_kwargs.copy()
186-
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
187-
else:
188-
lora_scale = 1.0
189-
190-
if USE_PEFT_BACKEND:
191-
# weight the lora layers by setting `lora_scale` for each PEFT layer
192-
scale_lora_layers(self, lora_scale)
193-
else:
194-
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
195-
logger.warning(
196-
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
197-
)
198190
hidden_states = self.img_in(hidden_states)
199191

200192
# add
@@ -256,10 +248,6 @@ def forward(
256248
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
257249
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
258250

259-
if USE_PEFT_BACKEND:
260-
# remove `lora_scale` from each PEFT layer
261-
unscale_lora_layers(self, lora_scale)
262-
263251
if not return_dict:
264252
return controlnet_block_samples
265253

src/diffusers/models/controlnets/controlnet_sana.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23-
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
23+
from ...utils import BaseOutput, apply_lora_scale, logging
2424
from ..attention import AttentionMixin
2525
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
2626
from ..modeling_outputs import Transformer2DModelOutput
@@ -117,6 +117,7 @@ def __init__(
117117

118118
self.gradient_checkpointing = False
119119

120+
@apply_lora_scale("attention_kwargs")
120121
def forward(
121122
self,
122123
hidden_states: torch.Tensor,
@@ -129,21 +130,6 @@ def forward(
129130
attention_kwargs: dict[str, Any] | None = None,
130131
return_dict: bool = True,
131132
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
132-
if attention_kwargs is not None:
133-
attention_kwargs = attention_kwargs.copy()
134-
lora_scale = attention_kwargs.pop("scale", 1.0)
135-
else:
136-
lora_scale = 1.0
137-
138-
if USE_PEFT_BACKEND:
139-
# weight the lora layers by setting `lora_scale` for each PEFT layer
140-
scale_lora_layers(self, lora_scale)
141-
else:
142-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
143-
logger.warning(
144-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
145-
)
146-
147133
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
148134
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
149135
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -218,10 +204,6 @@ def forward(
218204
block_res_sample = controlnet_block(block_res_sample)
219205
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
220206

221-
if USE_PEFT_BACKEND:
222-
# remove `lora_scale` from each PEFT layer
223-
unscale_lora_layers(self, lora_scale)
224-
225207
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
226208

227209
if not return_dict:

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24+
from ...utils import apply_lora_scale, logging
2525
from ..attention import AttentionMixin, JointTransformerBlock
2626
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
2727
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
@@ -269,6 +269,7 @@ def from_transformer(
269269

270270
return controlnet
271271

272+
@apply_lora_scale("joint_attention_kwargs")
272273
def forward(
273274
self,
274275
hidden_states: torch.Tensor,
@@ -308,21 +309,6 @@ def forward(
308309
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
309310
`tuple` where the first element is the sample tensor.
310311
"""
311-
if joint_attention_kwargs is not None:
312-
joint_attention_kwargs = joint_attention_kwargs.copy()
313-
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
314-
else:
315-
lora_scale = 1.0
316-
317-
if USE_PEFT_BACKEND:
318-
# weight the lora layers by setting `lora_scale` for each PEFT layer
319-
scale_lora_layers(self, lora_scale)
320-
else:
321-
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
322-
logger.warning(
323-
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
324-
)
325-
326312
if self.pos_embed is not None and hidden_states.ndim != 4:
327313
raise ValueError("hidden_states must be 4D when pos_embed is used")
328314

@@ -382,10 +368,6 @@ def forward(
382368
# 6. scaling
383369
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
384370

385-
if USE_PEFT_BACKEND:
386-
# remove `lora_scale` from each PEFT layer
387-
unscale_lora_layers(self, lora_scale)
388-
389371
if not return_dict:
390372
return (controlnet_block_res_samples,)
391373

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24+
from ...utils import apply_lora_scale, logging
2525
from ...utils.torch_utils import maybe_allow_in_graph
2626
from ..attention import AttentionMixin
2727
from ..attention_processor import (
@@ -397,6 +397,7 @@ def unfuse_qkv_projections(self):
397397
if self.original_attn_processors is not None:
398398
self.set_attn_processor(self.original_attn_processors)
399399

400+
@apply_lora_scale("attention_kwargs")
400401
def forward(
401402
self,
402403
hidden_states: torch.FloatTensor,
@@ -405,21 +406,6 @@ def forward(
405406
attention_kwargs: dict[str, Any] | None = None,
406407
return_dict: bool = True,
407408
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
408-
if attention_kwargs is not None:
409-
attention_kwargs = attention_kwargs.copy()
410-
lora_scale = attention_kwargs.pop("scale", 1.0)
411-
else:
412-
lora_scale = 1.0
413-
414-
if USE_PEFT_BACKEND:
415-
# weight the lora layers by setting `lora_scale` for each PEFT layer
416-
scale_lora_layers(self, lora_scale)
417-
else:
418-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
419-
logger.warning(
420-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
421-
)
422-
423409
height, width = hidden_states.shape[-2:]
424410

425411
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -486,10 +472,6 @@ def forward(
486472
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
487473
)
488474

489-
if USE_PEFT_BACKEND:
490-
# remove `lora_scale` from each PEFT layer
491-
unscale_lora_layers(self, lora_scale)
492-
493475
if not return_dict:
494476
return (output,)
495477

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
23+
from ...utils import apply_lora_scale, logging
2424
from ...utils.torch_utils import maybe_allow_in_graph
2525
from ..attention import Attention, AttentionMixin, FeedForward
2626
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
@@ -363,6 +363,7 @@ def unfuse_qkv_projections(self):
363363
if self.original_attn_processors is not None:
364364
self.set_attn_processor(self.original_attn_processors)
365365

366+
@apply_lora_scale("attention_kwargs")
366367
def forward(
367368
self,
368369
hidden_states: torch.Tensor,
@@ -374,21 +375,6 @@ def forward(
374375
attention_kwargs: dict[str, Any] | None = None,
375376
return_dict: bool = True,
376377
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
377-
if attention_kwargs is not None:
378-
attention_kwargs = attention_kwargs.copy()
379-
lora_scale = attention_kwargs.pop("scale", 1.0)
380-
else:
381-
lora_scale = 1.0
382-
383-
if USE_PEFT_BACKEND:
384-
# weight the lora layers by setting `lora_scale` for each PEFT layer
385-
scale_lora_layers(self, lora_scale)
386-
else:
387-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
388-
logger.warning(
389-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
390-
)
391-
392378
batch_size, num_frames, channels, height, width = hidden_states.shape
393379

394380
# 1. Time embedding
@@ -454,10 +440,6 @@ def forward(
454440
)
455441
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
456442

457-
if USE_PEFT_BACKEND:
458-
# remove `lora_scale` from each PEFT layer
459-
unscale_lora_layers(self, lora_scale)
460-
461443
if not return_dict:
462444
return (output,)
463445
return Transformer2DModelOutput(sample=output)

src/diffusers/models/transformers/consisid_transformer_3d.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
23+
from ...utils import apply_lora_scale, logging
2424
from ...utils.torch_utils import maybe_allow_in_graph
2525
from ..attention import Attention, AttentionMixin, FeedForward
2626
from ..attention_processor import CogVideoXAttnProcessor2_0
@@ -620,6 +620,7 @@ def _init_face_inputs(self):
620620
]
621621
)
622622

623+
@apply_lora_scale("attention_kwargs")
623624
def forward(
624625
self,
625626
hidden_states: torch.Tensor,
@@ -632,21 +633,6 @@ def forward(
632633
id_vit_hidden: torch.Tensor | None = None,
633634
return_dict: bool = True,
634635
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
635-
if attention_kwargs is not None:
636-
attention_kwargs = attention_kwargs.copy()
637-
lora_scale = attention_kwargs.pop("scale", 1.0)
638-
else:
639-
lora_scale = 1.0
640-
641-
if USE_PEFT_BACKEND:
642-
# weight the lora layers by setting `lora_scale` for each PEFT layer
643-
scale_lora_layers(self, lora_scale)
644-
else:
645-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
646-
logger.warning(
647-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
648-
)
649-
650636
# fuse clip and insightface
651637
valid_face_emb = None
652638
if self.is_train_face:
@@ -720,10 +706,6 @@ def forward(
720706
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
721707
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
722708

723-
if USE_PEFT_BACKEND:
724-
# remove `lora_scale` from each PEFT layer
725-
unscale_lora_layers(self, lora_scale)
726-
727709
if not return_dict:
728710
return (output,)
729711
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)