Skip to content

Commit d822acb

Browse files
committed
Merge branch 'main' of https://github.com/AI-Hypercomputer/maxdiffusion into fixbiassharding
2 parents 039ceb3 + ceca471 commit d822acb

15 files changed

Lines changed: 1262 additions & 201 deletions

src/maxdiffusion/checkpointing/ltx2_checkpointer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
7979
return restored_checkpoint, step
8080

8181
def load_checkpoint(
82-
self, step=None, vae_only=False, load_transformer=True
82+
self, step=None, vae_only=False, load_transformer=True, load_upsampler=False
8383
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
8484
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
8585
opt_state = None
8686

8787
if restored_checkpoint:
8888
max_logging.log("Loading LTX2 pipeline from checkpoint")
89-
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
89+
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer, load_upsampler)
9090
if "opt_state" in restored_checkpoint.ltx2_state.keys():
9191
opt_state = restored_checkpoint.ltx2_state["opt_state"]
9292
else:
9393
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
94-
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
94+
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer, load_upsampler)
9595

9696
return pipeline, opt_state, step
9797

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
44
attention: 'flash'
5+
a2v_attention_kernel: 'flash'
6+
v2a_attention_kernel: 'dot_product'
57
attention_sharding_uniform: True
68
precision: 'bf16'
79
scan_layers: True
@@ -68,6 +70,7 @@ flash_block_sizes: {
6870
block_kv_dkv_compute: 2048,
6971
use_fused_bwd_kernel: True,
7072
}
73+
flash_min_seq_length: 4096
7174
dcn_context_parallelism: 1
7275
dcn_tensor_parallelism: 1
7376
ici_data_parallelism: 1
@@ -102,3 +105,13 @@ jit_initializers: True
102105
enable_single_replica_ckpt_restoring: False
103106
seed: 0
104107
audio_format: "s16"
108+
109+
# LTX-2 Latent Upsampler
110+
run_latent_upsampler: False
111+
upsampler_model_path: "Lightricks/LTX-2"
112+
upsampler_spatial_patch_size: 1
113+
upsampler_temporal_patch_size: 1
114+
upsampler_adain_factor: 0.0
115+
upsampler_tone_map_compression_ratio: 0.0
116+
upsampler_rational_spatial_scale: 2.0
117+
upsampler_output_type: "pil"

src/maxdiffusion/generate_ltx2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def get_git_commit_hash():
8181

8282

8383
def call_pipeline(config, pipeline, prompt, negative_prompt):
84-
# Set default generation arguments
8584
generator = jax.random.key(config.seed) if hasattr(config, "seed") else jax.random.key(0)
8685
guidance_scale = config.guidance_scale if hasattr(config, "guidance_scale") else 3.0
8786

@@ -99,6 +98,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9998
decode_noise_scale=getattr(config, "decode_noise_scale", None),
10099
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101100
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
101+
output_type=getattr(config, "upsampler_output_type", "pil"),
102102
)
103103
return out
104104

@@ -114,9 +114,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
114114
else:
115115
max_logging.log("Could not retrieve Git commit hash.")
116116

117+
checkpoint_loader = LTX2Checkpointer(config=config)
117118
if pipeline is None:
118-
checkpoint_loader = LTX2Checkpointer(config=config)
119-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
119+
# Use the config flag to determine if the upsampler should be loaded
120+
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
121+
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
120122

121123
pipeline.enable_vae_slicing()
122124
pipeline.enable_vae_tiling()
@@ -135,6 +137,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
135137
)
136138

