2020
2121from ...configuration_utils import ConfigMixin , register_to_config
2222from ...loaders import PeftAdapterMixin
23- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
23+ from ...utils import apply_lora_scale , logging
2424from ...utils .torch_utils import maybe_allow_in_graph
2525from ..attention import Attention , AttentionMixin , FeedForward
2626from ..attention_processor import CogVideoXAttnProcessor2_0 , FusedCogVideoXAttnProcessor2_0
@@ -363,6 +363,7 @@ def unfuse_qkv_projections(self):
363363 if self .original_attn_processors is not None :
364364 self .set_attn_processor (self .original_attn_processors )
365365
366+ @apply_lora_scale ("attention_kwargs" )
366367 def forward (
367368 self ,
368369 hidden_states : torch .Tensor ,
@@ -374,21 +375,6 @@ def forward(
374375 attention_kwargs : dict [str , Any ] | None = None ,
375376 return_dict : bool = True ,
376377 ) -> tuple [torch .Tensor ] | Transformer2DModelOutput :
377- if attention_kwargs is not None :
378- attention_kwargs = attention_kwargs .copy ()
379- lora_scale = attention_kwargs .pop ("scale" , 1.0 )
380- else :
381- lora_scale = 1.0
382-
383- if USE_PEFT_BACKEND :
384- # weight the lora layers by setting `lora_scale` for each PEFT layer
385- scale_lora_layers (self , lora_scale )
386- else :
387- if attention_kwargs is not None and attention_kwargs .get ("scale" , None ) is not None :
388- logger .warning (
389- "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
390- )
391-
392378 batch_size , num_frames , channels , height , width = hidden_states .shape
393379
394380 # 1. Time embedding
@@ -454,10 +440,6 @@ def forward(
454440 )
455441 output = output .permute (0 , 1 , 5 , 4 , 2 , 6 , 3 , 7 ).flatten (6 , 7 ).flatten (4 , 5 ).flatten (1 , 2 )
456442
457- if USE_PEFT_BACKEND :
458- # remove `lora_scale` from each PEFT layer
459- unscale_lora_layers (self , lora_scale )
460-
461443 if not return_dict :
462444 return (output ,)
463445 return Transformer2DModelOutput (sample = output )
0 commit comments