@@ -40,7 +40,7 @@ def forward(self, ids):
4040 n_axes = ids .shape [- 1 ]
4141 emb = torch .cat ([self .rope (ids [..., i ], self .axes_dim [i ], self .theta ) for i in range (n_axes )], dim = - 3 )
4242 return emb .unsqueeze (1 )
43-
43+
4444
4545
4646class FluxJointAttention (torch .nn .Module ):
@@ -70,7 +70,7 @@ def apply_rope(self, xq, xk, freqs_cis):
7070 xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
7171 return xq_out .reshape (* xq .shape ).type_as (xq ), xk_out .reshape (* xk .shape ).type_as (xk )
7272
73- def forward (self , hidden_states_a , hidden_states_b , image_rotary_emb , ipadapter_kwargs_list = None ):
73+ def forward (self , hidden_states_a , hidden_states_b , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
7474 batch_size = hidden_states_a .shape [0 ]
7575
7676 # Part A
@@ -91,7 +91,7 @@ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_
9191
9292 q , k = self .apply_rope (q , k , image_rotary_emb )
9393
94- hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v )
94+ hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask )
9595 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
9696 hidden_states = hidden_states .to (q .dtype )
9797 hidden_states_b , hidden_states_a = hidden_states [:, :hidden_states_b .shape [1 ]], hidden_states [:, hidden_states_b .shape [1 ]:]
@@ -103,7 +103,7 @@ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_
103103 else :
104104 hidden_states_b = self .b_to_out (hidden_states_b )
105105 return hidden_states_a , hidden_states_b
106-
106+
107107
108108
109109class FluxJointTransformerBlock (torch .nn .Module ):
@@ -129,12 +129,12 @@ def __init__(self, dim, num_attention_heads):
129129 )
130130
131131
132- def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , ipadapter_kwargs_list = None ):
132+ def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
133133 norm_hidden_states_a , gate_msa_a , shift_mlp_a , scale_mlp_a , gate_mlp_a = self .norm1_a (hidden_states_a , emb = temb )
134134 norm_hidden_states_b , gate_msa_b , shift_mlp_b , scale_mlp_b , gate_mlp_b = self .norm1_b (hidden_states_b , emb = temb )
135135
136136 # Attention
137- attn_output_a , attn_output_b = self .attn (norm_hidden_states_a , norm_hidden_states_b , image_rotary_emb , ipadapter_kwargs_list )
137+ attn_output_a , attn_output_b = self .attn (norm_hidden_states_a , norm_hidden_states_b , image_rotary_emb , attn_mask , ipadapter_kwargs_list )
138138
139139 # Part A
140140 hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -147,7 +147,7 @@ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipad
147147 hidden_states_b = hidden_states_b + gate_mlp_b * self .ff_b (norm_hidden_states_b )
148148
149149 return hidden_states_a , hidden_states_b
150-
150+
151151
152152
153153class FluxSingleAttention (torch .nn .Module ):
@@ -184,7 +184,7 @@ def forward(self, hidden_states, image_rotary_emb):
184184 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
185185 hidden_states = hidden_states .to (q .dtype )
186186 return hidden_states
187-
187+
188188
189189
190190class AdaLayerNormSingle (torch .nn .Module ):
@@ -200,7 +200,7 @@ def forward(self, x, emb):
200200 shift_msa , scale_msa , gate_msa = emb .chunk (3 , dim = 1 )
201201 x = self .norm (x ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
202202 return x , gate_msa
203-
203+
204204
205205
206206class FluxSingleTransformerBlock (torch .nn .Module ):
@@ -225,8 +225,8 @@ def apply_rope(self, xq, xk, freqs_cis):
225225 xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
226226 return xq_out .reshape (* xq .shape ).type_as (xq ), xk_out .reshape (* xk .shape ).type_as (xk )
227227
228-
229- def process_attention (self , hidden_states , image_rotary_emb , ipadapter_kwargs_list = None ):
228+
229+ def process_attention (self , hidden_states , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
230230 batch_size = hidden_states .shape [0 ]
231231
232232 qkv = hidden_states .view (batch_size , - 1 , 3 * self .num_heads , self .head_dim ).transpose (1 , 2 )
@@ -235,29 +235,29 @@ def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_li
235235
236236 q , k = self .apply_rope (q , k , image_rotary_emb )
237237
238- hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v )
238+ hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask )
239239 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
240240 hidden_states = hidden_states .to (q .dtype )
241241 if ipadapter_kwargs_list is not None :
242242 hidden_states = interact_with_ipadapter (hidden_states , q , ** ipadapter_kwargs_list )
243243 return hidden_states
244244
245245
246- def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , ipadapter_kwargs_list = None ):
246+ def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
247247 residual = hidden_states_a
248248 norm_hidden_states , gate = self .norm (hidden_states_a , emb = temb )
249249 hidden_states_a = self .to_qkv_mlp (norm_hidden_states )
250250 attn_output , mlp_hidden_states = hidden_states_a [:, :, :self .dim * 3 ], hidden_states_a [:, :, self .dim * 3 :]
251251
252- attn_output = self .process_attention (attn_output , image_rotary_emb , ipadapter_kwargs_list )
252+ attn_output = self .process_attention (attn_output , image_rotary_emb , attn_mask , ipadapter_kwargs_list )
253253 mlp_hidden_states = torch .nn .functional .gelu (mlp_hidden_states , approximate = "tanh" )
254254
255255 hidden_states_a = torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
256256 hidden_states_a = gate .unsqueeze (1 ) * self .proj_out (hidden_states_a )
257257 hidden_states_a = residual + hidden_states_a
258-
258+
259259 return hidden_states_a , hidden_states_b
260-
260+
261261
262262
263263class AdaLayerNormContinuous (torch .nn .Module ):
@@ -300,7 +300,7 @@ def patchify(self, hidden_states):
300300 def unpatchify (self , hidden_states , height , width ):
301301 hidden_states = rearrange (hidden_states , "B (H W) (C P Q) -> B C (H P) (W Q)" , P = 2 , Q = 2 , H = height // 2 , W = width // 2 )
302302 return hidden_states
303-
303+
304304
305305 def prepare_image_ids (self , latents ):
306306 batch_size , _ , height , width = latents .shape
@@ -317,7 +317,7 @@ def prepare_image_ids(self, latents):
317317 latent_image_ids = latent_image_ids .to (device = latents .device , dtype = latents .dtype )
318318
319319 return latent_image_ids
320-
320+
321321
322322 def tiled_forward (
323323 self ,
@@ -338,11 +338,75 @@ def tiled_forward(
338338 return hidden_states
339339
340340
341+ def construct_mask (self , entity_masks , prompt_seq_len , image_seq_len ):
342+ N = len (entity_masks )
343+ batch_size = entity_masks [0 ].shape [0 ]
344+ total_seq_len = N * prompt_seq_len + image_seq_len
345+ patched_masks = [self .patchify (entity_masks [i ]) for i in range (N )]
346+ attention_mask = torch .ones ((batch_size , total_seq_len , total_seq_len ), dtype = torch .bool ).to (device = entity_masks [0 ].device )
347+
348+ image_start = N * prompt_seq_len
349+ image_end = N * prompt_seq_len + image_seq_len
350+ # prompt-image mask
351+ for i in range (N ):
352+ prompt_start = i * prompt_seq_len
353+ prompt_end = (i + 1 ) * prompt_seq_len
354+ image_mask = torch .sum (patched_masks [i ], dim = - 1 ) > 0
355+ image_mask = image_mask .unsqueeze (1 ).repeat (1 , prompt_seq_len , 1 )
356+ # prompt update with image
357+ attention_mask [:, prompt_start :prompt_end , image_start :image_end ] = image_mask
358+ # image update with prompt
359+ attention_mask [:, image_start :image_end , prompt_start :prompt_end ] = image_mask .transpose (1 , 2 )
360+ # prompt-prompt mask
361+ for i in range (N ):
362+ for j in range (N ):
363+ if i != j :
364+ prompt_start_i = i * prompt_seq_len
365+ prompt_end_i = (i + 1 ) * prompt_seq_len
366+ prompt_start_j = j * prompt_seq_len
367+ prompt_end_j = (j + 1 ) * prompt_seq_len
368+ attention_mask [:, prompt_start_i :prompt_end_i , prompt_start_j :prompt_end_j ] = False
369+
370+ attention_mask = attention_mask .float ()
371+ attention_mask [attention_mask == 0 ] = float ('-inf' )
372+ attention_mask [attention_mask == 1 ] = 0
373+ return attention_mask
374+
375+
376+ def process_entity_masks (self , hidden_states , prompt_emb , entity_prompt_emb , entity_masks , text_ids , image_ids ):
377+ repeat_dim = hidden_states .shape [1 ]
378+ max_masks = 0
379+ attention_mask = None
380+ prompt_embs = [prompt_emb ]
381+ if entity_masks is not None :
382+ # entity_masks
383+ batch_size , max_masks = entity_masks .shape [0 ], entity_masks .shape [1 ]
384+ entity_masks = entity_masks .repeat (1 , 1 , repeat_dim , 1 , 1 )
385+ entity_masks = [entity_masks [:, i , None ].squeeze (1 ) for i in range (max_masks )]
386+ # global mask
387+ global_mask = torch .ones_like (entity_masks [0 ]).to (device = hidden_states .device , dtype = hidden_states .dtype )
388+ entity_masks = entity_masks + [global_mask ] # append global to last
389+ # attention mask
390+ attention_mask = self .construct_mask (entity_masks , prompt_emb .shape [1 ], hidden_states .shape [1 ])
391+ attention_mask = attention_mask .to (device = hidden_states .device , dtype = hidden_states .dtype )
392+ attention_mask = attention_mask .unsqueeze (1 )
393+ # embds: n_masks * b * seq * d
394+ local_embs = [entity_prompt_emb [:, i , None ].squeeze (1 ) for i in range (max_masks )]
395+ prompt_embs = local_embs + prompt_embs # append global to last
396+ prompt_embs = [self .context_embedder (prompt_emb ) for prompt_emb in prompt_embs ]
397+ prompt_emb = torch .cat (prompt_embs , dim = 1 )
398+
399+ # positional embedding
400+ text_ids = torch .cat ([text_ids ] * (max_masks + 1 ), dim = 1 )
401+ image_rotary_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
402+ return prompt_emb , image_rotary_emb , attention_mask
403+
404+
341405 def forward (
342406 self ,
343407 hidden_states ,
344408 timestep , prompt_emb , pooled_prompt_emb , guidance , text_ids , image_ids = None ,
345- tiled = False , tile_size = 128 , tile_stride = 64 ,
409+ tiled = False , tile_size = 128 , tile_stride = 64 , entity_prompt_emb = None , entity_masks = None ,
346410 use_gradient_checkpointing = False ,
347411 ** kwargs
348412 ):
@@ -353,54 +417,59 @@ def forward(
353417 tile_size = tile_size , tile_stride = tile_stride ,
354418 ** kwargs
355419 )
356-
420+
357421 if image_ids is None :
358422 image_ids = self .prepare_image_ids (hidden_states )
359-
423+
360424 conditioning = self .time_embedder (timestep , hidden_states .dtype ) + self .pooled_text_embedder (pooled_prompt_emb )
361425 if self .guidance_embedder is not None :
362426 guidance = guidance * 1000
363427 conditioning = conditioning + self .guidance_embedder (guidance , hidden_states .dtype )
364- prompt_emb = self .context_embedder (prompt_emb )
365- image_rotary_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
366428
367429 height , width = hidden_states .shape [- 2 :]
368430 hidden_states = self .patchify (hidden_states )
369431 hidden_states = self .x_embedder (hidden_states )
370-
432+
433+ if entity_prompt_emb is not None and entity_masks is not None :
434+ prompt_emb , image_rotary_emb , attention_mask = self .process_entity_masks (hidden_states , prompt_emb , entity_prompt_emb , entity_masks , text_ids , image_ids )
435+ else :
436+ prompt_emb = self .context_embedder (prompt_emb )
437+ image_rotary_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
438+ attention_mask = None
439+
371440 def create_custom_forward (module ):
372441 def custom_forward (* inputs ):
373442 return module (* inputs )
374443 return custom_forward
375-
444+
376445 for block in self .blocks :
377446 if self .training and use_gradient_checkpointing :
378447 hidden_states , prompt_emb = torch .utils .checkpoint .checkpoint (
379448 create_custom_forward (block ),
380- hidden_states , prompt_emb , conditioning , image_rotary_emb ,
449+ hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask ,
381450 use_reentrant = False ,
382451 )
383452 else :
384- hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb )
453+ hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask )
385454
386455 hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
387456 for block in self .single_blocks :
388457 if self .training and use_gradient_checkpointing :
389458 hidden_states , prompt_emb = torch .utils .checkpoint .checkpoint (
390459 create_custom_forward (block ),
391- hidden_states , prompt_emb , conditioning , image_rotary_emb ,
460+ hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask ,
392461 use_reentrant = False ,
393462 )
394463 else :
395- hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb )
464+ hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask )
396465 hidden_states = hidden_states [:, prompt_emb .shape [1 ]:]
397466
398467 hidden_states = self .final_norm_out (hidden_states , conditioning )
399468 hidden_states = self .final_proj_out (hidden_states )
400469 hidden_states = self .unpatchify (hidden_states , height , width )
401470
402471 return hidden_states
403-
472+
404473
405474 def quantize (self ):
406475 def cast_to (weight , dtype = None , device = None , copy = False ):
@@ -440,24 +509,24 @@ class quantized_layer:
440509 class Linear (torch .nn .Linear ):
441510 def __init__ (self , * args , ** kwargs ):
442511 super ().__init__ (* args , ** kwargs )
443-
512+
444513 def forward (self ,input ,** kwargs ):
445514 weight ,bias = cast_bias_weight (self ,input )
446515 return torch .nn .functional .linear (input ,weight ,bias )
447-
516+
448517 class RMSNorm (torch .nn .Module ):
449518 def __init__ (self , module ):
450519 super ().__init__ ()
451520 self .module = module
452-
521+
453522 def forward (self ,hidden_states ,** kwargs ):
454523 weight = cast_weight (self .module ,hidden_states )
455524 input_dtype = hidden_states .dtype
456525 variance = hidden_states .to (torch .float32 ).square ().mean (- 1 , keepdim = True )
457526 hidden_states = hidden_states * torch .rsqrt (variance + self .module .eps )
458527 hidden_states = hidden_states .to (input_dtype ) * weight
459528 return hidden_states
460-
529+
461530 def replace_layer (model ):
462531 for name , module in model .named_children ():
463532 if isinstance (module , torch .nn .Linear ):
@@ -483,7 +552,6 @@ def replace_layer(model):
483552 @staticmethod
484553 def state_dict_converter ():
485554 return FluxDiTStateDictConverter ()
486-
487555
488556
489557class FluxDiTStateDictConverter :
@@ -587,7 +655,7 @@ def from_diffusers(self, state_dict):
587655 state_dict_ .pop (name .replace (f".{ component } _to_q." , f".{ component } _to_k." ))
588656 state_dict_ .pop (name .replace (f".{ component } _to_q." , f".{ component } _to_v." ))
589657 return state_dict_
590-
658+
591659 def from_civitai (self , state_dict ):
592660 rename_dict = {
593661 "time_in.in_layer.bias" : "time_embedder.timestep_embedder.0.bias" ,
0 commit comments