|
20 | 20 | import jax.numpy as jnp |
21 | 21 | from ... import common_types |
22 | 22 | from ..attention_flax import NNXAttentionOp |
| 23 | +from maxdiffusion.tpu_utils import get_tpu_type, TpuType |
23 | 24 |
|
24 | 25 | Array = common_types.Array |
25 | 26 | Mesh = common_types.Mesh |
@@ -349,23 +350,40 @@ def __init__( |
349 | 350 | rope_type: str = "interleaved", |
350 | 351 | flash_block_sizes: BlockSizes = None, |
351 | 352 | flash_min_seq_length: int = 4096, |
| 353 | + qkv_sharding_spec: Optional[tuple] = None, |
| 354 | + out_sharding_spec: Optional[tuple] = None, |
| 355 | + out_bias_sharding_spec: Optional[tuple] = None, |
352 | 356 | ): |
353 | 357 | self.heads = heads |
354 | 358 | self.rope_type = rope_type |
355 | 359 | self.dim_head = dim_head |
356 | 360 | self.inner_dim = dim_head * heads |
357 | 361 | self.dropout_rate = dropout |
358 | 362 |
|
| 363 | + # Auto-detect hardware for sharding specs if not overridden |
| 364 | + tpu_type = get_tpu_type() |
| 365 | + is_ironwood = tpu_type == TpuType.TPU_7X |
| 366 | + |
| 367 | + # Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated) |
| 368 | + # to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions. |
| 369 | + # This has currently only been tested on Trillium (v6e) and Ironwood (v7x). |
| 370 | + if qkv_sharding_spec is None: |
| 371 | + qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads") |
| 372 | + if out_sharding_spec is None: |
| 373 | + out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed") |
| 374 | + if out_bias_sharding_spec is None: |
| 375 | + out_bias_sharding_spec = (None,) if is_ironwood else ("embed",) |
| 376 | + |
359 | 377 | # 1. Define Partitioned Initializers (Logical Axes) |
360 | 378 | # Q, K, V kernels: [in_features (embed), out_features (heads)] |
361 | | - qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")) |
| 379 | + qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec) |
362 | 380 | # Q, K, V biases: [out_features (heads)] |
363 | 381 | qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) |
364 | 382 |
|
365 | 383 | # Out kernel: [in_features (heads), out_features (embed)] |
366 | | - out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")) |
| 384 | + out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec) |
367 | 385 | # Out bias: [out_features (embed)] |
368 | | - out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) |
| 386 | + out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec) |
369 | 387 |
|
370 | 388 | # Norm scales |
371 | 389 | norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)) |
|
0 commit comments