|
29 | 29 | from flax import linen as nn |
30 | 30 | from flax import nnx |
31 | 31 |
|
32 | | -<<<<<<< HEAD:src/maxtext/models/qwen3.py |
33 | 32 | from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN |
34 | 33 | from maxtext.layers import attentions |
35 | 34 | from maxtext.layers import initializers as max_initializers |
|
44 | 43 | from maxtext.layers.moe import RoutedMoE |
45 | 44 | from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned |
46 | 45 |
|
47 | | -======= |
48 | 46 | from jax.sharding import PartitionSpec as P |
49 | 47 | from jax.experimental.shard_map import shard_map |
50 | | - |
51 | | -from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN |
52 | | -from MaxText.layers import attentions |
53 | | -from MaxText.layers import initializers as max_initializers |
54 | | -from MaxText.layers import moe |
55 | | -from MaxText.layers import nnx_wrappers |
56 | | -from MaxText.layers import quantizations |
57 | | -from MaxText.layers.embeddings import Qwen3OmniMoeVisionPosEmbedInterpolate, PositionalEmbedding |
58 | | -from MaxText.layers.normalizations import RMSNorm, l2norm, Qwen3NextRMSNorm, Qwen3NextRMSNormGated |
59 | | -from MaxText.layers.quantizations import AqtQuantization as Quant |
60 | | -from MaxText.layers.attentions import Attention |
61 | | -from MaxText.layers.linears import DenseGeneral, MlpBlock |
62 | | -from MaxText.layers.moe import RoutedMoE |
63 | | -from MaxText.layers.initializers import nd_dense_init, variable_to_logically_partitioned |
64 | | -from maxtext.inference import page_manager |
65 | | ->>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py |
66 | 48 | from maxtext.utils import max_utils |
67 | 49 | from maxtext.inference import page_manager, kvcache |
68 | 50 |
|
@@ -218,7 +200,7 @@ def pallas_chunk_gated_delta_rule( |
218 | 200 | # ========================================================================= |
219 | 201 | initial_dtype = query.dtype |
220 | 202 | if use_qk_norm_in_gdn: |
221 | | - from MaxText.layers.normalizations import l2norm |
| 203 | + from maxtext.layers.normalizations import l2norm |
222 | 204 | query = l2norm(query, dim=-1, eps=1e-6) |
223 | 205 | key = l2norm(key, dim=-1, eps=1e-6) |
224 | 206 |
|
@@ -546,11 +528,7 @@ class Qwen3NextGatedDeltaNet(nnx.Module): |
546 | 528 | 2. output = Linear_out(y) |
547 | 529 | """ |
548 | 530 |
|
549 | | -<<<<<<< HEAD:src/maxtext/models/qwen3.py |
550 | | - def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs): |
551 | | -======= |
552 | | - def __init__(self, config: Config, *, rngs: nnx.Rngs, mesh: Mesh=None): |
553 | | ->>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py |
| 531 | + def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs, mesh: Mesh=None): |
554 | 532 | """ |
555 | 533 | Args: |
556 | 534 | config: MaxText configuration object. |
@@ -1148,11 +1126,7 @@ def __init__( |
1148 | 1126 | rngs=rngs, |
1149 | 1127 | ) |
1150 | 1128 | else: |
1151 | | -<<<<<<< HEAD:src/maxtext/models/qwen3.py |
1152 | 1129 | self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs) |
1153 | | -======= |
1154 | | - self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs, mesh=self.mesh) |
1155 | | ->>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py |
1156 | 1130 |
|
1157 | 1131 | # Second LayerNorm, applied before the MoE block. |
1158 | 1132 | self.post_attention_layernorm = Qwen3NextRMSNorm( |
|
0 commit comments