1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import math
15- from typing import Optional
15+ from typing import Optional , Any
1616import flax .linen as nn
1717from flax import nnx
1818import jax .numpy as jnp
2222from ..models .attention_flax import NNXSimpleFeedForward
2323from ..models .normalization_flax import FP32LayerNorm
2424from maxdiffusion .tpu_utils import get_tpu_type , TpuType
25+ from maxdiffusion .max_utils import safe_getattr
2526
2627
2728def get_sinusoidal_embeddings (
@@ -85,7 +86,12 @@ def __init__(
8586 dtype : jnp .dtype = jnp .float32 ,
8687 weights_dtype : jnp .dtype = jnp .float32 ,
8788 precision : jax .lax .Precision = None ,
89+ sharding_specs : Optional [Any ] = None ,
8890 ):
91+ linear_1_kernel = safe_getattr (sharding_specs , "emb_linear_1_kernel" , ("embed" , "mlp" ))
92+ linear_1_bias = safe_getattr (sharding_specs , "emb_linear_1_bias" , ("mlp" ,))
93+ linear_2_kernel = safe_getattr (sharding_specs , "emb_linear_2_kernel" , ("mlp" , "embed" ))
94+ linear_2_bias = safe_getattr (sharding_specs , "emb_linear_2_bias" , ("embed" ,))
8995 self .linear_1 = nnx .Linear (
9096 rngs = rngs ,
9197 in_features = in_channels ,
@@ -96,12 +102,9 @@ def __init__(
96102 precision = precision ,
97103 kernel_init = nnx .with_partitioning (
98104 nnx .initializers .xavier_uniform (),
99- (
100- "embed" ,
101- "mlp" ,
102- ),
105+ linear_1_kernel ,
103106 ),
104- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "mlp" ,) ),
107+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_1_bias ),
105108 )
106109
107110 if cond_proj_dim is not None :
@@ -128,12 +131,9 @@ def __init__(
128131 precision = precision ,
129132 kernel_init = nnx .with_partitioning (
130133 nnx .initializers .xavier_uniform (),
131- (
132- "mlp" ,
133- "embed" ,
134- ),
134+ linear_2_kernel ,
135135 ),
136- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "embed" ,) ),
136+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_2_bias ),
137137 )
138138
139139 if post_act_fn is None :
@@ -341,7 +341,12 @@ def __init__(
341341 dtype : jnp .dtype = jnp .float32 ,
342342 weights_dtype : jnp .dtype = jnp .float32 ,
343343 precision : jax .lax .Precision = None ,
344+ sharding_specs : Optional [Any ] = None ,
344345 ):
346+ linear_1_kernel = safe_getattr (sharding_specs , "emb_linear_1_kernel" , ("embed" , "mlp" ))
347+ linear_1_bias = safe_getattr (sharding_specs , "emb_linear_1_bias" , ("mlp" ,))
348+ linear_2_kernel = safe_getattr (sharding_specs , "emb_linear_2_kernel" , ("mlp" , "embed" ))
349+ linear_2_bias = safe_getattr (sharding_specs , "emb_linear_2_bias" , ("embed" ,))
345350 if out_features is None :
346351 out_features = hidden_size
347352
@@ -355,12 +360,9 @@ def __init__(
355360 precision = precision ,
356361 kernel_init = nnx .with_partitioning (
357362 nnx .initializers .xavier_uniform (),
358- (
359- "embed" ,
360- "mlp" ,
361- ),
363+ linear_1_kernel ,
362364 ),
363- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "mlp" ,) ),
365+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_1_bias ),
364366 )
365367 self .act_1 = get_activation (act_fn )
366368
@@ -374,12 +376,9 @@ def __init__(
374376 precision = precision ,
375377 kernel_init = nnx .with_partitioning (
376378 nnx .initializers .xavier_uniform (),
377- (
378- "mlp" ,
379- "embed" ,
380- ),
379+ linear_2_kernel ,
381380 ),
382- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "embed" ,) ),
381+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_2_bias ),
383382 )
384383
385384 def __call__ (self , caption ):
@@ -535,22 +534,38 @@ def __init__(
535534 use_additional_conditions : bool = False ,
536535 dtype : jnp .dtype = jnp .float32 ,
537536 weights_dtype : jnp .dtype = jnp .float32 ,
537+ sharding_specs : Optional [Any ] = None ,
538538 ):
539539 self .outdim = size_emb_dim
540540 self .use_additional_conditions = use_additional_conditions
541541
542542 self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
543543 self .timestep_embedder = NNXTimestepEmbedding (
544- rngs = rngs , in_channels = 256 , time_embed_dim = embedding_dim , dtype = dtype , weights_dtype = weights_dtype
544+ rngs = rngs ,
545+ in_channels = 256 ,
546+ time_embed_dim = embedding_dim ,
547+ dtype = dtype ,
548+ weights_dtype = weights_dtype ,
549+ sharding_specs = sharding_specs ,
545550 )
546551
547552 if use_additional_conditions :
548553 self .additional_condition_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
549554 self .resolution_embedder = NNXTimestepEmbedding (
550- rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
555+ rngs = rngs ,
556+ in_channels = 256 ,
557+ time_embed_dim = size_emb_dim ,
558+ dtype = dtype ,
559+ weights_dtype = weights_dtype ,
560+ sharding_specs = sharding_specs ,
551561 )
552562 self .aspect_ratio_embedder = NNXTimestepEmbedding (
553- rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
563+ rngs = rngs ,
564+ in_channels = 256 ,
565+ time_embed_dim = size_emb_dim ,
566+ dtype = dtype ,
567+ weights_dtype = weights_dtype ,
568+ sharding_specs = sharding_specs ,
554569 )
555570
556571 def __call__ (
0 commit comments