2424from maxdiffusion .models .embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings , NNXPixArtAlphaTextProjection
2525from maxdiffusion .models .gradient_checkpoint import GradientCheckpointType
2626from maxdiffusion .configuration_utils import ConfigMixin , register_to_config
27+ from maxdiffusion .common_types import BlockSizes
2728
2829
2930class LTX2AdaLayerNormSingle (nnx .Module ):
@@ -105,6 +106,7 @@ def __init__(
105106 names_which_can_be_saved : list = [],
106107 names_which_can_be_offloaded : list = [],
107108 attention_kernel : str = "flash" ,
109+ flash_block_sizes : BlockSizes = None ,
108110 ):
109111 self .dim = dim
110112 self .norm_eps = norm_eps
@@ -134,6 +136,7 @@ def __init__(
134136 mesh = mesh ,
135137 attention_kernel = self .attention_kernel ,
136138 rope_type = rope_type ,
139+ flash_block_sizes = flash_block_sizes ,
137140 )
138141
139142 self .audio_norm1 = nnx .RMSNorm (
@@ -158,6 +161,7 @@ def __init__(
158161 mesh = mesh ,
159162 attention_kernel = self .attention_kernel ,
160163 rope_type = rope_type ,
164+ flash_block_sizes = flash_block_sizes ,
161165 )
162166
163167 # 2. Prompt Cross-Attention
@@ -184,6 +188,7 @@ def __init__(
184188 mesh = mesh ,
185189 attention_kernel = self .attention_kernel ,
186190 rope_type = rope_type ,
191+ flash_block_sizes = flash_block_sizes ,
187192 )
188193
189194 self .audio_norm2 = nnx .RMSNorm (
@@ -209,6 +214,7 @@ def __init__(
209214 mesh = mesh ,
210215 attention_kernel = self .attention_kernel ,
211216 rope_type = rope_type ,
217+ flash_block_sizes = flash_block_sizes ,
212218 )
213219
214220 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -235,6 +241,7 @@ def __init__(
235241 mesh = mesh ,
236242 attention_kernel = self .attention_kernel ,
237243 rope_type = rope_type ,
244+ flash_block_sizes = flash_block_sizes ,
238245 )
239246
240247 self .video_to_audio_norm = nnx .RMSNorm (
@@ -260,6 +267,7 @@ def __init__(
260267 mesh = mesh ,
261268 attention_kernel = self .attention_kernel ,
262269 rope_type = rope_type ,
270+ flash_block_sizes = flash_block_sizes ,
263271 )
264272
265273 # 4. Feed Forward
@@ -553,6 +561,7 @@ def __init__(
553561 scan_layers : bool = True ,
554562 attention_kernel : str = "flash" ,
555563 qk_norm : str = "rms_norm_across_heads" ,
564+ flash_block_sizes : BlockSizes = None ,
556565 ** kwargs ,
557566 ):
558567 self .in_channels = in_channels
@@ -791,6 +800,7 @@ def init_block(rngs):
791800 names_which_can_be_saved = self .names_which_can_be_saved ,
792801 names_which_can_be_offloaded = self .names_which_can_be_offloaded ,
793802 attention_kernel = self .attention_kernel ,
803+ flash_block_sizes = flash_block_sizes ,
794804 )
795805
796806 if self .scan_layers :
@@ -822,6 +832,7 @@ def init_block(rngs):
822832 names_which_can_be_saved = self .names_which_can_be_saved ,
823833 names_which_can_be_offloaded = self .names_which_can_be_offloaded ,
824834 attention_kernel = self .attention_kernel ,
835+ flash_block_sizes = flash_block_sizes ,
825836 )
826837 blocks .append (block )
827838 self .transformer_blocks = nnx .List (blocks )
0 commit comments