|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | """ |
15 | | -An example script to perform decoding using vLLM with a MaxText model. |
| 15 | +An example script to perform decoding using vLLM via Tunix or via MaxText on vLLM. |
16 | 16 |
|
17 | | -Example command: |
| 17 | +Example usage with Tunix: |
18 | 18 | python3 -m MaxText.vllm_decode MaxText/configs/base.yml \ |
19 | 19 | model_name=llama3.1-8b tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ |
20 | 20 | tokenizer_type=huggingface hf_access_token=<your_hf_token> \ |
21 | 21 | load_parameters_path=<your_checkpoint_path> \ |
22 | 22 | per_device_batch_size=1 run_name=vllm_decode_test max_target_length=64 \ |
23 | 23 | use_chat_template=False prompt="Suggest some famous landmarks in London." \ |
24 | | - decode_sampling_temperature=0.0 decode_sampling_nucleus_p=1.0 decode_sampling_top_k=0.0 |
| 24 | + decode_sampling_temperature=0.0 decode_sampling_nucleus_p=1.0 decode_sampling_top_k=0.0 \ |
| 25 | + --use_tunix \ |
| 26 | + |
| 27 | +Or without Tunix using the MaxText vLLM integration: |
| 28 | + python3 -m MaxText.vllm_decode \ |
| 29 | + --model-name qwen3-30b-a3b \ |
| 30 | + --hf-model-name Qwen/Qwen3-30B-A3B \ |
| 31 | + --hf-config-path src/MaxText/integration/vllm/maxtext_vllm_adapter \ |
| 32 | + --load-parameters-path <your_checkpoint_path> \ |
| 33 | + --ici_data_parallelism 1 \ |
| 34 | + --ici-tensor-parallelism 4 \ |
| 35 | + --ici-expert-parallelism 1 \ |
| 36 | + --max-model-len 4096 \ |
| 37 | + --max-num-batched-tokens 262144 \ |
| 38 | + --gpu-memory-utilization 0.5 \ |
| 39 | + --prompt "Suggest some famous landmarks in London." \ |
25 | 40 | """ |
26 | 41 |
|
27 | 42 | import os |
28 | 43 | from typing import Any, Sequence |
29 | 44 |
|
30 | 45 | from absl import app |
| 46 | +from absl import flags |
31 | 47 | import jax |
32 | 48 | import transformers |
33 | 49 |
|
|
37 | 53 | from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter |
38 | 54 | from tunix.rl.rollout import base_rollout |
39 | 55 | from tunix.rl.rollout.vllm_rollout import VllmRollout |
| 56 | +from vllm import LLM |
| 57 | +from vllm.sampling_params import SamplingParams |
40 | 58 |
|
41 | 59 | os.environ["SKIP_JAX_PRECOMPILE"] = "1" |
| 60 | +os.environ["NEW_MODEL_DESIGN"] = "1" |
42 | 61 |
|
43 | 62 |
|
44 | | -def decode( |
| 63 | +# --- DEFINE FLAGS GLOBALLY --- |
| 64 | +FLAGS = flags.FLAGS |
| 65 | + |
| 66 | +# Parallelism |
| 67 | +flags.DEFINE_integer("ici_data_parallelism", 1, "Size of the data parallelism dimension.") |
| 68 | +flags.DEFINE_integer("ici_tensor_parallelism", 1, "Size of the non-expert tensor parallelism dimension.") |
| 69 | +flags.DEFINE_integer("ici_expert_parallelism", 1, "Size of the MoE expert parallelism dimension.") |
| 70 | + |
| 71 | +# Model |
| 72 | +flags.DEFINE_string("model_name", "qwen3-30b-a3b", "Model name for MaxText.") |
| 73 | +flags.DEFINE_string("hf_model_name", "Qwen/Qwen3-30B-A3B", "Path to the Hugging Face model.") |
| 74 | +flags.DEFINE_string("hf_config_path", None, "Path to the local Hugging Face model config.") |
| 75 | +flags.DEFINE_string("load_parameters_path", None, "Path to load model parameters from.") |
| 76 | +flags.DEFINE_bool("enable_expert_parallel", False, "Whether to enable expert parallelism.") |
| 77 | + |
| 78 | +# Length/Throughput |
| 79 | +flags.DEFINE_integer("max_target_length", 1024, "Maximum total context length (MCL).") |
| 80 | +flags.DEFINE_integer("max_prefill_length", 512, "Maximum prefill length.") |
| 81 | +flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.") |
| 82 | + |
| 83 | +# Decoding |
| 84 | +flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.") |
| 85 | +flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.") |
| 86 | +flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.") |
| 87 | +flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.") |
| 88 | +flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.") |
| 89 | + |
| 90 | +# Mark required flags |
| 91 | +flags.mark_flag_as_required("hf_config_path") |
| 92 | + |
| 93 | + |
| 94 | +def decode_with_vllm( |
| 95 | + model_name: str, |
| 96 | + hf_model_name: str, |
| 97 | + hf_config_path: str, |
| 98 | + load_parameters_path: str, |
| 99 | + ici_data_parallelism: int, |
| 100 | + ici_tensor_parallelism: int, |
| 101 | + ici_expert_parallelism: int, |
| 102 | + max_prefill_length: int, |
| 103 | + max_target_length: int, |
| 104 | + gpu_memory_utilization: float, |
| 105 | + enable_expert_parallel: bool, |
| 106 | + prompt: str, |
| 107 | + decode_sampling_temperature: float, |
| 108 | + decode_sampling_nucleus_p: float, |
| 109 | + decode_sampling_top_k: float, |
| 110 | +) -> None: |
| 111 | + """Decode using vLLM with a MaxText model implementation. |
| 112 | +
|
| 113 | + Args: |
| 114 | + model_name: Name of the model for MaxText. |
| 115 | + hf_model_name: Path to the Hugging Face model. |
| 116 | + hf_config_path: Path to the local Hugging Face model config. |
| 117 | + load_parameters_path: Path to load model parameters from. |
| 118 | + ici_data_parallelism: Size of the data parallelism dimension. |
| 119 | + ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension. |
| 120 | + ici_expert_parallelism: Size of the MoE expert parallelism dimension. |
| 121 | + max_prefill_length: Maximum prefill length. |
| 122 | + max_target_length: Maximum total context length (MCL). |
| 123 | + gpu_memory_utilization: Fraction of GPU memory to be used for the model executor. |
| 124 | + enable_expert_parallel: Whether to enable expert parallelism. |
| 125 | + prompt: The prompt to decode. |
| 126 | + decode_sampling_temperature: Temperature for sampling. |
| 127 | + decode_sampling_nucleus_p: Nucleus sampling probability. |
| 128 | + decode_sampling_top_k: Top-k sampling probability. |
| 129 | + """ |
| 130 | + |
| 131 | + # Prepare vLLM Arguments |
| 132 | + vllm_args = {} |
| 133 | + vllm_args["additional_config"] = {} |
| 134 | + |
| 135 | + # Core vLLM Arguments |
| 136 | + vllm_args["model"] = hf_model_name |
| 137 | + vllm_args["max_model_len"] = max_target_length |
| 138 | + vllm_args["tensor_parallel_size"] = ici_tensor_parallelism |
| 139 | + vllm_args["data_parallel_size"] = ici_data_parallelism |
| 140 | + vllm_args["enable_expert_parallel"] = enable_expert_parallel |
| 141 | + vllm_args["hf_config_path"] = hf_config_path |
| 142 | + vllm_args["gpu_memory_utilization"] = gpu_memory_utilization |
| 143 | + |
| 144 | + if load_parameters_path is None: |
| 145 | + vllm_args["load_format"] = "dummy" |
| 146 | + |
| 147 | + # Prepare MaxText and sharding configs (Parallelism is dynamic) |
| 148 | + vllm_args["additional_config"]["maxtext_config"] = { |
| 149 | + "model_name": model_name, |
| 150 | + "max_target_length": max_target_length, |
| 151 | + "weight_dtype": "bfloat16", |
| 152 | + "allow_split_physical_axes": True, |
| 153 | + "load_parameters_path": load_parameters_path, |
| 154 | + } |
| 155 | + |
| 156 | + vllm_args["additional_config"]["sharding"] = { |
| 157 | + "sharding_strategy": { |
| 158 | + "tensor_parallelism": ici_tensor_parallelism, |
| 159 | + "expert_parallelism": ici_expert_parallelism, |
| 160 | + "data_parallelism": ici_data_parallelism, |
| 161 | + }, |
| 162 | + } |
| 163 | + |
| 164 | + if enable_expert_parallel: |
| 165 | + vllm_args["additional_config"]["sharding"]["sharding_strategy"].update({"expert_parallelism": ici_expert_parallelism}) |
| 166 | + |
| 167 | + # Initialize and Run LLM |
| 168 | + max_tokens = max_target_length - max_prefill_length |
| 169 | + sampling_params = SamplingParams( |
| 170 | + temperature=decode_sampling_temperature, |
| 171 | + max_tokens=max_tokens, |
| 172 | + top_k=decode_sampling_top_k, |
| 173 | + top_p=decode_sampling_nucleus_p, |
| 174 | + ) |
| 175 | + |
| 176 | + print( |
| 177 | + f"Initializing LLM with DP={vllm_args['data_parallel_size']}, TP={vllm_args['tensor_parallel_size']} " |
| 178 | + f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..." |
| 179 | + ) |
| 180 | + llm = LLM(**vllm_args) |
| 181 | + |
| 182 | + print("Generating output...") |
| 183 | + outputs = llm.generate([prompt], sampling_params) |
| 184 | + |
| 185 | + # Print Outputs |
| 186 | + for output in outputs: |
| 187 | + prompt = output.prompt |
| 188 | + generated_text = output.outputs[0].text |
| 189 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 190 | + |
| 191 | + |
| 192 | +def decode_with_tunix( |
45 | 193 | config: Config, |
46 | 194 | model: Any, |
47 | 195 | mesh: jax.sharding.Mesh, |
@@ -113,9 +261,28 @@ def main(argv: Sequence[str]) -> None: |
113 | 261 | os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" |
114 | 262 | ) |
115 | 263 |
|
116 | | - config = pyconfig.initialize(argv) |
117 | | - maxtext_model, mesh = model_creation_utils.create_nnx_model(config) |
118 | | - decode(config, model=maxtext_model, mesh=mesh) |
| 264 | + if FLAGS.use_tunix: |
| 265 | + config = pyconfig.initialize(argv) |
| 266 | + maxtext_model, mesh = model_creation_utils.create_nnx_model(config) |
| 267 | + decode_with_tunix(config, model=maxtext_model, mesh=mesh) |
| 268 | + else: |
| 269 | + decode_with_vllm( |
| 270 | + model_name=FLAGS.model_name, |
| 271 | + hf_model_name=FLAGS.hf_model_name, |
| 272 | + hf_config_path=FLAGS.hf_config_path, |
| 273 | + load_parameters_path=FLAGS.load_parameters_path, |
| 274 | + ici_data_parallelism=FLAGS.ici_data_parallelism, |
| 275 | + ici_tensor_parallelism=FLAGS.ici_tensor_parallelism, |
| 276 | + ici_expert_parallelism=FLAGS.ici_expert_parallelism, |
| 277 | + max_target_length=FLAGS.max_target_length, |
| 278 | + max_prefill_length=FLAGS.max_prefill_length, |
| 279 | + gpu_memory_utilization=FLAGS.gpu_memory_utilization, |
| 280 | + enable_expert_parallel=FLAGS.enable_expert_parallel, |
| 281 | + prompt=FLAGS.prompt, |
| 282 | + decode_sampling_temperature=FLAGS.decode_sampling_temperature, |
| 283 | + decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p, |
| 284 | + decode_sampling_top_k=FLAGS.decode_sampling_top_k, |
| 285 | + ) |
119 | 286 |
|
120 | 287 |
|
121 | 288 | if __name__ == "__main__": |
|
0 commit comments