|
| 1 | +""" |
| 2 | +Optional bridge: run vLLM ``fused_experts`` on ATen views of InfiniCore tensors. |
| 3 | +
|
| 4 | +Requires InfiniCore built with ``--aten=y`` and a Python environment where vLLM |
| 5 | +(and Triton, for the default GPU path) are installed. |
| 6 | +""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +import json |
| 11 | +import os |
| 12 | +from typing import TYPE_CHECKING |
| 13 | + |
| 14 | +from infinicore.lib import _infinicore |
| 15 | +from infinicore.tensor import from_torch, to_torch |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from infinicore.tensor import Tensor |
| 19 | + |
| 20 | + |
| 21 | +def _require_aten_bridge() -> None: |
| 22 | + if getattr(_infinicore, "_tensor_as_torch", None) is None: |
| 23 | + raise RuntimeError( |
| 24 | + "vllm_fused_moe_bridge requires InfiniCore with ATen enabled " |
| 25 | + "(rebuild with --aten=y)." |
| 26 | + ) |
| 27 | + |
| 28 | + |
| 29 | +def fused_experts_ic( |
| 30 | + hidden_states: Tensor, |
| 31 | + w1: Tensor, |
| 32 | + w2: Tensor, |
| 33 | + topk_weights: Tensor, |
| 34 | + topk_ids: Tensor, |
| 35 | + *, |
| 36 | + inplace: bool = False, |
| 37 | + apply_router_weight_on_input: bool = False, |
| 38 | +): |
| 39 | + """ |
| 40 | + Run vLLM fused MoE experts on InfiniCore tensors via ``to_torch`` / ``from_torch``. |
| 41 | +
|
| 42 | + Weight layout matches vLLM ``FusedMoE`` / ``fused_experts``: |
| 43 | + ``w1`` shape ``[num_experts, 2 * intermediate_size, hidden_size]``, |
| 44 | + ``w2`` shape ``[num_experts, hidden_size, intermediate_size]``, |
| 45 | + last dimension contiguous (stride 1). ``hidden_states`` shape ``[num_tokens, hidden_size]``, contiguous. |
| 46 | +
|
| 47 | + Returns a new InfiniCore tensor (aliases the vLLM output torch tensor via ``from_torch``). |
| 48 | + """ |
| 49 | + _require_aten_bridge() |
| 50 | + |
| 51 | + try: |
| 52 | + import torch |
| 53 | + from vllm.model_executor.layers.fused_moe import MoEActivation, fused_experts |
| 54 | + except ImportError as e: |
| 55 | + raise RuntimeError( |
| 56 | + "vllm_fused_moe_bridge requires vLLM to be installed in this interpreter." |
| 57 | + ) from e |
| 58 | + |
| 59 | + h = to_torch(hidden_states) |
| 60 | + t_w1 = to_torch(w1) |
| 61 | + t_w2 = to_torch(w2) |
| 62 | + t_tw = to_torch(topk_weights) |
| 63 | + t_ids = to_torch(topk_ids) |
| 64 | + |
| 65 | + if not h.is_contiguous(): |
| 66 | + h = h.contiguous() |
| 67 | + if t_w1.stride(-1) != 1: |
| 68 | + t_w1 = t_w1.contiguous() |
| 69 | + if t_w2.stride(-1) != 1: |
| 70 | + t_w2 = t_w2.contiguous() |
| 71 | + |
| 72 | + if torch.cuda.is_available(): |
| 73 | + torch.cuda.current_stream().synchronize() |
| 74 | + |
| 75 | + out_t = fused_experts( |
| 76 | + h, |
| 77 | + t_w1, |
| 78 | + t_w2, |
| 79 | + t_tw, |
| 80 | + t_ids, |
| 81 | + inplace=inplace, |
| 82 | + activation=MoEActivation.SILU, |
| 83 | + apply_router_weight_on_input=apply_router_weight_on_input, |
| 84 | + ) |
| 85 | + return from_torch(out_t) |
| 86 | + |
| 87 | + |
| 88 | +def _load_hf_config_json(model_path: str) -> dict: |
| 89 | + path = os.path.join(os.path.expanduser(model_path), "config.json") |
| 90 | + with open(path, encoding="utf-8") as f: |
| 91 | + return json.load(f) |
| 92 | + |
| 93 | + |
| 94 | +def _torch_dtype_from_hf(cfg: dict): |
| 95 | + import torch |
| 96 | + |
| 97 | + td = cfg.get("torch_dtype") |
| 98 | + if td is None: |
| 99 | + return torch.bfloat16 |
| 100 | + if isinstance(td, str): |
| 101 | + name = td.replace("torch.", "", 1) |
| 102 | + return getattr(torch, name, torch.bfloat16) |
| 103 | + return torch.bfloat16 |
| 104 | + |
| 105 | + |
| 106 | +def _moe_dims_from_config(cfg: dict) -> tuple[int, int, int, int] | None: |
| 107 | + """ |
| 108 | + Return (num_experts, intermediate, hidden, top_k) if config looks like a MoE model, else None. |
| 109 | + """ |
| 110 | + n_exp = cfg.get("n_routed_experts") |
| 111 | + if n_exp is None: |
| 112 | + n_exp = cfg.get("num_local_experts") |
| 113 | + if n_exp is None: |
| 114 | + return None |
| 115 | + |
| 116 | + inter = cfg.get("moe_intermediate_size") |
| 117 | + if inter is None: |
| 118 | + inter = cfg.get("intermediate_size") |
| 119 | + hidden = cfg.get("hidden_size") |
| 120 | + if inter is None or hidden is None: |
| 121 | + return None |
| 122 | + |
| 123 | + topk = cfg.get("num_experts_per_tok") |
| 124 | + if topk is None: |
| 125 | + topk = cfg.get("num_experts_per_token") |
| 126 | + if topk is None: |
| 127 | + topk = 1 |
| 128 | + |
| 129 | + return (int(n_exp), int(inter), int(hidden), int(topk)) |
| 130 | + |
| 131 | + |
| 132 | +def _dtype_nbytes(torch_dtype) -> int: |
| 133 | + import torch |
| 134 | + |
| 135 | + return torch.tensor([], dtype=torch_dtype).element_size() |
| 136 | + |
| 137 | + |
| 138 | +def _estimated_fused_moe_warmup_bytes( |
| 139 | + E: int, N: int, H: int, topk: int, num_tokens: int, torch_dtype |
| 140 | +) -> int: |
| 141 | + """Rough peak device memory for dummy w1/w2 + activations (bf16/fp16 = 2 bytes).""" |
| 142 | + es = _dtype_nbytes(torch_dtype) |
| 143 | + w1 = E * (2 * N) * H * es |
| 144 | + w2 = E * H * N * es |
| 145 | + hs = num_tokens * H * es |
| 146 | + tw = num_tokens * topk * es |
| 147 | + # topk_ids int32 |
| 148 | + tid = num_tokens * topk * 4 |
| 149 | + return w1 + w2 + hs + tw + tid |
| 150 | + |
| 151 | + |
| 152 | +def verify_vllm_fused_moe_config_for_checkpoint(model_path: str, device_index: int = 0) -> None: |
| 153 | + """ |
| 154 | + Log whether vLLM's bundled fused_moe JSON exists for this model's (E, N) and GPU name. |
| 155 | +
|
| 156 | + Mirrors vLLM's ``get_config_file_name`` / config search paths so messages match runtime. |
| 157 | + """ |
| 158 | + dims = _moe_dims_from_config(_load_hf_config_json(model_path)) |
| 159 | + if dims is None: |
| 160 | + return |
| 161 | + |
| 162 | + import torch |
| 163 | + |
| 164 | + E, N, _, _ = dims |
| 165 | + if not torch.cuda.is_available(): |
| 166 | + print( |
| 167 | + f"[vllm_fused_moe] preflight: MoE E={E} N={N} (CUDA unavailable; skip config check)", |
| 168 | + flush=True, |
| 169 | + ) |
| 170 | + return |
| 171 | + |
| 172 | + torch.cuda.set_device(device_index) |
| 173 | + try: |
| 174 | + import vllm.envs as vllm_envs |
| 175 | + import vllm.model_executor.layers.fused_moe.fused_moe as vllm_fused_moe_mod |
| 176 | + |
| 177 | + get_config_file_name = vllm_fused_moe_mod.get_config_file_name |
| 178 | + except ImportError: |
| 179 | + print( |
| 180 | + "[vllm_fused_moe] preflight: vLLM not importable; skip fused_moe config check", |
| 181 | + flush=True, |
| 182 | + ) |
| 183 | + return |
| 184 | + |
| 185 | + json_name = get_config_file_name(E, N, None, None) |
| 186 | + fused_moe_dir = os.path.dirname(vllm_fused_moe_mod.__file__) |
| 187 | + default_path = os.path.join(fused_moe_dir, "configs", json_name) |
| 188 | + paths = [] |
| 189 | + if vllm_envs.VLLM_TUNED_CONFIG_FOLDER: |
| 190 | + paths.append(os.path.join(vllm_envs.VLLM_TUNED_CONFIG_FOLDER, json_name)) |
| 191 | + paths.append(default_path) |
| 192 | + |
| 193 | + found = next((p for p in paths if os.path.isfile(p)), None) |
| 194 | + if found: |
| 195 | + print( |
| 196 | + f"[vllm_fused_moe] preflight: using tuned config {found}", |
| 197 | + flush=True, |
| 198 | + ) |
| 199 | + else: |
| 200 | + print( |
| 201 | + "[vllm_fused_moe] preflight: no tuned fused_moe JSON for this " |
| 202 | + f"(E={E}, N={N}); vLLM will use defaults (see also {default_path})", |
| 203 | + flush=True, |
| 204 | + ) |
| 205 | + |
| 206 | + |
| 207 | +def warmup_vllm_fused_moe_from_checkpoint(model_path: str, device_index: int = 0) -> None: |
| 208 | + """ |
| 209 | + Run one ``fused_experts`` call with shapes from ``config.json`` so Triton JIT and config |
| 210 | + resolution happen before TTFT timers (e.g. in ``jiuge.py``). |
| 211 | +
|
| 212 | + Set ``INFINILM_SKIP_VLLM_FUSED_MOE_PREFLIGHT=1`` to disable. |
| 213 | +
|
| 214 | + Optional ``INFINILM_VLLM_FUSED_WARMUP_MAX_BYTES`` (integer): skip allocating dummy expert |
| 215 | + weights when the estimate exceeds this budget (TTFT may then include one-time Triton JIT). |
| 216 | + """ |
| 217 | + if os.environ.get("INFINILM_SKIP_VLLM_FUSED_MOE_PREFLIGHT") == "1": |
| 218 | + return |
| 219 | + |
| 220 | + cfg = _load_hf_config_json(model_path) |
| 221 | + dims = _moe_dims_from_config(cfg) |
| 222 | + if dims is None: |
| 223 | + return |
| 224 | + |
| 225 | + try: |
| 226 | + import torch |
| 227 | + from vllm.model_executor.layers.fused_moe import MoEActivation, fused_experts |
| 228 | + except ImportError: |
| 229 | + return |
| 230 | + |
| 231 | + E, N, H, topk = dims |
| 232 | + if not torch.cuda.is_available(): |
| 233 | + return |
| 234 | + |
| 235 | + torch_dtype = _torch_dtype_from_hf(cfg) |
| 236 | + torch.cuda.set_device(device_index) |
| 237 | + device = torch.device("cuda", device_index) |
| 238 | + |
| 239 | + # Enough tokens to exercise typical block sizes without large activation batch. |
| 240 | + num_tokens = min(128, max(32, topk * 8)) |
| 241 | + est = _estimated_fused_moe_warmup_bytes(E, N, H, topk, num_tokens, torch_dtype) |
| 242 | + max_b = os.environ.get("INFINILM_VLLM_FUSED_WARMUP_MAX_BYTES") |
| 243 | + if max_b is not None: |
| 244 | + try: |
| 245 | + if est > int(max_b): |
| 246 | + print( |
| 247 | + f"[vllm_fused_moe] preflight: skip fused_experts warmup " |
| 248 | + f"(estimated {est} bytes > INFINILM_VLLM_FUSED_WARMUP_MAX_BYTES={max_b})", |
| 249 | + flush=True, |
| 250 | + ) |
| 251 | + return |
| 252 | + except ValueError: |
| 253 | + pass |
| 254 | + |
| 255 | + hidden_states = torch.randn(num_tokens, H, device=device, dtype=torch_dtype) |
| 256 | + w1 = torch.randn(E, 2 * N, H, device=device, dtype=torch_dtype) |
| 257 | + w2 = torch.randn(E, H, N, device=device, dtype=torch_dtype) |
| 258 | + topk_weights = torch.randn(num_tokens, topk, device=device, dtype=torch_dtype) |
| 259 | + topk_ids = torch.randint(0, E, (num_tokens, topk), device=device, dtype=torch.int32) |
| 260 | + |
| 261 | + torch.cuda.current_stream().synchronize() |
| 262 | + |
| 263 | + try: |
| 264 | + _ = fused_experts( |
| 265 | + hidden_states, |
| 266 | + w1, |
| 267 | + w2, |
| 268 | + topk_weights, |
| 269 | + topk_ids, |
| 270 | + inplace=False, |
| 271 | + activation=MoEActivation.SILU, |
| 272 | + apply_router_weight_on_input=False, |
| 273 | + ) |
| 274 | + torch.cuda.synchronize() |
| 275 | + except torch.cuda.OutOfMemoryError as e: |
| 276 | + print( |
| 277 | + f"[vllm_fused_moe] preflight: fused_experts warmup OOM (skip): {e}", |
| 278 | + flush=True, |
| 279 | + ) |
| 280 | + |
| 281 | + |
| 282 | +def preflight_vllm_fused_moe_for_ttft(model_path: str, device_index: int = 0) -> None: |
| 283 | + """Verify fused_moe JSON presence and warm up kernels before timed generate.""" |
| 284 | + verify_vllm_fused_moe_config_for_checkpoint(model_path, device_index=device_index) |
| 285 | + warmup_vllm_fused_moe_from_checkpoint(model_path, device_index=device_index) |
0 commit comments