Skip to content

Commit b767f44

Browse files
committed
Adding MaxText on vLLM decoding.
removing EP from TP value. adding dp to sharding strategy.
1 parent 7991534 commit b767f44

1 file changed

Lines changed: 174 additions & 7 deletions

File tree

src/MaxText/vllm_decode.py

Lines changed: 174 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,38 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""
15-
An example script to perform decoding using vLLM with a MaxText model.
15+
An example script to perform decoding using vLLM via Tunix or via MaxText on vLLM.
1616
17-
Example command:
17+
Example usage with Tunix:
1818
python3 -m MaxText.vllm_decode MaxText/configs/base.yml \
1919
model_name=llama3.1-8b tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
2020
tokenizer_type=huggingface hf_access_token=<your_hf_token> \
2121
load_parameters_path=<your_checkpoint_path> \
2222
per_device_batch_size=1 run_name=vllm_decode_test max_target_length=64 \
2323
use_chat_template=False prompt="Suggest some famous landmarks in London." \
24-
decode_sampling_temperature=0.0 decode_sampling_nucleus_p=1.0 decode_sampling_top_k=0.0
24+
decode_sampling_temperature=0.0 decode_sampling_nucleus_p=1.0 decode_sampling_top_k=0.0 \
25+
--use_tunix \
26+
27+
Or without Tunix using the MaxText vLLM integration:
28+
python3 -m MaxText.vllm_decode \
29+
--model-name qwen3-30b-a3b \
30+
--hf-model-name Qwen/Qwen3-30B-A3B \
31+
--hf-config-path src/MaxText/integration/vllm/maxtext_vllm_adapter \
32+
--load-parameters-path <your_checkpoint_path> \
33+
--ici_data_parallelism 1 \
34+
--ici-tensor-parallelism 4 \
35+
--ici-expert-parallelism 1 \
36+
--max-model-len 4096 \
37+
--max-num-batched-tokens 262144 \
38+
--gpu-memory-utilization 0.5 \
39+
--prompt "Suggest some famous landmarks in London." \
2540
"""
2641

2742
import os
2843
from typing import Any, Sequence
2944

3045
from absl import app
46+
from absl import flags
3147
import jax
3248
import transformers
3349

@@ -37,11 +53,143 @@
3753
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
3854
from tunix.rl.rollout import base_rollout
3955
from tunix.rl.rollout.vllm_rollout import VllmRollout
56+
from vllm import LLM
57+
from vllm.sampling_params import SamplingParams
4058

4159
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
60+
os.environ["NEW_MODEL_DESIGN"] = "1"
4261

4362

44-
def decode(
63+
# --- DEFINE FLAGS GLOBALLY ---
64+
FLAGS = flags.FLAGS
65+
66+
# Parallelism
67+
flags.DEFINE_integer("ici_data_parallelism", 1, "Size of the data parallelism dimension.")
68+
flags.DEFINE_integer("ici_tensor_parallelism", 1, "Size of the non-expert tensor parallelism dimension.")
69+
flags.DEFINE_integer("ici_expert_parallelism", 1, "Size of the MoE expert parallelism dimension.")
70+
71+
# Model
72+
flags.DEFINE_string("model_name", "qwen3-30b-a3b", "Model name for MaxText.")
73+
flags.DEFINE_string("hf_model_name", "Qwen/Qwen3-30B-A3B", "Path to the Hugging Face model.")
74+
flags.DEFINE_string("hf_config_path", None, "Path to the local Hugging Face model config.")
75+
flags.DEFINE_string("load_parameters_path", None, "Path to load model parameters from.")
76+
flags.DEFINE_bool("enable_expert_parallel", False, "Whether to enable expert parallelism.")
77+
78+
# Length/Throughput
79+
flags.DEFINE_integer("max_target_length", 1024, "Maximum total context length (MCL).")
80+
flags.DEFINE_integer("max_prefill_length", 512, "Maximum prefill length.")
81+
flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.")
82+
83+
# Decoding
84+
flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.")
85+
flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.")
86+
flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.")
87+
flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.")
88+
flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.")
89+
90+
# Mark required flags
91+
flags.mark_flag_as_required("hf_config_path")
92+
93+
94+
def decode_with_vllm(
95+
model_name: str,
96+
hf_model_name: str,
97+
hf_config_path: str,
98+
load_parameters_path: str,
99+
ici_data_parallelism: int,
100+
ici_tensor_parallelism: int,
101+
ici_expert_parallelism: int,
102+
max_prefill_length: int,
103+
max_target_length: int,
104+
gpu_memory_utilization: float,
105+
enable_expert_parallel: bool,
106+
prompt: str,
107+
decode_sampling_temperature: float,
108+
decode_sampling_nucleus_p: float,
109+
decode_sampling_top_k: float,
110+
) -> None:
111+
"""Decode using vLLM with a MaxText model implementation.
112+
113+
Args:
114+
model_name: Name of the model for MaxText.
115+
hf_model_name: Path to the Hugging Face model.
116+
hf_config_path: Path to the local Hugging Face model config.
117+
load_parameters_path: Path to load model parameters from.
118+
ici_data_parallelism: Size of the data parallelism dimension.
119+
ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
120+
ici_expert_parallelism: Size of the MoE expert parallelism dimension.
121+
max_prefill_length: Maximum prefill length.
122+
max_target_length: Maximum total context length (MCL).
123+
gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
124+
enable_expert_parallel: Whether to enable expert parallelism.
125+
prompt: The prompt to decode.
126+
decode_sampling_temperature: Temperature for sampling.
127+
decode_sampling_nucleus_p: Nucleus sampling probability.
128+
decode_sampling_top_k: Top-k sampling probability.
129+
"""
130+
131+
# Prepare vLLM Arguments
132+
vllm_args = {}
133+
vllm_args["additional_config"] = {}
134+
135+
# Core vLLM Arguments
136+
vllm_args["model"] = hf_model_name
137+
vllm_args["max_model_len"] = max_target_length
138+
vllm_args["tensor_parallel_size"] = ici_tensor_parallelism
139+
vllm_args["data_parallel_size"] = ici_data_parallelism
140+
vllm_args["enable_expert_parallel"] = enable_expert_parallel
141+
vllm_args["hf_config_path"] = hf_config_path
142+
vllm_args["gpu_memory_utilization"] = gpu_memory_utilization
143+
144+
if load_parameters_path is None:
145+
vllm_args["load_format"] = "dummy"
146+
147+
# Prepare MaxText and sharding configs (Parallelism is dynamic)
148+
vllm_args["additional_config"]["maxtext_config"] = {
149+
"model_name": model_name,
150+
"max_target_length": max_target_length,
151+
"weight_dtype": "bfloat16",
152+
"allow_split_physical_axes": True,
153+
"load_parameters_path": load_parameters_path,
154+
}
155+
156+
vllm_args["additional_config"]["sharding"] = {
157+
"sharding_strategy": {
158+
"tensor_parallelism": ici_tensor_parallelism,
159+
"expert_parallelism": ici_expert_parallelism,
160+
"data_parallelism": ici_data_parallelism,
161+
},
162+
}
163+
164+
if enable_expert_parallel:
165+
vllm_args["additional_config"]["sharding"]["sharding_strategy"].update({"expert_parallelism": ici_expert_parallelism})
166+
167+
# Initialize and Run LLM
168+
max_tokens = max_target_length - max_prefill_length
169+
sampling_params = SamplingParams(
170+
temperature=decode_sampling_temperature,
171+
max_tokens=max_tokens,
172+
top_k=decode_sampling_top_k,
173+
top_p=decode_sampling_nucleus_p,
174+
)
175+
176+
print(
177+
f"Initializing LLM with DP={vllm_args['data_parallel_size']}, TP={vllm_args['tensor_parallel_size']} "
178+
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
179+
)
180+
llm = LLM(**vllm_args)
181+
182+
print("Generating output...")
183+
outputs = llm.generate([prompt], sampling_params)
184+
185+
# Print Outputs
186+
for output in outputs:
187+
prompt = output.prompt
188+
generated_text = output.outputs[0].text
189+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
190+
191+
192+
def decode_with_tunix(
45193
config: Config,
46194
model: Any,
47195
mesh: jax.sharding.Mesh,
@@ -113,9 +261,28 @@ def main(argv: Sequence[str]) -> None:
113261
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
114262
)
115263

116-
config = pyconfig.initialize(argv)
117-
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)
118-
decode(config, model=maxtext_model, mesh=mesh)
264+
if FLAGS.use_tunix:
265+
config = pyconfig.initialize(argv)
266+
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)
267+
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
268+
else:
269+
decode_with_vllm(
270+
model_name=FLAGS.model_name,
271+
hf_model_name=FLAGS.hf_model_name,
272+
hf_config_path=FLAGS.hf_config_path,
273+
load_parameters_path=FLAGS.load_parameters_path,
274+
ici_data_parallelism=FLAGS.ici_data_parallelism,
275+
ici_tensor_parallelism=FLAGS.ici_tensor_parallelism,
276+
ici_expert_parallelism=FLAGS.ici_expert_parallelism,
277+
max_target_length=FLAGS.max_target_length,
278+
max_prefill_length=FLAGS.max_prefill_length,
279+
gpu_memory_utilization=FLAGS.gpu_memory_utilization,
280+
enable_expert_parallel=FLAGS.enable_expert_parallel,
281+
prompt=FLAGS.prompt,
282+
decode_sampling_temperature=FLAGS.decode_sampling_temperature,
283+
decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p,
284+
decode_sampling_top_k=FLAGS.decode_sampling_top_k,
285+
)
119286

120287

121288
if __name__ == "__main__":

0 commit comments

Comments
 (0)