Skip to content

Commit 5bf2e94

Browse files
committed
fused stub validated
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent ef4c150 commit 5bf2e94

6 files changed

Lines changed: 485 additions & 154 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ description = "InfiniCore 是一个跨平台统一编程工具集,为不同芯
99
readme = "README.md"
1010
dependencies = []
1111
requires-python = ">=3.8"
12+
13+
[project.optional-dependencies]
14+
# Same interpreter as InfiniCore built with ``--aten=y`` (see vllm_kernel_reuse_evaluation.md).
15+
vllm = ["vllm==0.19.0"]
16+
1217
classifiers = [
1318
"Programming Language :: Python :: 3",
1419
"License :: OSI Approved :: MIT License",

python/infinicore/_preload.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
import glob
23
import os
34
from typing import Iterable, List
45

@@ -96,13 +97,52 @@ def preload_device(device_type: str) -> None:
9697
# etc.
9798

9899

100+
def preload_flash_attn_for_cpp_api() -> None:
101+
"""
102+
Best-effort load of flash-attn's CUDA extension with RTLD_GLOBAL.
103+
104+
``libinfinicore_cpp_api.so`` (aten builds) may reference flash-attn symbols;
105+
loading via ``LD_PRELOAD`` breaks unrelated subprocesses (e.g. vLLM/Triton).
106+
"""
107+
if os.environ.get("INFINICORE_DISABLE_FLASH_ATTN_RTLD_GLOBAL", "") == "1":
108+
return
109+
if os.environ.get("INFINILM_DISABLE_FLASH_ATTN_RTLD_GLOBAL", "") == "1":
110+
return
111+
112+
candidates: List[str] = []
113+
try:
114+
import flash_attn
115+
116+
base = os.path.dirname(flash_attn.__file__)
117+
parent = os.path.dirname(base)
118+
candidates.extend(glob.glob(os.path.join(parent, "flash_attn_2_cuda*.so")))
119+
candidates.extend(glob.glob(os.path.join(base, "flash_attn_2_cuda*.so")))
120+
except ImportError:
121+
pass
122+
123+
# Typical wheel layout (image / CI); last resort after package discovery
124+
candidates.append(
125+
"/usr/local/lib/python3.12/dist-packages/flash_attn_2_cuda.cpython-312-x86_64-linux-gnu.so"
126+
)
127+
128+
for fa in candidates:
129+
if fa and os.path.isfile(fa):
130+
try:
131+
ctypes.CDLL(fa, mode=ctypes.RTLD_GLOBAL)
132+
return
133+
except OSError:
134+
continue
135+
136+
99137
def preload() -> None:
100138
"""
101139
Universal preload function that loops through device types and preloads when required.
102140
103141
This function detects available device types and preloads their runtime libraries
104142
if the environment indicates they are needed.
105143
"""
144+
preload_flash_attn_for_cpp_api()
145+
106146
# Device types that may require preload
107147
device_types = [
108148
"METAX", # HPCC/METAX
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""
2+
Optional bridge: run vLLM ``fused_experts`` on ATen views of InfiniCore tensors.
3+
4+
Requires InfiniCore built with ``--aten=y`` and a Python environment where vLLM
5+
(and Triton, for the default GPU path) are installed.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import json
11+
import os
12+
from typing import TYPE_CHECKING
13+
14+
from infinicore.lib import _infinicore
15+
from infinicore.tensor import from_torch, to_torch
16+
17+
if TYPE_CHECKING:
18+
from infinicore.tensor import Tensor
19+
20+
21+
def _require_aten_bridge() -> None:
22+
if getattr(_infinicore, "_tensor_as_torch", None) is None:
23+
raise RuntimeError(
24+
"vllm_fused_moe_bridge requires InfiniCore with ATen enabled "
25+
"(rebuild with --aten=y)."
26+
)
27+
28+
29+
def fused_experts_ic(
30+
hidden_states: Tensor,
31+
w1: Tensor,
32+
w2: Tensor,
33+
topk_weights: Tensor,
34+
topk_ids: Tensor,
35+
*,
36+
inplace: bool = False,
37+
apply_router_weight_on_input: bool = False,
38+
):
39+
"""
40+
Run vLLM fused MoE experts on InfiniCore tensors via ``to_torch`` / ``from_torch``.
41+
42+
Weight layout matches vLLM ``FusedMoE`` / ``fused_experts``:
43+
``w1`` shape ``[num_experts, 2 * intermediate_size, hidden_size]``,
44+
``w2`` shape ``[num_experts, hidden_size, intermediate_size]``,
45+
last dimension contiguous (stride 1). ``hidden_states`` shape ``[num_tokens, hidden_size]``, contiguous.
46+
47+
Returns a new InfiniCore tensor (aliases the vLLM output torch tensor via ``from_torch``).
48+
"""
49+
_require_aten_bridge()
50+
51+
try:
52+
import torch
53+
from vllm.model_executor.layers.fused_moe import MoEActivation, fused_experts
54+
except ImportError as e:
55+
raise RuntimeError(
56+
"vllm_fused_moe_bridge requires vLLM to be installed in this interpreter."
57+
) from e
58+
59+
h = to_torch(hidden_states)
60+
t_w1 = to_torch(w1)
61+
t_w2 = to_torch(w2)
62+
t_tw = to_torch(topk_weights)
63+
t_ids = to_torch(topk_ids)
64+
65+
if not h.is_contiguous():
66+
h = h.contiguous()
67+
if t_w1.stride(-1) != 1:
68+
t_w1 = t_w1.contiguous()
69+
if t_w2.stride(-1) != 1:
70+
t_w2 = t_w2.contiguous()
71+
72+
if torch.cuda.is_available():
73+
torch.cuda.current_stream().synchronize()
74+
75+
out_t = fused_experts(
76+
h,
77+
t_w1,
78+
t_w2,
79+
t_tw,
80+
t_ids,
81+
inplace=inplace,
82+
activation=MoEActivation.SILU,
83+
apply_router_weight_on_input=apply_router_weight_on_input,
84+
)
85+
return from_torch(out_t)
86+
87+
88+
def _load_hf_config_json(model_path: str) -> dict:
89+
path = os.path.join(os.path.expanduser(model_path), "config.json")
90+
with open(path, encoding="utf-8") as f:
91+
return json.load(f)
92+
93+
94+
def _torch_dtype_from_hf(cfg: dict):
95+
import torch
96+
97+
td = cfg.get("torch_dtype")
98+
if td is None:
99+
return torch.bfloat16
100+
if isinstance(td, str):
101+
name = td.replace("torch.", "", 1)
102+
return getattr(torch, name, torch.bfloat16)
103+
return torch.bfloat16
104+
105+
106+
def _moe_dims_from_config(cfg: dict) -> tuple[int, int, int, int] | None:
107+
"""
108+
Return (num_experts, intermediate, hidden, top_k) if config looks like a MoE model, else None.
109+
"""
110+
n_exp = cfg.get("n_routed_experts")
111+
if n_exp is None:
112+
n_exp = cfg.get("num_local_experts")
113+
if n_exp is None:
114+
return None
115+
116+
inter = cfg.get("moe_intermediate_size")
117+
if inter is None:
118+
inter = cfg.get("intermediate_size")
119+
hidden = cfg.get("hidden_size")
120+
if inter is None or hidden is None:
121+
return None
122+
123+
topk = cfg.get("num_experts_per_tok")
124+
if topk is None:
125+
topk = cfg.get("num_experts_per_token")
126+
if topk is None:
127+
topk = 1
128+
129+
return (int(n_exp), int(inter), int(hidden), int(topk))
130+
131+
132+
def _dtype_nbytes(torch_dtype) -> int:
133+
import torch
134+
135+
return torch.tensor([], dtype=torch_dtype).element_size()
136+
137+
138+
def _estimated_fused_moe_warmup_bytes(
139+
E: int, N: int, H: int, topk: int, num_tokens: int, torch_dtype
140+
) -> int:
141+
"""Rough peak device memory for dummy w1/w2 + activations (bf16/fp16 = 2 bytes)."""
142+
es = _dtype_nbytes(torch_dtype)
143+
w1 = E * (2 * N) * H * es
144+
w2 = E * H * N * es
145+
hs = num_tokens * H * es
146+
tw = num_tokens * topk * es
147+
# topk_ids int32
148+
tid = num_tokens * topk * 4
149+
return w1 + w2 + hs + tw + tid
150+
151+
152+
def verify_vllm_fused_moe_config_for_checkpoint(model_path: str, device_index: int = 0) -> None:
153+
"""
154+
Log whether vLLM's bundled fused_moe JSON exists for this model's (E, N) and GPU name.
155+
156+
Mirrors vLLM's ``get_config_file_name`` / config search paths so messages match runtime.
157+
"""
158+
dims = _moe_dims_from_config(_load_hf_config_json(model_path))
159+
if dims is None:
160+
return
161+
162+
import torch
163+
164+
E, N, _, _ = dims
165+
if not torch.cuda.is_available():
166+
print(
167+
f"[vllm_fused_moe] preflight: MoE E={E} N={N} (CUDA unavailable; skip config check)",
168+
flush=True,
169+
)
170+
return
171+
172+
torch.cuda.set_device(device_index)
173+
try:
174+
import vllm.envs as vllm_envs
175+
import vllm.model_executor.layers.fused_moe.fused_moe as vllm_fused_moe_mod
176+
177+
get_config_file_name = vllm_fused_moe_mod.get_config_file_name
178+
except ImportError:
179+
print(
180+
"[vllm_fused_moe] preflight: vLLM not importable; skip fused_moe config check",
181+
flush=True,
182+
)
183+
return
184+
185+
json_name = get_config_file_name(E, N, None, None)
186+
fused_moe_dir = os.path.dirname(vllm_fused_moe_mod.__file__)
187+
default_path = os.path.join(fused_moe_dir, "configs", json_name)
188+
paths = []
189+
if vllm_envs.VLLM_TUNED_CONFIG_FOLDER:
190+
paths.append(os.path.join(vllm_envs.VLLM_TUNED_CONFIG_FOLDER, json_name))
191+
paths.append(default_path)
192+
193+
found = next((p for p in paths if os.path.isfile(p)), None)
194+
if found:
195+
print(
196+
f"[vllm_fused_moe] preflight: using tuned config {found}",
197+
flush=True,
198+
)
199+
else:
200+
print(
201+
"[vllm_fused_moe] preflight: no tuned fused_moe JSON for this "
202+
f"(E={E}, N={N}); vLLM will use defaults (see also {default_path})",
203+
flush=True,
204+
)
205+
206+
207+
def warmup_vllm_fused_moe_from_checkpoint(model_path: str, device_index: int = 0) -> None:
208+
"""
209+
Run one ``fused_experts`` call with shapes from ``config.json`` so Triton JIT and config
210+
resolution happen before TTFT timers (e.g. in ``jiuge.py``).
211+
212+
Set ``INFINILM_SKIP_VLLM_FUSED_MOE_PREFLIGHT=1`` to disable.
213+
214+
Optional ``INFINILM_VLLM_FUSED_WARMUP_MAX_BYTES`` (integer): skip allocating dummy expert
215+
weights when the estimate exceeds this budget (TTFT may then include one-time Triton JIT).
216+
"""
217+
if os.environ.get("INFINILM_SKIP_VLLM_FUSED_MOE_PREFLIGHT") == "1":
218+
return
219+
220+
cfg = _load_hf_config_json(model_path)
221+
dims = _moe_dims_from_config(cfg)
222+
if dims is None:
223+
return
224+
225+
try:
226+
import torch
227+
from vllm.model_executor.layers.fused_moe import MoEActivation, fused_experts
228+
except ImportError:
229+
return
230+
231+
E, N, H, topk = dims
232+
if not torch.cuda.is_available():
233+
return
234+
235+
torch_dtype = _torch_dtype_from_hf(cfg)
236+
torch.cuda.set_device(device_index)
237+
device = torch.device("cuda", device_index)
238+
239+
# Enough tokens to exercise typical block sizes without large activation batch.
240+
num_tokens = min(128, max(32, topk * 8))
241+
est = _estimated_fused_moe_warmup_bytes(E, N, H, topk, num_tokens, torch_dtype)
242+
max_b = os.environ.get("INFINILM_VLLM_FUSED_WARMUP_MAX_BYTES")
243+
if max_b is not None:
244+
try:
245+
if est > int(max_b):
246+
print(
247+
f"[vllm_fused_moe] preflight: skip fused_experts warmup "
248+
f"(estimated {est} bytes > INFINILM_VLLM_FUSED_WARMUP_MAX_BYTES={max_b})",
249+
flush=True,
250+
)
251+
return
252+
except ValueError:
253+
pass
254+
255+
hidden_states = torch.randn(num_tokens, H, device=device, dtype=torch_dtype)
256+
w1 = torch.randn(E, 2 * N, H, device=device, dtype=torch_dtype)
257+
w2 = torch.randn(E, H, N, device=device, dtype=torch_dtype)
258+
topk_weights = torch.randn(num_tokens, topk, device=device, dtype=torch_dtype)
259+
topk_ids = torch.randint(0, E, (num_tokens, topk), device=device, dtype=torch.int32)
260+
261+
torch.cuda.current_stream().synchronize()
262+
263+
try:
264+
_ = fused_experts(
265+
hidden_states,
266+
w1,
267+
w2,
268+
topk_weights,
269+
topk_ids,
270+
inplace=False,
271+
activation=MoEActivation.SILU,
272+
apply_router_weight_on_input=False,
273+
)
274+
torch.cuda.synchronize()
275+
except torch.cuda.OutOfMemoryError as e:
276+
print(
277+
f"[vllm_fused_moe] preflight: fused_experts warmup OOM (skip): {e}",
278+
flush=True,
279+
)
280+
281+
282+
def preflight_vllm_fused_moe_for_ttft(model_path: str, device_index: int = 0) -> None:
283+
"""Verify fused_moe JSON presence and warm up kernels before timed generate."""
284+
verify_vllm_fused_moe_config_for_checkpoint(model_path, device_index=device_index)
285+
warmup_vllm_fused_moe_from_checkpoint(model_path, device_index=device_index)

0 commit comments

Comments
 (0)