137139
out = call_pipeline(config, pipeline, prompt, negative_prompt)
140+
138141
# out should have .frames and .audio
139142
videos = out.frames if hasattr(out, "frames") else out[0]
140143
audios = out.audio if hasattr(out, "audio") else None
@@ -143,6 +146,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
143146
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
144147
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
145148
max_logging.log(f"model type: {getattr(config, 'model_type', 'T2V')}")
149+
if getattr(config, "run_latent_upsampler", False):
150+
max_logging.log(f"upsampler model path: {config.upsampler_model_path}")
146151
max_logging.log(f"hardware: {jax.devices()[0].platform}")
147152
max_logging.log(f"number of devices: {jax.device_count()}")
148153
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional, Tuple
1818
from flax import nnx
19+
import jax
1920
import jax.numpy as jnp
2021
from ... import common_types
2122
from ..attention_flax import NNXAttentionOp
@@ -347,6 +348,7 @@ def __init__(
347348
attention_kernel: str = "flash",
348349
rope_type: str = "interleaved",
349350
flash_block_sizes: BlockSizes = None,
351+
flash_min_seq_length: int = 4096,
350352
):
351353
self.heads = heads
352354
self.rope_type = rope_type
@@ -434,6 +436,7 @@ def __init__(
434436
axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV),
435437
axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV),
436438
flash_block_sizes=flash_block_sizes,
439+
flash_min_seq_length=flash_min_seq_length,
437440
)
438441

439442
def __call__(
@@ -447,46 +450,49 @@ def __call__(
447450
# Determine context (Self or Cross)
448451
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
449452

450-
# 1. Project
451-
query = self.to_q(hidden_states)
452-
key = self.to_k(context)
453-
value = self.to_v(context)
453+
# 1. Project and Norm
454+
with jax.named_scope("QKV Projection"):
455+
query = self.to_q(hidden_states)
456+
key = self.to_k(context)
457+
value = self.to_v(context)
454458

455-
# 2. Norm (Full Inner Dimension)
456-
query = self.norm_q(query)
457-
key = self.norm_k(key)
459+
with jax.named_scope("QKV Norm"):
460+
query = self.norm_q(query)
461+
key = self.norm_k(key)
458462

459463
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
460464
# Frequencies are shape [B, S, InnerDim]
461465
# 3. Apply RoPE
462-
if rotary_emb is not None:
463-
if hasattr(self, "rope_type") and self.rope_type == "split":
464-
# Split RoPE: passing full freqs [B, H, S, D//2]
465-
# apply_split_rotary_emb handles reshaping query/key
466-
467-
query = apply_split_rotary_emb(query, rotary_emb)
468-
469-
if k_rotary_emb is not None:
470-
key = apply_split_rotary_emb(key, k_rotary_emb)
471-
elif encoder_hidden_states is None:
472-
key = apply_split_rotary_emb(key, rotary_emb)
473-
474-
else:
475-
# Interleaved (Default)
476-
query = apply_rotary_emb(query, rotary_emb)
477-
if k_rotary_emb is not None:
478-
key = apply_rotary_emb(key, k_rotary_emb)
479-
elif encoder_hidden_states is None:
480-
key = apply_rotary_emb(key, rotary_emb)
481-
482-
# 4. Attention
483-
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
484-
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
485-
486-
# 7. Output Projection
487-
hidden_states = self.to_out(attn_output)
488-
489-
if self.dropout_layer is not None:
490-
hidden_states = self.dropout_layer(hidden_states)
466+
with jax.named_scope("Apply RoPE"):
467+
if rotary_emb is not None:
468+
if hasattr(self, "rope_type") and self.rope_type == "split":
469+
# Split RoPE: passing full freqs [B, H, S, D//2]
470+
# apply_split_rotary_emb handles reshaping query/key
471+
472+
query = apply_split_rotary_emb(query, rotary_emb)
473+
474+
if k_rotary_emb is not None:
475+
key = apply_split_rotary_emb(key, k_rotary_emb)
476+
elif encoder_hidden_states is None:
477+
key = apply_split_rotary_emb(key, rotary_emb)
478+
479+
else:
480+
# Interleaved (Default)
481+
query = apply_rotary_emb(query, rotary_emb)
482+
if k_rotary_emb is not None:
483+
key = apply_rotary_emb(key, k_rotary_emb)
484+
elif encoder_hidden_states is None:
485+
key = apply_rotary_emb(key, rotary_emb)
486+
487+
with jax.named_scope("Attention and Output Project"):
488+
# 4. Attention
489+
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
490+
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
491+
492+
# 7. Output Projection
493+
hidden_states = self.to_out(attn_output)
494+
495+
if self.dropout_layer is not None:
496+
hidden_states = self.dropout_layer(hidden_states)
491497

492498
return hidden_states

0 commit comments

Comments
 (0)