diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..5b2149bcfa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,402 @@ +absl-py>=2.3.1 +aiofiles>=25.1.0 +aiohappyeyeballs>=2.6.1 +aiohttp>=3.13.3 +aiohttp-cors>=0.8.1 +aiosignal>=1.4.0 +annotated-doc>=0.0.4 +annotated-types>=0.7.0 +anthropic>=0.84.0 +antlr4-python3-runtime>=4.9.3 +anyio>=4.11.0 +aqtp>=0.9.0 +array_record>=0.8.3 +asttokens>=3.0.1 +astor>=0.8.1 +astroid>=4.0.2 +astunparse>=1.6.3 +attrs>=25.4.0 +auditwheel>=6.5.0 +black>=24.10.0 +blake3>=1.0.8 +blobfile>=3.1.0 +boto3>=1.42.56 +botocore>=1.42.56 +build>=1.3.0 +cachetools>=6.2.2 +cbor2>=5.8.0 +certifi>=2026.2.25 +cffi>=2.0.0 +cfgv>=3.5.0 +charset-normalizer>=3.4.4 +cheroot>=11.1.2 +chex>=0.1.91 +click>=8.3.1 +cloud-accelerator-diagnostics>=0.1.1 +cloud-tpu-diagnostics>=0.1.5 +cloudpickle>=3.1.2 +clu>=0.0.12 +cmake>=4.2.1 +colorama>=0.4.6 +colorful>=0.5.8 +comm>=0.2.3 +compressed-tensors>=0.13.0 +contourpy>=1.3.3 +coverage>=7.12.0 +cryptography>=46.0.5 +cycler>=0.12.1 +dacite>=1.9.2 +dataclasses-json>=0.6.7 +datasets>=4.6.0 +debugpy>=1.8.20 +decorator>=5.2.1 +depyf>=0.20.0 +dill>=0.4.0 +diskcache>=5.6.3 +distlib>=0.4.0 +distro>=1.9.0 +dm-tree>=0.1.9 +dnspython>=2.8.0 +docstring_parser>=0.17.0 +drjax>=0.1.4 +editdistance>=0.8.1 +einops>=0.8.1 +einshape>=1.0 +email-validator>=2.3.0 +entrypoints>=0.4 +etils>=1.13.0 +evaluate>=0.4.6 +execnet>=2.1.2 +executing>=2.2.1 +fastapi>=0.122.0 +fastapi-cli>=0.0.24 +fastapi-cloud-cli>=0.13.0 +fastar>=0.8.0 +fastjsonschema>=2.21.2 +filelock>=3.20.0 +flatbuffers>=25.9.23 +flax>=0.12.4 +fonttools>=4.60.1 +frozenlist>=1.8.0 +fsspec>=2026.1.0 +gast>=0.6.0 +gcsfs>=2026.1.0 +gguf>=0.17.1 +google-api-core>=2.28.1 +google-api-python-client>=2.187.0 +google-auth>=2.43.0 +google-auth-httplib2>=0.2.1 +google-auth-oauthlib>=1.2.2 +google-cloud-aiplatform>=1.128.0 +google-cloud-appengine-logging>=1.7.0 +google-cloud-audit-log>=0.4.0 +google-cloud-bigquery>=3.38.0 +google-cloud-core>=2.5.0 +google-cloud-logging>=3.12.1 +google-cloud-mldiagnostics>=0.5.10 +google-cloud-monitoring>=2.28.0 +google-cloud-resource-manager>=1.15.0 +google-cloud-storage>=3.9.0 +google-cloud-storage-control>=1.10.0 +google-crc32c>=1.7.1 +google-genai>=1.52.0 +google-pasta>=0.2.0 +google-resumable-media>=2.8.0 +google_metrax>=0.2.4 +googleapis-common-protos>=1.72.0 +grain>=0.2.15 +grpc-google-iam-v1>=0.14.3 +grpcio>=1.78.0 +grpcio-reflection>=1.71.0 +grpcio-status>=1.71.2 +gspread>=6.2.1 +gviz-api>=1.10.0 +h11>=0.16.0 +h5py>=3.15.1 +hf-xet>=1.2.0 +hf_transfer>=0.1.9 +httpcore>=1.0.9 +httplib2>=0.31.0 +httptools>=0.7.1 +httpx>=0.28.1 +httpx-sse>=0.4.3 +huggingface-hub>=0.36.0 +humanize>=4.14.0 +hypothesis>=6.142.1 +identify>=2.6.15 +idna>=3.11 +ijson>=3.5.0 +immutabledict>=4.2.2 +importlab>=0.8.1 +importlib_metadata>=8.7.0 +importlib_resources>=6.5.2 +iniconfig>=2.3.0 +interegular>=0.3.3 +ipykernel>=7.2.0 +ipython>=9.10.0 +ipython_pygments_lexers>=1.1.1 +ipywidgets>=8.1.8 +isort>=7.0.0 +jaraco.classes>=3.4.0 +jaraco.context>=6.1.0 +jaraco.functools>=4.3.0 +jax>=0.9.1 +jaxlib>=0.9.1 +jaxtyping>=0.3.3 +jedi>=0.19.2 +jeepney>=0.9.0 +Jinja2>=3.1.6 +jiter>=0.13.0 +jmespath>=1.1.0 +joblib>=1.5.2 +jsonlines>=4.0.0 +jsonschema>=4.26.0 +jsonschema-specifications>=2025.9.1 +jupyter_client>=8.8.0 +jupyter_core>=5.9.1 +jupyterlab_widgets>=3.0.16 +kagglehub>=0.3.13 +keras>=3.12.0 +keyring>=25.7.0 +keyrings.google-artifactregistry-auth>=1.1.2 +kiwisolver>=1.4.9 +lark>=1.2.2 +latex2sympy2_extended>=1.11.0 +libclang>=18.1.1 +libcst>=1.8.6 +libtpu>=0.0.38 +llguidance>=1.3.0 +llvmlite>=0.45.1 +lm-format-enforcer>=0.11.3 +loguru>=0.7.3 +lxml>=6.0.2 +Markdown>=3.10 +markdown-it-py>=4.0.0 +MarkupSafe>=3.0.3 +marshmallow>=3.26.2 +math-verify>=0.9.0 +matplotlib>=3.10.7 +matplotlib-inline>=0.2.1 +mccabe>=0.7.0 +mcp>=1.26.0 +mdurl>=0.1.2 +mistral_common>=1.9.1 +ml_collections>=1.1.0 +ml_dtypes>=0.5.4 +ml_goodput_measurement>=0.0.15 +model-hosting-container-standards>=0.1.13 +more-itertools>=10.8.0 +mpmath>=1.3.0 +msgpack>=1.1.2 +msgspec>=0.20.0 +multidict>=6.7.0 +multiprocess>=0.70.18 +mypy_extensions>=1.1.0 +namex>=0.1.0 +nbclient>=0.10.4 +nbformat>=5.10.4 +nest-asyncio>=1.6.0 +networkx>=3.6 +ninja>=1.13.0 +nixl>=0.3.0 +nltk>=3.9.2 +nodeenv>=1.9.1 +numba>=0.62.1 +numpy>=2.2.6 +numpy-typing-compat>=20250818.2.2 +nvidia-cublas-cu12>=12.8.4.1 +nvidia-cuda-cupti-cu12>=12.8.90 +nvidia-cuda-nvrtc-cu12>=12.8.93 +nvidia-cuda-runtime-cu12>=12.8.90 +nvidia-cudnn-cu12>=9.10.2.21 +nvidia-cufft-cu12>=11.3.3.83 +nvidia-cufile-cu12>=1.13.1.3 +nvidia-curand-cu12>=10.3.9.90 +nvidia-cusolver-cu12>=11.7.3.90 +nvidia-cusparse-cu12>=12.5.8.93 +nvidia-cusparselt-cu12>=0.7.1 +nvidia-nccl-cu12>=2.27.5 +nvidia-nvjitlink-cu12>=12.8.93 +nvidia-nvshmem-cu12>=3.3.20 +nvidia-nvtx-cu12>=12.8.90 +oauthlib>=3.3.1 +omegaconf>=2.3.0 +openai>=2.24.0 +openai-harmony>=0.0.8 +opencensus>=0.11.4 +opencensus-context>=0.1.3 +opencv-python-headless>=4.13.0.92 +opentelemetry-api>=1.39.1 +opentelemetry-exporter-otlp>=1.39.1 +opentelemetry-exporter-otlp-proto-common>=1.39.1 +opentelemetry-exporter-otlp-proto-grpc>=1.39.1 +opentelemetry-exporter-otlp-proto-http>=1.39.1 +opentelemetry-exporter-prometheus>=0.60b1 +opentelemetry-proto>=1.39.1 +opentelemetry-sdk>=1.39.1 +opentelemetry-semantic-conventions>=0.60b1 +opentelemetry-semantic-conventions-ai>=0.4.15 +opt_einsum>=3.4.0 +optax>=0.2.6 +optree>=0.18.0 +optype>=0.14.0 +orbax-checkpoint>=0.11.28 +orbax-export>=0.0.8 +outlines_core>=0.2.11 +packaging>=26.0 +pandas>=2.3.3 +papermill>=2.7.0 +parameterized>=0.9.0 +parso>=0.8.6 +partial-json-parser>=0.2.1.1.post7 +pathspec>=0.12.1 +pathwaysutils>=0.1.6 +perfetto>=0.16.0 +pexpect>=4.9.0 +pillow>=12.0.0 +platformdirs>=4.9.2 +pluggy>=1.6.0 +portpicker>=1.6.0 +pre_commit>=4.5.0 +prometheus-fastapi-instrumentator>=7.1.0 +prometheus_client>=0.23.1 +promise>=2.3 +prompt_toolkit>=3.0.52 +propcache>=0.4.1 +proto-plus>=1.26.1 +protobuf>=5.29.6 +psutil>=7.2.2 +ptyprocess>=0.7.0 +pure_eval>=0.2.3 +py-cpuinfo>=9.0.0 +py-spy>=0.4.1 +pyarrow>=22.0.0 +pyasn1>=0.6.1 +pyasn1_modules>=0.4.2 +pybase64>=1.4.3 +pycnite>=2024.7.31 +pycountry>=26.2.16 +pycparser>=3.0 +pycryptodomex>=3.23.0 +pydantic>=2.12.5 +pydantic-extra-types>=2.11.0 +pydantic-settings>=2.13.1 +pydantic_core>=2.41.5 +pydot>=4.0.1 +pyelftools>=0.32 +pyglove>=0.4.5 +Pygments>=2.19.2 +pyink>=24.10.1 +PyJWT>=2.11.0 +pylatexenc>=2.10 +pylint>=4.0.3 +pyparsing>=3.2.5 +pyproject_hooks>=1.2.0 +pytest>=8.4.2 +pytest-mock>=3.15.1 +pytest-xdist>=3.8.0 +python-dateutil>=2.9.0.post0 +python-dotenv>=1.2.1 +python-json-logger>=4.0.0 +python-multipart>=0.0.22 +pytype>=2024.10.11 +pytz>=2025.2 +PyYAML>=6.0.3 +pyzmq>=27.1.0 +qwix>=0.1.2 +ray>=2.54.0 +referencing>=0.37.0 +regex>=2025.11.3 +requests>=2.32.5 +requests-oauthlib>=2.0.0 +rich>=14.2.0 +rich-toolkit>=0.19.7 +rignore>=0.7.6 +rpds-py>=0.30.0 +rsa>=4.9.1 +runai-model-streamer>=0.15.4 +runai-model-streamer-gcs>=0.15.4 +runai-model-streamer-s3>=0.15.4 +s3transfer>=0.16.0 +safetensors>=0.7.0 +scipy>=1.16.3 +scipy-stubs>=1.16.3.0 +SecretStorage>=3.5.0 +sentencepiece>=0.2.1 +sentry-sdk>=2.53.0 +seqio>=0.0.20 +setproctitle>=1.3.7 +setuptools>=78.1.0 +setuptools-scm>=9.2.2 +shapely>=2.1.2 +shellingham>=1.5.4 +shortuuid>=1.0.13 +simple-parsing>=0.1.7 +simplejson>=3.20.2 +six>=1.17.0 +smart_open>=7.5.1 +sniffio>=1.3.1 +sortedcontainers>=2.4.0 +sse-starlette>=3.2.0 +starlette>=0.50.0 +stack-data>=0.6.3 +supervisor>=4.3.0 +sympy>=1.14.0 +tabulate>=0.9.0 +tenacity>=9.1.4 +tensorboard>=2.20.0 +tensorboard-data-server>=0.7.2 +tensorboard-plugin-profile>=2.13.0 +tensorboardX>=2.6.4 +tensorflow>=2.20.0 +tensorflow-datasets>=4.9.9 +tensorflow-metadata>=1.17.2 +tensorflow-text>=2.20.0 +tensorstore>=0.1.79 +termcolor>=3.2.0 +tiktoken>=0.12.0 +tokamax>=0.0.8 +tokenizers>=0.22.1 +toml>=0.10.2 +tomlkit>=0.13.3 +toolz>=1.1.0 +torch>=2.10.0 +torchax>=0.0.11 +torchvision>=0.24.0 +tornado>=6.5.4 +tpu-info>=0.7.1 +tqdm>=4.67.3 +traitlets>=5.14.3 +transformers>=4.57.1 +treescope>=0.1.10 +triton>=3.5.0 +typeguard>=2.13.3 +typer>=0.24.1 +typing-inspect>=0.9.0 +typing-inspection>=0.4.2 +typing_extensions>=4.15.0 +tzdata>=2025.2 +uritemplate>=4.2.0 +urllib3>=2.5.0 +uv>=0.10.6 +uvicorn>=0.38.0 +uvloop>=0.22.1 +virtualenv>=20.35.4 +wadler_lindig>=0.1.7 +watchfiles>=1.1.1 +wcwidth>=0.6.0 +websockets>=15.0.1 +Werkzeug>=3.1.3 +wheel>=0.46.3 +widgetsnbextension>=4.0.15 +wrapt>=2.0.1 +xgrammar>=0.1.29 +xprof>=2.21.1 +xxhash>=3.6.0 +yapf>=0.43.0 +yarl>=1.22.0 +zipp>=3.23.0 +zstandard>=0.25.0 + +# tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/10c269b0cf8a18628875c7b27a3e3d76f25b8920.zip +# vllm @ git+https://github.com/vllm-project/vllm@ab79863e6c4f4df652328af6901be2ee208dacec diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index ec2e96333f..20966fa62b 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -32,6 +32,7 @@ AxisIdxes = tuple[int, ...] BATCH = "activation_batch" +BATCH_ATTN = "activation_batch_attn" ATTN_LENGTH = "activation_attn_length" diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index af5e28e7a8..28622b4dc8 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -592,6 +592,7 @@ ici_tensor_sequence_parallelism: 1 ici_autoregressive_parallelism: 1 ici_pipeline_parallelism: 1 ici_expert_parallelism: 1 +ici_attn_dp_expert_parallelism: 1 # Enable ZeRO-1 optimizer sharding over data axis shard_optimizer_over_data: False @@ -985,7 +986,7 @@ xprof_e2e_enable_fw_power_level_event: False xprof_e2e_enable_fw_thermal_event: False profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing. -log_config: True # Prints the config (after defaults have been set by pyconfig logic) +log_config: False # Prints the config (after defaults have been set by pyconfig logic) debug_sharding: False # Prints model weights sharding info # Checkpoint Structured logging diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 139012406a..e7863e9961 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -29,55 +29,78 @@ weight_dtype: bfloat16 # -------------- Logical Axis Rules -------------- mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert'] logical_axis_rules: [ - ['activation_batch', ['data']], - ['activation_batch_moe', []], - ['activation_embed_and_logits_batch', ['data', 'expert']], - ['activation_embed_and_logits_batch_sequence', ['data', 'expert']], + # ========================================== + # Vocabulary Embedding + # ========================================== + # Vocab Activations + ['activation_embed_and_logits_batch', ['data']], + ['activation_embed_and_logits_batch_sequence', ['data']], + ['activation_vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']], + # Vocab Weights + ['vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']], + ['embed_vocab', []], + # ========================================== + # Attention + # ========================================== + # Attention Activations + ['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']], ['activation_heads', ['model', 'expert']], ['activation_kv_heads', ['model', 'expert']], - ['activation_attn_length', []], - ['activation_length', ['data']], - ['activation_length_moe', ['data', 'expert']], - ['activation_length_moe', 'data'], - ['activation_q_length', ['expert', 'attn_dp_expert']], - ['activation_attn_embed', 'model'], - ['activation_embed', ['model', 'attn_dp']], - ['activation_embed_moe', ['model', 'attn_dp']], - ['activation_mlp', ['model', 'attn_dp']], - ['activation_mlp_moe', ['model', 'attn_dp']], - ['activation_kv', ['model']], - ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], - ['activation_kv_batch', ['data']], - ['activation_kv_head_dim', ['model']], - ['activation_vocab', ['model', 'attn_dp']], - ['activation_norm_length', []], - ['activation_norm_length_moe', []], - ['activation_exp', ['expert', 'attn_dp_expert']], - ['decode_batch', ['expert', 'attn_dp_expert']], - ['decode_batch_moe', []], - ['decode_length', []], - ['mlp', ['model', 'attn_dp']], - ['mlp_moe', ['model', 'attn_dp']], - ['mlp_no_fsdp', ['model', 'attn_dp']], - ['vocab', ['model', 'attn_dp']], - ['heads', ['model']], + ['activation_attn_embed', []], + ['activation_kv', ['model', 'expert']], + ['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']], + ['activation_kv_head_dim', []], + # Attention Weights + ['heads', ['model', 'expert']], ['q_heads', ['model', 'expert']], ['kv_heads', ['model', 'expert']], - ['kv_head_dim', []], + ['qkv', []], ['kv', []], - ['embed', ['expert', 'attn_dp_expert']], - ['embed', ['attn_dp_expert']], - ['embed_vocab', ['expert', 'attn_dp_expert']], - ['embed_vocab', ['attn_dp_expert']], - ['embed_moe', []], + ['kv_head_dim', []], + ['q_lora', []], + ["q_lora_up_proj", []], + ['kv_lora', []], + ["kv_lora_up_proj", []], + # ========================================== + # Mixture of Experts (MoE) + # ========================================== + # MoE Activations + ['activation_batch_moe', ['data']], + ['activation_embed_moe', ['model']], + ['activation_mlp_moe', []], + ['activation_exp', ['expert', 'attn_dp', 'attn_dp_expert']], + # MoE Weights + ['exp', ['expert', 'attn_dp', 'attn_dp_expert']], + ['mlp_moe', []], ['embed_moe', []], - ['embed_tensor_transpose', ['attn_dp', 'model']], - ['q_lora', ['expert', 'attn_dp_expert']], - ['kv_lora', ['expert', 'attn_dp_expert']], + # ========================================== + # Standard MLP / Dense Layers / Model Structure + # ========================================== + # Dense Activations + ['activation_mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']], + # Note activation batch and length also get used in attention and vocab + ['activation_batch', ['data']], + ['activation_embed', ['model', 'expert', 'attn_dp', 'attn_dp_expert']], + # General Weights + ['mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']], + ['embed', []], ['norm', []], + # ========================================== + # Inference(Prefill, Decode, Cache) + # ========================================== + ['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']], + ['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']], + ['cache_heads', ['model', 'expert']], ['cache_heads', ['model']], - ['exp', ['expert', 'attn_dp_expert']], - ['paged_kv_heads', ['model']], - ] + ['paged_kv_heads', ['model', 'expert']], + ['cache_batch_prefill', []], + ['cache_batch', []], + ['cache_heads_none', []], + ['cache_kv', []], + ['cache_sequence', []], + ['num_pages', []], + ['tokens_per_page', []], + ['paged_kv_head_dim_size', []], + ] data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']] input_data_sharding_logical_axes: ['activation_embed_and_logits_batch'] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 84232f8633..305ce0647f 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -686,6 +686,11 @@ class MoEGeneral(BaseModel): False, description="Whether to cast inputs to fp32 to compute MoE gate logits for numerical stability.", ) + prefuse_moe_weights: bool = Field( + False, + description="Whether to pre-fuse MoE weights (w0 and w1) during initialization. " + "This is useful for inference performance in vllm_rpa mode.", + ) class MoEKernels(BaseModel): @@ -881,6 +886,7 @@ class IciParallelism(BaseModel): ici_autoregressive_parallelism: int = Field(1, description="ICI axis for autoregressive parallelism.") ici_pipeline_parallelism: int = Field(1, description="ICI axis for pipeline parallelism.") ici_expert_parallelism: int = Field(1, description="ICI axis for expert parallelism.") + ici_attn_dp_expert_parallelism: int = Field(1, description="ICI axis for attn dp expert parallelism.") class PipelineParallelism(BaseModel): @@ -2741,7 +2747,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "expert": self.ici_expert_parallelism, "autoregressive": self.ici_autoregressive_parallelism, "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads - "attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP + "attn_dp_expert": self.ici_attn_dp_expert_parallelism, } self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index 66af92e209..bd35ee284e 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -82,6 +82,7 @@ def decode_with_vllm(config: Config) -> None: "weight_dtype": "bfloat16", "allow_split_physical_axes": True, "debug_sharding": config.debug_sharding, + "prefuse_moe_weights": config.prefuse_moe_weights, }, "sharding": { "sharding_strategy": { @@ -99,6 +100,9 @@ def decode_with_vllm(config: Config) -> None: enable_expert_parallel = config.ici_expert_parallelism > 1 if enable_expert_parallel: vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = config.ici_expert_parallelism + vllm_args["additional_config"]["sharding"]["sharding_strategy"][ + "attention_data_expert_parallelism" + ] = config.ici_attn_dp_expert_parallelism vllm_args["enable_expert_parallel"] = enable_expert_parallel max_logging.log( diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index a4cd924672..f06e7ed927 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -124,6 +124,9 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): # Model creation self.model: nnx.Module | None = None + # Indicates that the model handles its own sharding logic + self._self_manages_sharding = True + # Handle dummy weight loading during initialization if vllm_config.load_config.load_format == "dummy": self.load_weights(rng_key) @@ -161,8 +164,8 @@ def __call__( raise ValueError("Model must be an instance of type nnx.Module.") # Ensure inputs are at least 2D with a batch dimension - input_ids = jnp.atleast_2d(input_ids) - input_positions = jnp.atleast_2d(attention_metadata.input_positions) + input_ids = jnp.expand_dims(input_ids, axis=1) + input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1) with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): aux_hidden_states = [] @@ -233,7 +236,7 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array: with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): # Reshape to (num_tokens, 1, hidden_dim) for decoder output head - y = hidden_states[:, jnp.newaxis, :] + y = jnp.expand_dims(hidden_states, axis=1) # Compute logits using the MaxText decoder's output head logits = self.model.decoder.apply_output_head(self.model.token_embedder, y, True, self.model_mode) @@ -250,7 +253,7 @@ def load_weights(self, rng_key: jax.Array) -> None: if self.model is not None: return - with self.mesh, nn.logical_axis_rules(""): + with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): model, _ = model_creation_utils.create_nnx_model( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index f0d7791a0d..3f600ef4de 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -36,7 +36,7 @@ Array, AxisIdxes, AxisNames, - BATCH, + BATCH_ATTN as BATCH, CACHE_BATCH, CACHE_BATCH_PREFILL, CACHE_SEQUENCE, diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index b796ea9e6e..ef51a5f6fc 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -38,7 +38,7 @@ AttentionType, AxisIdxes, AxisNames, - BATCH, + BATCH_ATTN as BATCH, CACHE_BATCH, CACHE_BATCH_PREFILL, CACHE_HEADS, diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index d9ef0f8848..3c7cad9682 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -27,7 +27,7 @@ from maxtext.common.common_types import ( DecoderBlockType, - BATCH, + BATCH_ATTN as BATCH, HEAD, PREFILL_LENGTH, D_KV, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index d622d30912..697caa03b0 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -384,7 +384,10 @@ def __init__( kernel_init=self.kernel_init, kernel_axes=self.kernel_axes, use_bias=self.config.routed_bias, - score_func=self.config.routed_score_func, + # tpu-inference applies the score function in the fused_moe_gmm kernel, + # so we don't apply it here to avoid redundant computation. + # See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58. + score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func, matmul_precision=self.config.matmul_precision, shard_mode=config.shard_mode, rngs=self.rngs, @@ -403,6 +406,27 @@ def __init__( self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim)) + elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa": + self.wi = nnx.Param( + self.kernel_init( + self.rngs.params(), + (num_experts, self.config.emb_dim, intermediate_dim * 2), + weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + sharding=self.wi_kernel_axes, + ) + self.wo = nnx.Param( + self.kernel_init( + self.rngs.params(), + (self.num_experts, self.intermediate_dim, self.config.emb_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + sharding=self.wo_kernel_axes, + ) else: self.wi_0 = nnx.Param( self.kernel_init( @@ -1970,6 +1994,72 @@ def dense_matmul( ).astype(self.dtype) return output, lb_loss, bias_updates + def fused_moe_matmul( + self, + inputs, + gate_logits, + wo_kernel, + w0_kernel=None, + w1_kernel=None, + fused_kernel=None, + ) -> tuple[jax.Array, None, None]: + """Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only). + + fused_moe_func handles routing, GMM, and weighted combination internally. + It does not compute lb_loss or bias_updates (inference-only). + """ + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.layers.common.fused_moe_gmm import fused_moe_func + except ImportError as e: + raise ImportError("fused_moe_matmul requires the tpu-inference package.") from e + + # Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input) + batch_size, seq_len, emb_dim = inputs.shape + hidden_states = jnp.reshape(inputs, (batch_size * seq_len, emb_dim)) + gating_output = jnp.reshape(gate_logits, (batch_size * seq_len, self.num_experts)) + + # Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H] + # fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:] + if fused_kernel is None: + fused_kernel = jnp.concatenate([w0_kernel, w1_kernel], axis=-1) + + # Use expert parallelism if the expert axis has size > 1 + use_ep = self.get_expert_parallelism_size() > 1 + + # Map MaxText config fields to fused_moe_func args + activation = self.config.mlp_activations[0] # e.g. "silu" + scoring_fn = self.config.routed_score_func if self.config.routed_score_func else "softmax" + + # Check if the model architecture intrinsically renormalizes weights + renormalize = self.config.norm_topk_prob or ( + self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4) + ) + + output_2d = fused_moe_func( + hidden_states=hidden_states, + w1=fused_kernel, + w2=wo_kernel, + w1_scale=None, + w2_scale=None, + w1_bias=None, + w2_bias=None, + gating_output=gating_output, + topk=self.num_experts_per_tok, + renormalize=renormalize, + mesh=self.mesh, + use_ep=use_ep, + activation=activation, + scoring_fn=scoring_fn, + sc_kernel_threshold=16777216, + sc_kernel_col_chunk_size=1024, + ) + + # Reshape output 2D [T, D] -> 3D [B, S, D] + output = jnp.reshape(output_2d, (batch_size, seq_len, emb_dim)) + return output, None, None + def retrieve_quantized_weight( self, inputs, @@ -2008,10 +2098,17 @@ def __call__( routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype) gate_logits, pre_bias_logits = self.gate(routing_inputs) - w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) - w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) wo_kernel = jnp.asarray(self.wo[...], self.dtype) + fused_kernel = None + w0_kernel = None + w1_kernel = None + if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa": + fused_kernel = jnp.asarray(self.wi[...], self.dtype) + else: + w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) + w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) + if self.per_expert_scale is not None: wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] @@ -2022,7 +2119,12 @@ def __call__( else: w0_bias, w1_bias, wo_bias = None, None, None - if cfg.sparse_matmul: + # vllm_rpa codepath uses fused_moe_func from tpu_inference for optimized inference. + if cfg.attention == "vllm_rpa": + output, lb_loss, bias_updates = self.fused_moe_matmul( + inputs, gate_logits, wo_kernel, w0_kernel=w0_kernel, w1_kernel=w1_kernel, fused_kernel=fused_kernel + ) + elif cfg.sparse_matmul: if quantizations.in_serve_mode(self.quant): w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight( inputs, diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 8b9a72a64f..e03696fd1e 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -13,13 +13,13 @@ # limitations under the License. # pylint: disable=line-too-long, disable=bare-except, consider-using-generator -""" Utils that are only interesting to MaxText. """ +"""Utils that are only interesting to MaxText.""" import functools import pickle import os -from flax import linen as nn +from flax import nnx, linen as nn from flax.linen import partitioning as nn_partitioning from flax.training import train_state @@ -1625,7 +1625,35 @@ def schedule(step): return optax.join_schedules(pieces, boundaries) -def print_shardings_params(params, params_sharding, mesh, logical_annotations=None): +# def print_shardings_params(params, params_sharding, mesh, logical_annotations=None): +# """ +# Print state shardings comparing Logical Definition vs Physical Result. +# """ +# if not hasattr(params, "params"): +# params = {"params": params} +# if not hasattr(params_sharding, "params"): +# params_sharding = {"params": params_sharding} +# if logical_annotations and not hasattr(logical_annotations, "params"): +# logical_annotations = {"params": logical_annotations} + +# leaves_params, _ = jax.tree_util.tree_flatten_with_path(params, is_leaf=lambda x: isinstance(x, nnx.Variable)) +# leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding, is_leaf=lambda x: isinstance(x, nnx.Variable)) +# leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations, is_leaf=lambda x: isinstance(x, nnx.Variable)) + +# for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): +# path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) +# shape = jax.typeof(getattr(leaf_val, "value")) +# pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) +# pspec_str = str(tuple(pspec)) +# logical_str = str(getattr(leaf_logical_val, "out_sharding", None)) + +# message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" +# max_logging.info(message) + +# print(flush=True) + + +def print_shardings_params(params, params_sharding, mesh, logical_annotations=None, target_layer=0): """ Print state shardings comparing Logical Definition vs Physical Result. """ @@ -1636,16 +1664,33 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No if logical_annotations and not hasattr(logical_annotations, "params"): logical_annotations = {"params": logical_annotations} - leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) - leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) + leaves_params, _ = jax.tree_util.tree_flatten_with_path(params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + leaves_sharding, _ = jax.tree_util.tree_flatten_with_path( + params_sharding, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + leaves_logical, _ = jax.tree_util.tree_flatten_with_path( + logical_annotations, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) + # Extract path keys to accurately check for layer names + path_keys = [str(p.key if hasattr(p, "key") else p.name) for p in path] + path_str = "/".join(path_keys) + + # Check if param is inside a layer block, and if it's the target layer + is_layer_param = any(k.startswith("layers_") for k in path_keys) + is_target_layer = any(k == f"layers_{target_layer}" for k in path_keys) + # Skip logging if it belongs to a layer that isn't our target + if is_layer_param and not is_target_layer: + continue + + if "to_nnx__rngs" in path_str: + continue + + shape = jax.typeof(getattr(leaf_val, "value")) pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) pspec_str = str(tuple(pspec)) - logical_str = str(leaf_logical_val) + logical_str = str(getattr(leaf_logical_val, "out_sharding", None)) message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" max_logging.info(message) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 49fb9d3490..1f2879fb45 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -25,6 +25,7 @@ import flax.linen as nn import jax import jax.numpy as jnp +import numpy as np from jax.sharding import AxisType, Mesh from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode @@ -330,6 +331,39 @@ def create_sharded_state(): # Get the structure of checkpoint in `config.load_parameters_path` metadata = ckptr.metadata(config.load_parameters_path) + def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx): + if not hasattr(target, "items") or not hasattr(meta_tree, "items"): + return target + new_target = {} + for k, v in target.items(): + if k == "wi" and "wi" not in meta_tree and "wi_0" in meta_tree and "wi_1" in meta_tree: + if not is_nnx: + arr = v + half_dim = arr.shape[-1] // 2 + new_target["wi_0"] = jax.ShapeDtypeStruct( + shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding + ) + new_target["wi_1"] = jax.ShapeDtypeStruct( + shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding + ) + else: + arr = v["value"] + half_dim = arr.shape[-1] // 2 + new_target["wi_0"] = { + "value": jax.ShapeDtypeStruct( + shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding + ) + } + new_target["wi_1"] = { + "value": jax.ShapeDtypeStruct( + shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding + ) + } + else: + new_target[k] = _adjust_target_for_moe_fusion(v, meta_tree.get(k, {}), is_nnx) + + return new_target + is_nnx_checkpoint = True if ( "params" in metadata.item_metadata.tree.keys() @@ -343,6 +377,10 @@ def create_sharded_state(): is_leaf=lambda n: hasattr(n, "value"), ) + target_for_restore = _adjust_target_for_moe_fusion( + target_for_restore, metadata.item_metadata.tree["params"]["params"], False + ) + item_to_restore = {"params": {"params": target_for_restore}} base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore) restore_args = { @@ -361,6 +399,7 @@ def create_sharded_state(): sharded_state, is_leaf=lambda n: isinstance(n, nnx.Variable), ) + target_for_restore = _adjust_target_for_moe_fusion(target_for_restore, metadata.item_metadata.tree, True) item_to_restore = target_for_restore base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore) restore_args = _fix_restore_args_for_shape_mismatch( @@ -400,6 +439,36 @@ def create_sharded_state(): sharded_state, is_leaf=lambda n: isinstance(n, nnx.Variable), ) + + def to_dict(tree): + if hasattr(tree, "items"): + return {k: to_dict(v) for k, v in tree.items()} + return tree + + model_arrays = to_dict(model_arrays) + checkpoint = to_dict(checkpoint) + + def _fuse_moe_weights(ckpt_tree, model_arrays_tree): + if not hasattr(ckpt_tree, "items") or not hasattr(model_arrays_tree, "items"): + return ckpt_tree + new_ckpt = {} + for k, v in ckpt_tree.items(): + if k in ("wi_0", "wi_1") and "wi" in model_arrays_tree: + continue + new_ckpt[k] = _fuse_moe_weights(v, model_arrays_tree.get(k, {})) + + if "wi" in model_arrays_tree and "wi_0" in ckpt_tree and "wi_1" in ckpt_tree: + wi_0 = ckpt_tree["wi_0"] + wi_1 = ckpt_tree["wi_1"] + new_ckpt["wi"] = np.concatenate([wi_0, wi_1], axis=-1) + + return new_ckpt + + checkpoint = _fuse_moe_weights(checkpoint, model_arrays) + # Release the raw restored buffers now that wi_0/wi_1 have been fused (if needed). + # This prevents the replicated intermediate copies from persisting until function return. + del restored + checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays) nnx.update(model, checkpoint) diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 416c08a2f7..386273ba10 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -20,6 +20,7 @@ from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp +import numpy as np from jax.sharding import Mesh from maxtext.configs import pyconfig from maxtext.common.common_types import Config, DType @@ -1064,5 +1065,288 @@ def test_get_all_to_all_params_unsharded_batch(self): ) +def make_moe(cfg, mesh): + return moe.RoutedMoE( + config=cfg, + num_experts=cfg.num_experts, + num_experts_per_tok=cfg.num_experts_per_tok, + mesh=mesh, + kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", "mlp"), + dtype=cfg.dtype, + rngs=nnx.Rngs(params=0), + ) + + +def copy_weights(src_model, dst_model): + """Copy wi_0, wi_1, wo, and gate weights from src to dst.""" + dst_model.wi_0 = src_model.wi_0 + dst_model.wi_1 = src_model.wi_1 + dst_model.wo = src_model.wo + dst_model.gate = src_model.gate + + +def copy_weights_prefused(src_model, dst_model): + """Copy weights from a split-weight model into a prefuse_moe_weights=True model. + + Concatenates src wi_0 and wi_1 along the last axis to produce the fused wi. + """ + wi_fused = jnp.concatenate([src_model.wi_0[...], src_model.wi_1[...]], axis=-1) + dst_model.wi = nnx.Param(wi_fused) + dst_model.wo = src_model.wo + dst_model.gate = src_model.gate + + +# fused_moe_func requires num_tokens * topk % 16 == 0. +# B=1, S=16, topk=2 -> T*topk = 32, divisible by 16. +_B = 1 +_S = 16 + + +@pytest.mark.tpu_only +class FusedMoeTPUTest(unittest.TestCase): + """Tests for fused_moe_matmul (vllm_rpa path) in RoutedMoE.""" + + def setUp(self): + super().setUp() + self.rng = jax.random.PRNGKey(42) + + # Dense reference config (no vllm, einsum-based) + extra_args = get_decoupled_parallelism_overrides() + self.dense_cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="fused_moe_dense_ref", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + sparse_matmul=False, + megablox=False, + ici_expert_parallelism=jax.device_count(), + log_config=False, + max_target_length=_S, + per_device_batch_size=_B, + **extra_args, + ) + dense_devices = maxtext_utils.create_device_mesh(self.dense_cfg) + self.dense_mesh = Mesh(dense_devices, self.dense_cfg.mesh_axes) + self.dense_model = make_moe(self.dense_cfg, self.dense_mesh) + + # vllm_rpa fused config + self.fused_cfg = pyconfig.initialize( + [None, get_test_config_path("inference/vllm.yml")], + run_name="fused_moe_vllm", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + ici_expert_parallelism=jax.device_count(), + log_config=False, + max_target_length=_S, + per_device_batch_size=_B, + ) + fused_devices = maxtext_utils.create_device_mesh(self.fused_cfg) + self.fused_mesh = Mesh(fused_devices, self.fused_cfg.mesh_axes) + self.fused_model = make_moe(self.fused_cfg, self.fused_mesh) + copy_weights(self.dense_model, self.fused_model) + + def _inputs(self): + return jax.random.normal(self.rng, (_B, _S, self.dense_cfg.base_emb_dim), dtype=jnp.bfloat16) + + def test_fused_vs_dense_softmax(self): + """fused_moe_matmul agrees with dense_matmul under softmax routing.""" + inputs = self._inputs() + + dense_out, _, _ = self.dense_model(inputs) + fused_out, lb_loss, bias_updates = self.fused_model(inputs) + + np.testing.assert_allclose( + np.array(dense_out, dtype=np.float32), + np.array(fused_out, dtype=np.float32), + rtol=1e-2, + atol=1e-2, + ) + self.assertIsNone(lb_loss) + self.assertIsNone(bias_updates) + + def test_fused_vs_sparse_softmax(self): + """fused_moe_matmul agrees with sparse_matmul (Megablox) under softmax routing.""" + extra_args = get_decoupled_parallelism_overrides() + sparse_cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="fused_moe_sparse_ref", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + sparse_matmul=True, + megablox=True, + ici_expert_parallelism=jax.device_count(), + log_config=False, + max_target_length=_S, + per_device_batch_size=_B, + **extra_args, + ) + sparse_devices = maxtext_utils.create_device_mesh(sparse_cfg) + sparse_mesh = Mesh(sparse_devices, sparse_cfg.mesh_axes) + sparse_model = make_moe(sparse_cfg, sparse_mesh) + copy_weights(self.dense_model, sparse_model) + + inputs = self._inputs() + sparse_out, _, _ = sparse_model(inputs) + fused_out, lb_loss, bias_updates = self.fused_model(inputs) + + np.testing.assert_allclose( + np.array(sparse_out, dtype=np.float32), + np.array(fused_out, dtype=np.float32), + rtol=1e-2, + atol=1e-2, + ) + self.assertIsNone(lb_loss) + self.assertIsNone(bias_updates) + + def test_fused_output_shape_and_dtype(self): + """Output shape is (B, S, D), dtype matches cfg.dtype, and losses are None.""" + inputs = self._inputs() + fused_out, lb_loss, bias_updates = self.fused_model(inputs) + + expected_shape = (_B, _S, self.fused_cfg.base_emb_dim) + self.assertEqual(fused_out.shape, expected_shape) + self.assertEqual(fused_out.dtype, self.fused_cfg.dtype) + self.assertIsNone(lb_loss) + self.assertIsNone(bias_updates) + + def test_fused_vs_dense_renormalize(self): + """fused_moe_matmul agrees with dense_matmul when norm_topk_prob=True.""" + extra_args = get_decoupled_parallelism_overrides() + dense_renorm_cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="fused_moe_dense_renorm", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + sparse_matmul=False, + megablox=False, + ici_expert_parallelism=jax.device_count(), + log_config=False, + norm_topk_prob=True, + max_target_length=_S, + per_device_batch_size=_B, + **extra_args, + ) + dense_renorm_devices = maxtext_utils.create_device_mesh(dense_renorm_cfg) + dense_renorm_mesh = Mesh(dense_renorm_devices, dense_renorm_cfg.mesh_axes) + dense_renorm_model = make_moe(dense_renorm_cfg, dense_renorm_mesh) + + fused_renorm_cfg = pyconfig.initialize( + [None, get_test_config_path("inference/vllm.yml")], + run_name="fused_moe_vllm_renorm", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + norm_topk_prob=True, + ici_expert_parallelism=jax.device_count(), + log_config=False, + max_target_length=_S, + per_device_batch_size=_B, + ) + fused_renorm_devices = maxtext_utils.create_device_mesh(fused_renorm_cfg) + fused_renorm_mesh = Mesh(fused_renorm_devices, fused_renorm_cfg.mesh_axes) + fused_renorm_model = make_moe(fused_renorm_cfg, fused_renorm_mesh) + copy_weights(dense_renorm_model, fused_renorm_model) + + inputs = self._inputs() + dense_out, _, _ = dense_renorm_model(inputs) + fused_out, lb_loss, bias_updates = fused_renorm_model(inputs) + + np.testing.assert_allclose( + np.array(dense_out, dtype=np.float32), + np.array(fused_out, dtype=np.float32), + rtol=1e-2, + atol=1e-2, + ) + self.assertIsNone(lb_loss) + self.assertIsNone(bias_updates) + + def test_prefused_vs_dense_softmax(self): + """prefuse_moe_weights=True agrees with dense_matmul under softmax routing.""" + prefused_cfg = pyconfig.initialize( + [None, get_test_config_path("inference/vllm.yml")], + run_name="fused_moe_vllm_prefused", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + ici_expert_parallelism=jax.device_count(), + log_config=False, + prefuse_moe_weights=True, + max_target_length=_S, + per_device_batch_size=_B, + ) + prefused_devices = maxtext_utils.create_device_mesh(prefused_cfg) + prefused_mesh = Mesh(prefused_devices, prefused_cfg.mesh_axes) + prefused_model = make_moe(prefused_cfg, prefused_mesh) + copy_weights_prefused(self.dense_model, prefused_model) + + inputs = self._inputs() + dense_out, _, _ = self.dense_model(inputs) + prefused_out, lb_loss, bias_updates = prefused_model(inputs) + + np.testing.assert_allclose( + np.array(dense_out, dtype=np.float32), + np.array(prefused_out, dtype=np.float32), + rtol=1e-2, + atol=1e-2, + ) + self.assertIsNone(lb_loss) + self.assertIsNone(bias_updates) + + def test_prefused_vs_sparse_softmax(self): + """prefuse_moe_weights=True agrees with sparse_matmul (Megablox) under softmax routing.""" + extra_args = get_decoupled_parallelism_overrides() + sparse_cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="fused_moe_sparse_ref2", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + sparse_matmul=True, + megablox=True, + ici_expert_parallelism=jax.device_count(), + max_target_length=_S, + per_device_batch_size=_B, + **extra_args, + ) + sparse_devices = maxtext_utils.create_device_mesh(sparse_cfg) + sparse_mesh = Mesh(sparse_devices, sparse_cfg.mesh_axes) + sparse_model = make_moe(sparse_cfg, sparse_mesh) + copy_weights(self.dense_model, sparse_model) + + prefused_cfg = pyconfig.initialize( + [None, get_test_config_path("inference/vllm.yml")], + run_name="fused_moe_vllm_prefused2", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + ici_expert_parallelism=jax.device_count(), + prefuse_moe_weights=True, + max_target_length=_S, + per_device_batch_size=_B, + ) + prefused_devices = maxtext_utils.create_device_mesh(prefused_cfg) + prefused_mesh = Mesh(prefused_devices, prefused_cfg.mesh_axes) + prefused_model = make_moe(prefused_cfg, prefused_mesh) + copy_weights_prefused(self.dense_model, prefused_model) + + inputs = self._inputs() + sparse_out, _, _ = sparse_model(inputs) + prefused_out, lb_loss, bias_updates = prefused_model(inputs) + + np.testing.assert_allclose( + np.array(sparse_out, dtype=np.float32), + np.array(prefused_out, dtype=np.float32), + rtol=1e-2, + atol=1e-2, + ) + self.assertIsNone(lb_loss) + self.assertIsNone(bias_updates) + + if __name__ == "__main__": unittest.main()