Skip to content

Commit cfff435

Browse files
Merge pull request #414 from AI-Hypercomputer:ltx2_sharding
PiperOrigin-RevId: 923847845
2 parents eefba00 + 2e0b568 commit cfff435

17 files changed

Lines changed: 665 additions & 143 deletions

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@ remat_policy: "NONE"
1414
jax_cache_dir: ''
1515
weights_dtype: 'bfloat16'
1616
activations_dtype: 'bfloat16'
17+
text_encoder_dtype: 'bfloat16'
18+
compile_text_encoder: False
19+
use_batched_text_encoder: False
1720

1821
run_name: 'ltx2_inference'
1922
output_dir: ''
23+
base_output_directory: ''
2024
config_path: ''
2125
save_config_to_gcs: False
2226

@@ -69,6 +73,11 @@ logical_axis_rules: [
6973
]
7074
data_sharding: ['data', 'fsdp', 'context', 'tensor']
7175

76+
sharding:
77+
transformer: 'default'
78+
vae: 'default'
79+
text_connector: 'default'
80+
7281
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
7382
dcn_fsdp_parallelism: -1
7483

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@ remat_policy: "NONE"
2020
jax_cache_dir: ''
2121
weights_dtype: 'bfloat16'
2222
activations_dtype: 'bfloat16'
23+
text_encoder_dtype: 'bfloat16'
24+
compile_text_encoder: False
25+
use_batched_text_encoder: False
2326

2427
run_name: 'ltx2_inference'
2528
output_dir: ''
29+
base_output_directory: ''
2630
config_path: ''
2731
save_config_to_gcs: False
2832

@@ -74,6 +78,11 @@ logical_axis_rules: [
7478
]
7579
data_sharding: ['data', 'fsdp', 'context', 'tensor']
7680

81+
sharding:
82+
transformer: 'default'
83+
vae: 'default'
84+
text_connector: 'default'
85+
7786
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
7887
dcn_fsdp_parallelism: -1
7988

src/maxdiffusion/max_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020
import functools
2121
from functools import partial, reduce
2222
from contextlib import nullcontext
23-
from typing import Dict, Callable
23+
from typing import (
24+
Any,
25+
Callable,
26+
Dict,
27+
Set,
28+
Tuple,
29+
Union,
30+
)
2431
import json
2532
import yaml
2633
import os
@@ -36,7 +43,6 @@
3643
import optax
3744
from maxdiffusion import max_logging
3845
from maxdiffusion.checkpointing import checkpointing_utils
39-
from maxdiffusion.models.attention_flax import AttentionOp
4046
import flax.linen as nn
4147
import flax.linen.module as module_lib
4248
from flax.linen.summary import _process_inputs
@@ -50,13 +56,6 @@
5056

5157
from transformers import FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel
5258
from flax import struct
53-
from typing import (
54-
Callable,
55-
Any,
56-
Tuple,
57-
Union,
58-
Set,
59-
)
6059
from flax import core
6160
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
6261

@@ -676,6 +675,8 @@ def get_live_arrays():
676675
# to retrieve layer parameters and calculate
677676
def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSequences], train, **kwargs):
678677
"""Calculates model tflops by passing a module."""
678+
from maxdiffusion.models.attention_flax import AttentionOp
679+
679680
with module_lib._tabulate_context():
680681
_ = jax.eval_shape(module.init, rngs, **kwargs)
681682
calls = module_lib._context.call_info_stack[-1].calls
@@ -769,3 +770,8 @@ def maybe_initialize_jax_distributed_system(raw_keys):
769770
max_logging.log("Jax distributed system initialized on GPU!")
770771
else:
771772
jax.distributed.initialize()
773+
774+
775+
def safe_getattr(obj: Any, name: str, default: Any) -> Any:
776+
"""Safely reads attribute from an object, returning default if obj is None or attribute missing."""
777+
return getattr(obj, name, default) if obj is not None else default

src/maxdiffusion/models/attention_flax.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import contextlib
1616
import functools
1717
import math
18-
from typing import Optional, Callable, Tuple, Dict
18+
from typing import Optional, Callable, Tuple, Any, Dict
1919
import flax.linen as nn
2020
from flax import nnx
2121
import jax
@@ -31,6 +31,7 @@
3131
from einops import rearrange
3232
from .. import common_types, max_logging
3333
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
34+
from maxdiffusion.max_utils import safe_getattr
3435

3536

3637
from ..kernels import custom_splash_attention as custom_splash
@@ -1133,24 +1134,15 @@ def __init__(
11331134
dtype: jnp.dtype = jnp.float32,
11341135
weights_dtype: jnp.dtype = jnp.float32,
11351136
precision: Optional[jax.lax.Precision] = None,
1137+
sharding_specs: Optional[Any] = None,
11361138
):
11371139
inner_dim = int(dim * mult)
11381140
dim_out = dim_out if dim_out is not None else dim
11391141

