Skip to content

Commit e25d9bc

Browse files
Fixed rebase errors
1 parent b03fe07 commit e25d9bc

2 files changed

Lines changed: 2 additions & 52 deletions

File tree

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from flax import linen as nn
3030
from flax import nnx
3131

32-
<<<<<<< HEAD:src/maxtext/models/qwen3.py
3332
from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN
3433
from maxtext.layers import attentions
3534
from maxtext.layers import initializers as max_initializers
@@ -44,25 +43,8 @@
4443
from maxtext.layers.moe import RoutedMoE
4544
from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned
4645

47-
=======
4846
from jax.sharding import PartitionSpec as P
4947
from jax.experimental.shard_map import shard_map
50-
51-
from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN
52-
from MaxText.layers import attentions
53-
from MaxText.layers import initializers as max_initializers
54-
from MaxText.layers import moe
55-
from MaxText.layers import nnx_wrappers
56-
from MaxText.layers import quantizations
57-
from MaxText.layers.embeddings import Qwen3OmniMoeVisionPosEmbedInterpolate, PositionalEmbedding
58-
from MaxText.layers.normalizations import RMSNorm, l2norm, Qwen3NextRMSNorm, Qwen3NextRMSNormGated
59-
from MaxText.layers.quantizations import AqtQuantization as Quant
60-
from MaxText.layers.attentions import Attention
61-
from MaxText.layers.linears import DenseGeneral, MlpBlock
62-
from MaxText.layers.moe import RoutedMoE
63-
from MaxText.layers.initializers import nd_dense_init, variable_to_logically_partitioned
64-
from maxtext.inference import page_manager
65-
>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py
6648
from maxtext.utils import max_utils
6749
from maxtext.inference import page_manager, kvcache
6850

@@ -218,7 +200,7 @@ def pallas_chunk_gated_delta_rule(
218200
# =========================================================================
219201
initial_dtype = query.dtype
220202
if use_qk_norm_in_gdn:
221-
from MaxText.layers.normalizations import l2norm
203+
from maxtext.layers.normalizations import l2norm
222204
query = l2norm(query, dim=-1, eps=1e-6)
223205
key = l2norm(key, dim=-1, eps=1e-6)
224206

@@ -546,11 +528,7 @@ class Qwen3NextGatedDeltaNet(nnx.Module):
546528
2. output = Linear_out(y)
547529
"""
548530

549-
<<<<<<< HEAD:src/maxtext/models/qwen3.py
550-
def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs):
551-
=======
552-
def __init__(self, config: Config, *, rngs: nnx.Rngs, mesh: Mesh=None):
553-
>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py
531+
def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs, mesh: Mesh=None):
554532
"""
555533
Args:
556534
config: MaxText configuration object.
@@ -1148,11 +1126,7 @@ def __init__(
11481126
rngs=rngs,
11491127
)
11501128
else:
1151-
<<<<<<< HEAD:src/maxtext/models/qwen3.py
11521129
self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs)
1153-
=======
1154-
self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs, mesh=self.mesh)
1155-
>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py
11561130

11571131
# Second LayerNorm, applied before the MoE block.
11581132
self.post_attention_layernorm = Qwen3NextRMSNorm(

src/maxtext/utils/maxtext_utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -536,37 +536,13 @@ def calculate_gated_delta_net_flops_per_device(config):
536536
# We multiply by 2 for FMA
537537
flops_conv = 2 * B * S * K_conv * (2 * K_dim + V_dim)
538538

539-
<<<<<<< HEAD
540539
# 3. Core Gated Delta Net
541540
# This counts 4 distinct O(D^2) operations in the recurrent update:
542541
# KK^T, VK^T, S(a(I-bKK^T)), and SQ.
543542
# We multiply by 2 for FMA.
544543
# Total Core FLOPs = 2 (FMA) * 4 (Ops) * H * D^2 = 8 * H * D^2 per token.
545544
# We use D_k * D_v to generalize D^2 for potentially differing head dimensions.
546545
flops_core_per_token = H_v * (D_k * D_v) * 8
547-
=======
548-
# 3. Core Gated Delta Net (Optimized WY Representation)
549-
# The implementation broadcasts K heads to V heads if H_v > H_k
550-
H_eff = max(H_k, H_v)
551-
552-
# Per-token costs derived from jax_chunk_gated_delta_rule:
553-
# Intra-chunk Pre-computation:
554-
# S = K @ K.T: 2 * C * D_k
555-
# A = (I+S)^-1: ~ C^2 (Triangular solve approximation)
556-
# U = A @ V: 2 * C * D_v
557-
# W = A @ K: 2 * C * D_k
558-
# Scan / Output:
559-
# Out_Inter (Q @ h): 2 * D_k * D_v
560-
# Out_Intra_QK (Q @ K.T): 2 * C * D_k
561-
# Out_Intra_AV (Attn @ V): 2 * C * D_v
562-
# State_Update (W.T @ U): 2 * D_k * D_v
563-
564-
# Summing per-token factors:
565-
# (2*C*D_k) + C^2 + (2*C*D_v) + (2*C*D_k) + (2*D_k*D_v) + (2*C*D_k) + (2*C*D_v) + (2*D_k*D_v)
566-
# = 6*C*D_k + 4*C*D_v + 4*D_k*D_v + C^2
567-
568-
flops_core_per_token = H_eff * (6 * C * D_k + 4 * C * D_v + 4 * D_k * D_v + C**2)
569-
>>>>>>> 09f85a04f (Update tflops calc to align with WY-optimized GDN)
570546
flops_core = B * S * flops_core_per_token
571547

572548
# Weights part: Projections + Conv

0 commit comments

Comments
 (0)