2222import jax
2323import jax .numpy as jnp
2424from jax .sharding import Mesh , NamedSharding , PartitionSpec as P
25+ from torchax import default_env
26+ from maxdiffusion .models .ltx2 .text_encoders .torchax_text_encoder import TorchaxGemma3TextEncoder
27+ from maxdiffusion .tpu_utils import get_tpu_type , TpuType
28+ from maxdiffusion .maxdiffusion_utils import get_dummy_ltx2_inputs
29+ import contextlib
2530import flax
2631import flax .linen as nn
2732import flax .traverse_util
2833from flax import nnx
2934from flax .linen import partitioning as nn_partitioning
3035from transformers import AutoTokenizer , GemmaTokenizer , GemmaTokenizerFast , Gemma3ForConditionalGeneration
31- from maxdiffusion .tpu_utils import get_tpu_type , TpuType
3236import qwix
3337from ...utils import logging
34- from ...schedulers import FlaxFlowMatchScheduler
38+ from ...schedulers import FlaxFlowMatchScheduler # pylint: disable=no-name-in-module
3539from ...models .ltx2 .autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
3640from ...models .ltx2 .autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio
3741from ...models .ltx2 .vocoder_ltx2 import LTX2Vocoder
5357from ... import max_logging
5458from ... import max_utils
5559from ...max_utils import get_precision , device_put_replicated , get_flash_block_sizes
56- from maxdiffusion .maxdiffusion_utils import get_dummy_ltx2_inputs
5760
5861
5962@flax .struct .dataclass
@@ -65,7 +68,8 @@ class LTX2PipelineOutput:
6568def rescale_noise_cfg (noise_cfg , noise_pred_text , guidance_rescale = 0.0 ):
6669 """
6770 Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure.
68- Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).
71+ Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed]
72+ (https://huggingface.co/papers/2305.08891).
6973 """
7074 std_text = jnp .std (noise_pred_text , axis = list (range (1 , noise_pred_text .ndim )), keepdims = True )
7175 std_cfg = jnp .std (noise_cfg , axis = list (range (1 , noise_cfg .ndim )), keepdims = True )
@@ -110,6 +114,9 @@ def create_sharded_logical_transformer(
110114 restored_checkpoint = None ,
111115 subfolder : str = "" ,
112116):
117+ """Creates a sharded logical transformer."""
118+
119+ # pylint: disable=too-many-positional-arguments,unused-argument
113120 def create_model (rngs : nnx .Rngs , ltx2_config : dict ):
114121 transformer = LTX2VideoTransformer3DModel (** ltx2_config , rngs = rngs )
115122 return transformer
@@ -186,6 +193,8 @@ def calculate_shift(
186193 base_shift : float = 0.5 ,
187194 max_shift : float = 1.15 ,
188195):
196+ """Calculates the shift for the timestep schedule."""
197+ # pylint: disable=too-many-positional-arguments
189198 m = (max_shift - base_shift ) / (max_seq_len - base_seq_len )
190199 b = base_shift - m * base_seq_len
191200 mu = image_seq_len * m + b
@@ -200,6 +209,8 @@ def retrieve_timesteps(
200209 sigmas : Optional [List [float ]] = None ,
201210 ** kwargs ,
202211):
212+ """Retrieves timesteps for the scheduler."""
213+ # pylint: disable=too-many-positional-arguments
203214 if timesteps is not None and sigmas is not None :
204215 raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
205216
@@ -222,6 +233,8 @@ class LTX2Pipeline:
222233 Pipeline for LTX-2.
223234 """
224235
236+ # pylint: disable=missing-function-docstring,too-many-positional-arguments,unused-argument
237+
225238 def __init__ (
226239 self ,
227240 scheduler : FlaxFlowMatchScheduler ,
@@ -245,6 +258,8 @@ def __init__(
245258 self .transformer = transformer
246259 self .latent_upsampler = latent_upsampler
247260 self .latent_upsampler_params = latent_upsampler_params
261+ self .mesh = None
262+ self .config = None
248263
249264 # VAE compression ratios
250265 self .vae_spatial_compression_ratio = getattr (self .vae , "spatial_compression_ratio" , 32 )
@@ -316,6 +331,11 @@ def load_text_encoder(cls, config: HyperParameters):
316331 torch_dtype = torch .bfloat16 ,
317332 )
318333 text_encoder .eval ()
334+
335+ with default_env ():
336+ text_encoder = text_encoder .to ("jax" )
337+ text_encoder = TorchaxGemma3TextEncoder (text_encoder )
338+
319339 return text_encoder
320340
321341 @classmethod
@@ -396,7 +416,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
396416 sharding = sharding .get_value ()
397417 try :
398418 replicate_vae = config .replicate_vae
399- except ValueError :
419+ except Exception : # pylint: disable=broad-exception-caught
400420 replicate_vae = False
401421 if replicate_vae :
402422 sharding = NamedSharding (mesh , P ())
@@ -444,7 +464,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
444464 sharding = sharding .get_value ()
445465 try :
446466 replicate_vae = config .replicate_vae
447- except ValueError :
467+ except Exception : # pylint: disable=broad-exception-caught
448468 replicate_vae = False
449469 if replicate_vae :
450470 sharding = NamedSharding (mesh , P ())
@@ -750,39 +770,48 @@ def _get_gemma_prompt_embeds(
750770 prompt = [p .strip () for p in prompt ]
751771
752772 if self .text_encoder is not None :
753- # PyTorch Text Encoder
773+ # Torchax Text Encoder
754774 text_inputs = self .tokenizer (
755775 prompt ,
756776 padding = "max_length" ,
757777 max_length = max_sequence_length ,
758778 truncation = True ,
759779 add_special_tokens = True ,
760- return_tensors = "pt" ,
780+ return_tensors = "np" ,
781+ )
782+ text_input_ids = jnp .array (text_inputs .input_ids )
783+ prompt_attention_mask = jnp .array (text_inputs .attention_mask )
784+
785+ # Distribute the batch dimension across available TPUs to prevent Softmax OOM
786+ # (reduces 512MB allocation down to 64MB per TPU for batch size 16)
787+ devices = np .array (jax .devices ())
788+ num_shards = 1
789+ for i in range (len (devices ), 0 , - 1 ):
790+ if text_input_ids .shape [0 ] % i == 0 :
791+ num_shards = i
792+ break
793+
794+ if num_shards > 1 :
795+ mesh = Mesh (devices [:num_shards ], axis_names = ("batch" ,))
796+ sharding = NamedSharding (mesh , P ("batch" ))
797+ text_input_ids = jax .device_put (text_input_ids , sharding )
798+ prompt_attention_mask = jax .device_put (prompt_attention_mask , sharding )
799+
800+ # Torchax wrapper returns tuple of hidden states natively
801+ text_encoder_hidden_states = self .text_encoder (
802+ input_ids = text_input_ids , attention_mask = prompt_attention_mask , output_hidden_states = True
761803 )
762- text_input_ids = text_inputs .input_ids
763- prompt_attention_mask = text_inputs .attention_mask
764-
765- text_input_ids = text_input_ids .to (self .text_encoder .device )
766- prompt_attention_mask = prompt_attention_mask .to (self .text_encoder .device )
767-
768- with torch .no_grad ():
769- text_encoder_outputs = self .text_encoder (
770- input_ids = text_input_ids , attention_mask = prompt_attention_mask , output_hidden_states = True
771- )
772-
773- text_encoder_hidden_states = text_encoder_outputs .hidden_states
774- del text_encoder_outputs # Free memory
775804
776805 prompt_embeds_list = []
777806 # Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
778807 for state in text_encoder_hidden_states :
779- state_np = state . cpu (). to ( torch . float32 ). numpy ( )
780- prompt_embeds_list .append (jnp . array ( state_np , dtype = jnp .bfloat16 ))
808+ state = jax . device_put ( state , jax . devices ()[ 0 ] )
809+ prompt_embeds_list .append (state . astype ( jnp .bfloat16 ))
781810
782811 prompt_embeds = prompt_embeds_list
783- del text_encoder_hidden_states # Free PyTorch tensor memory
812+ del text_encoder_hidden_states # Free memory
784813
785- prompt_attention_mask = jnp . array ( prompt_attention_mask .cpu (). to ( torch . float32 ). numpy (), dtype = jnp .bool_ )
814+ prompt_attention_mask = prompt_attention_mask .astype ( jnp .bool_ )
786815 else :
787816 raise ValueError ("`text_encoder` is required to encode prompts." )
788817
@@ -939,7 +968,7 @@ def check_inputs(
939968
940969 @staticmethod
941970 def _pack_latents (latents : jax .Array , patch_size : int = 1 , patch_size_t : int = 1 ) -> jax .Array :
942- batch_size , num_channels , num_frames , height , width = latents .shape
971+ batch_size , _ , num_frames , height , width = latents .shape
943972 post_patch_num_frames = num_frames // patch_size_t
944973 post_patch_height = height // patch_size
945974 post_patch_width = width // patch_size
@@ -1028,7 +1057,7 @@ def _pack_audio_latents(
10281057 latents : jax .Array , patch_size : Optional [int ] = None , patch_size_t : Optional [int ] = None
10291058 ) -> jax .Array :
10301059 if patch_size is not None and patch_size_t is not None :
1031- batch_size , num_channels , latent_length , latent_mel_bins = latents .shape
1060+ batch_size , _ , latent_length , latent_mel_bins = latents .shape
10321061 post_patch_latent_length = latent_length // patch_size_t
10331062 post_patch_mel_bins = latent_mel_bins // patch_size
10341063 latents = latents .reshape (batch_size , - 1 , post_patch_latent_length , patch_size_t , post_patch_mel_bins , patch_size )
@@ -1327,7 +1356,6 @@ def __call__(
13271356 graphdef , state = nnx .split (self .transformer )
13281357
13291358 # 7. Denoising Loop
1330- import contextlib
13311359
13321360 context_manager = self .mesh if hasattr (self , "mesh" ) and self .mesh is not None else contextlib .nullcontext ()
13331361 axis_rules_context = (
@@ -1380,9 +1408,7 @@ def __call__(
13801408 )
13811409 else :
13821410 # Old Python loop path
1383- for i in range (len (timesteps_jax )):
1384- t = timesteps_jax [i ]
1385-
1411+ for _ , t in enumerate (timesteps_jax ):
13861412 # Isolate input sharding to scan_layers=False to avoid affecting the standard path
13871413 latents_jax_sharded = latents_jax
13881414 audio_latents_jax_sharded = audio_latents_jax
@@ -1503,20 +1529,45 @@ def __call__(
15031529 if output_type == "latent" :
15041530 return LTX2PipelineOutput (frames = latents , audio = audio_latents )
15051531
1506- # Force latents and VAE weights to be fully replicated using with_sharding_constraint, this speeds up single video latency ~3x
1507- try :
1508- mesh = latents .sharding .mesh
1509- replicated_sharding = NamedSharding (mesh , P ())
1510- latents = jax .lax .with_sharding_constraint (latents , replicated_sharding )
1511-
1512- # Replicate VAE weights
1513- graphdef , state = nnx .split (self .vae )
1514- state = jax .tree_util .tree_map (
1515- lambda x : jax .lax .with_sharding_constraint (x , replicated_sharding ) if isinstance (x , jax .Array ) else x , state
1532+ # Force latents and VAE weights to be fully replicated using with_sharding_constraint,
1533+ # this speeds up single video latency ~3x
1534+ if batch_size <= 2 :
1535+ try :
1536+ mesh = latents .sharding .mesh
1537+ replicated_sharding = NamedSharding (mesh , P ())
1538+ latents = jax .lax .with_sharding_constraint (latents , replicated_sharding )
1539+
1540+ # Replicate VAE weights
1541+ graphdef , state = nnx .split (self .vae )
1542+ state = jax .tree_util .tree_map (
1543+ lambda x : jax .lax .with_sharding_constraint (x , replicated_sharding ) if isinstance (x , jax .Array ) else x , state
1544+ )
1545+ self .vae = nnx .merge (graphdef , state )
1546+ except Exception : # pylint: disable=broad-exception-caught
1547+ max_logging .log ("[Tuning] Failed to apply sharding constraint" )
1548+ else :
1549+ max_logging .log (
1550+ f"[Tuning] Skipping VAE replication and disabling slicing to prevent HBM OOM for batch_size { batch_size } > 2"
15161551 )
1517- self .vae = nnx .merge (graphdef , state )
1518- except Exception as e :
1519- max_logging .log (f"[Tuning] Failed to apply sharding constraint: { e } " )
1552+ try :
1553+ # Disable sequential slicing to avoid JAX concatenating 17GB arrays on the TPU
1554+ self .vae .use_slicing = False
1555+
1556+ # Distribute the batch dimension across the existing mesh to ensure topological compatibility
1557+ mesh = latents .sharding .mesh
1558+ active_axes = []
1559+ current_shards = 1
1560+
1561+ for axis_name , size in mesh .shape .items ():
1562+ if size > 1 and batch_size % (current_shards * size ) == 0 :
1563+ active_axes .append (axis_name )
1564+ current_shards *= size
1565+
1566+ if active_axes :
1567+ batch_sharding = jax .sharding .NamedSharding (mesh , jax .sharding .PartitionSpec (tuple (active_axes )))
1568+ latents = jax .lax .with_sharding_constraint (latents , batch_sharding )
1569+ except Exception : # pylint: disable=broad-exception-caught
1570+ max_logging .log ("[Tuning] Failed to apply batch sharding constraint to VAE" )
15201571
15211572 if getattr (self .vae .config , "timestep_conditioning" , False ):
15221573 noise = jax .random .normal (generator , latents .shape , dtype = latents .dtype )
@@ -1587,6 +1638,8 @@ def transformer_forward_pass(
15871638 audio_num_frames ,
15881639 fps ,
15891640):
1641+ """Forward pass for the transformer."""
1642+ # pylint: disable=too-many-positional-arguments,unused-argument
15901643 transformer = nnx .merge (graphdef , state )
15911644
15921645 # Expand timestep to batch size
@@ -1647,6 +1700,8 @@ def run_diffusion_loop(
16471700 scheduler_step ,
16481701 logical_axis_rules ,
16491702):
1703+ """Runs the diffusion loop."""
1704+ # pylint: disable=too-many-positional-arguments
16501705 latents_jax = latents_jax .astype (jnp .float32 )
16511706 audio_latents_jax = audio_latents_jax .astype (jnp .float32 )
16521707 transformer = nnx .merge (graphdef , state )
0 commit comments