1140-
tpu_type = get_tpu_type()
1141-
is_ironwood = tpu_type == TpuType.TPU_7X
1142-
1143-
# Hardware-aware sharding specs: Ironwood (v7x) keeps the embedding dimension (embed)
1144-
# replicated (None) to minimize cross-device communication, while other hardware (default)
1145-
# shards it to prevent OOM issues.
1146-
if is_ironwood:
1147-
net0_kernel_spec = (None, "mlp")
1148-
net2_kernel_spec = ("mlp", None)
1149-
net2_bias_spec = (None,)
1150-
else:
1151-
net0_kernel_spec = ("embed", "mlp")
1152-
net2_kernel_spec = ("mlp", "embed")
1153-
net2_bias_spec = ("embed",)
1142+
net_0_kernel = safe_getattr(sharding_specs, "net_0_kernel", ("embed", "mlp"))
1143+
net_0_bias = safe_getattr(sharding_specs, "net_0_bias", ("mlp",))
1144+
net_2_kernel = safe_getattr(sharding_specs, "net_2_kernel", ("mlp", "embed"))
1145+
net_2_bias = safe_getattr(sharding_specs, "net_2_bias", ("embed",))
11541146

11551147
self.net_0 = nnx.Linear(
11561148
dim,
@@ -1160,8 +1152,8 @@ def __init__(
11601152
dtype=dtype,
11611153
param_dtype=weights_dtype,
11621154
precision=precision,
1163-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net0_kernel_spec),
1164-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
1155+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_0_kernel),
1156+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_0_bias),
11651157
)
11661158
self.act = get_activation(activation_fn)
11671159
self.net_2 = nnx.Linear(
@@ -1172,8 +1164,8 @@ def __init__(
11721164
dtype=dtype,
11731165
param_dtype=weights_dtype,
11741166
precision=precision,
1175-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net2_kernel_spec),
1176-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net2_bias_spec),
1167+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_2_kernel),
1168+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_2_bias),
11771169
)
11781170

11791171
def __call__(self, hidden_states: Array) -> Array:

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Optional
15+
from typing import Optional, Any
1616
import flax.linen as nn
1717
from flax import nnx
1818
import jax.numpy as jnp
@@ -22,6 +22,7 @@
2222
from ..models.attention_flax import NNXSimpleFeedForward
2323
from ..models.normalization_flax import FP32LayerNorm
2424
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
25+
from maxdiffusion.max_utils import safe_getattr
2526

2627

