@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
5454 encoder_hidden_states = hidden_states
5555
5656 if attn .fused_projections :
57- if attn .cross_attention_dim_head is None :
57+ if not attn .is_cross_attention :
5858 # In self-attention layers, we can fuse the entire QKV projection into a single linear
5959 query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
6060 else :
@@ -502,24 +502,27 @@ def __init__(
502502 dim_head : int = 64 ,
503503 eps : float = 1e-6 ,
504504 cross_attention_dim_head : Optional [int ] = None ,
505+ bias : bool = True ,
505506 processor = None ,
506507 ):
507508 super ().__init__ ()
508509 self .inner_dim = dim_head * heads
509510 self .heads = heads
510- self .cross_attention_head_dim = cross_attention_dim_head
511+ self .cross_attention_dim_head = cross_attention_dim_head
511512 self .kv_inner_dim = self .inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
513+ self .use_bias = bias
514+ self .is_cross_attention = cross_attention_dim_head is not None
512515
513516 # 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
514517 # NOTE: this is not used in "vanilla" WanAttention
515518 self .pre_norm_q = nn .LayerNorm (dim , eps , elementwise_affine = False )
516519 self .pre_norm_kv = nn .LayerNorm (dim , eps , elementwise_affine = False )
517520
518521 # 2. QKV and Output Projections
519- self .to_q = torch .nn .Linear (dim , self .inner_dim , bias = True )
520- self .to_k = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
521- self .to_v = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
522- self .to_out = torch .nn .Linear (self .inner_dim , dim , bias = True )
522+ self .to_q = torch .nn .Linear (dim , self .inner_dim , bias = bias )
523+ self .to_k = torch .nn .Linear (dim , self .kv_inner_dim , bias = bias )
524+ self .to_v = torch .nn .Linear (dim , self .kv_inner_dim , bias = bias )
525+ self .to_out = torch .nn .Linear (self .inner_dim , dim , bias = bias )
523526
524527 # 3. QK Norm
525528 # NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
@@ -682,15 +685,18 @@ def __init__(
682685 self .add_v_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
683686 self .norm_added_k = torch .nn .RMSNorm (dim_head * heads , eps = eps )
684687
685- self .is_cross_attention = cross_attention_dim_head is not None
688+ if is_cross_attention is not None :
689+ self .is_cross_attention = is_cross_attention
690+ else :
691+ self .is_cross_attention = cross_attention_dim_head is not None
686692
687693 self .set_processor (processor )
688694
689695 def fuse_projections (self ):
690696 if getattr (self , "fused_projections" , False ):
691697 return
692698
693- if self .cross_attention_dim_head is None :
699+ if not self .is_cross_attention :
694700 concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
695701 concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
696702 out_features , in_features = concatenated_weights .shape
0 commit comments