2728
def get_sinusoidal_embeddings(
@@ -85,7 +86,12 @@ def __init__(
8586
dtype: jnp.dtype = jnp.float32,
8687
weights_dtype: jnp.dtype = jnp.float32,
8788
precision: jax.lax.Precision = None,
89+
sharding_specs: Optional[Any] = None,
8890
):
91+
linear_1_kernel = safe_getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
92+
linear_1_bias = safe_getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
93+
linear_2_kernel = safe_getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
94+
linear_2_bias = safe_getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
8995
self.linear_1 = nnx.Linear(
9096
rngs=rngs,
9197
in_features=in_channels,
@@ -96,12 +102,9 @@ def __init__(
96102
precision=precision,
97103
kernel_init=nnx.with_partitioning(
98104
nnx.initializers.xavier_uniform(),
99-
(
100-
"embed",
101-
"mlp",
102-
),
105+
linear_1_kernel,
103106
),
104-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
107+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
105108
)
106109

107110
if cond_proj_dim is not None:
@@ -128,12 +131,9 @@ def __init__(
128131
precision=precision,
129132
kernel_init=nnx.with_partitioning(
130133
nnx.initializers.xavier_uniform(),
131-
(
132-
"mlp",
133-
"embed",
134-
),
134+
linear_2_kernel,
135135
),
136-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
136+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
137137
)
138138

139139
if post_act_fn is None:
@@ -341,7 +341,12 @@ def __init__(
341341
dtype: jnp.dtype = jnp.float32,
342342
weights_dtype: jnp.dtype = jnp.float32,
343343
precision: jax.lax.Precision = None,
344+
sharding_specs: Optional[Any] = None,
344345
):
346+
linear_1_kernel = safe_getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
347+
linear_1_bias = safe_getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
348+
linear_2_kernel = safe_getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
349+
linear_2_bias = safe_getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
345350
if out_features is None:
346351
out_features = hidden_size
347352

@@ -355,12 +360,9 @@ def __init__(
355360
precision=precision,
356361
kernel_init=nnx.with_partitioning(
357362
nnx.initializers.xavier_uniform(),
358-
(
359-
"embed",
360-
"mlp",
361-
),
363+
linear_1_kernel,
362364
),
363-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
365+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
364366
)
365367
self.act_1 = get_activation(act_fn)
366368

@@ -374,12 +376,9 @@ def __init__(
374376
precision=precision,
375377
kernel_init=nnx.with_partitioning(
376378
nnx.initializers.xavier_uniform(),
377-
(
378-
"mlp",
379-
"embed",
380-
),
379+
linear_2_kernel,
381380
),
382-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
381+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
383382
)
384383

385384
def __call__(self, caption):
@@ -535,22 +534,38 @@ def __init__(
535534
use_additional_conditions: bool = False,
536535
dtype: jnp.dtype = jnp.float32,
537536
weights_dtype: jnp.dtype = jnp.float32,
537+
sharding_specs: Optional[Any] = None,
538538
):
539539
self.outdim = size_emb_dim
540540
self.use_additional_conditions = use_additional_conditions
541541

542542
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
543543
self.timestep_embedder = NNXTimestepEmbedding(
544-
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
544+
rngs=rngs,
545+
in_channels=256,
546+
time_embed_dim=embedding_dim,
547+
dtype=dtype,
548+
weights_dtype=weights_dtype,
549+
sharding_specs=sharding_specs,
545550
)
546551

547552
if use_additional_conditions:
548553
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
549554
self.resolution_embedder = NNXTimestepEmbedding(
550-
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
555+
rngs=rngs,
556+
in_channels=256,
557+
time_embed_dim=size_emb_dim,
558+
dtype=dtype,
559+
weights_dtype=weights_dtype,
560+
sharding_specs=sharding_specs,
551561
)
552562
self.aspect_ratio_embedder = NNXTimestepEmbedding(
553-
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
563+
rngs=rngs,
564+
in_channels=256,
565+
time_embed_dim=size_emb_dim,
566+
dtype=dtype,
567+
weights_dtype=weights_dtype,
568+
sharding_specs=sharding_specs,
554569
)
555570

556571
def __call__(

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jax.numpy as jnp
2121
from ... import common_types
2222
from ..attention_flax import NNXAttentionOp
23-
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
23+
from .logical_sharding_ltx2 import get_sharding_specs, LTX2DiTShardingSpecs
2424

2525
Array = common_types.Array
2626
Mesh = common_types.Mesh
@@ -350,9 +350,7 @@ def __init__(
350350
rope_type: str = "interleaved",
351351
flash_block_sizes: BlockSizes = None,
352352
flash_min_seq_length: int = 4096,
353-
qkv_sharding_spec: Optional[tuple] = None,
354-
out_sharding_spec: Optional[tuple] = None,
355-
out_bias_sharding_spec: Optional[tuple] = None,
353+
sharding_specs: Optional[LTX2DiTShardingSpecs] = None,
356354
gated_attn: bool = False,
357355
):
358356
self.heads = heads
@@ -361,33 +359,24 @@ def __init__(
361359
self.inner_dim = dim_head * heads
362360
self.dropout_rate = dropout
363361

364-
# Auto-detect hardware for sharding specs if not overridden
365-
tpu_type = get_tpu_type()
366-
is_ironwood = tpu_type == TpuType.TPU_7X
367-
368-
# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
369-
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
370-
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
371-
if qkv_sharding_spec is None:
372-
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
373-
if out_sharding_spec is None:
374-
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
375-
if out_bias_sharding_spec is None:
376-
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)
362+
if sharding_specs is None:
363+
specs = get_sharding_specs("default", "ltx2_dit")
364+
else:
365+
specs = sharding_specs
377366

378367
# 1. Define Partitioned Initializers (Logical Axes)
379368
# Q, K, V kernels: [in_features (embed), out_features (heads)]
380-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
369+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.qkv_kernel)
381370
# Q, K, V biases: [out_features (heads)]
382-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
371+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.qkv_bias)
383372

384373
# Out kernel: [in_features (heads), out_features (embed)]
385-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
374+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.out_kernel)
386375
# Out bias: [out_features (embed)]
387-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)
376+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.out_bias)
388377

389378
# Norm scales
390-
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))
379+
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), specs.norm_scale)
391380

392381
# 2. Projections
393382
self.to_q = nnx.Linear(
@@ -450,8 +439,8 @@ def __init__(
450439
query_dim,
451440
heads,
452441
use_bias=True,
453-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
454-
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)),
442+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.gate_logits_kernel),
443+
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), specs.gate_logits_bias),
455444
rngs=rngs,
456445
dtype=dtype,
457446
)

0 commit comments

Comments
 (0)