Skip to content

Commit 87fd59b

Browse files
committed
refactor: MoE LoRA
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 8556a53 commit 87fd59b

13 files changed

Lines changed: 729 additions & 304 deletions

File tree

aphrodite/lora/layers/fused_moe.py

Lines changed: 37 additions & 277 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import functools
43

54
import torch
65
import torch.nn as nn
@@ -14,31 +13,17 @@
1413
)
1514
from aphrodite.distributed.utils import divide
1615
from aphrodite.lora.layers.base import BaseLayerWithLoRA
17-
from aphrodite.lora.ops.triton_ops.utils import get_lora_op_configs
1816
from aphrodite.model_executor.layers.fused_moe import FusedMoE
19-
from aphrodite.model_executor.layers.fused_moe.config import (
20-
_get_config_dtype_str,
21-
)
22-
from aphrodite.model_executor.layers.fused_moe.experts.gpt_oss_triton_kernels_moe import (
23-
UnfusedOAITritonExperts,
24-
)
25-
from aphrodite.model_executor.layers.fused_moe.fused_marlin_moe import (
26-
MarlinExperts,
27-
)
28-
from aphrodite.model_executor.layers.fused_moe.fused_moe import (
29-
TritonExperts,
30-
)
3117
from aphrodite.model_executor.layers.fused_moe.fused_moe_modular_method import (
3218
FusedMoEModularMethod,
3319
)
34-
from aphrodite.model_executor.layers.fused_moe.modular_kernel import (
35-
FusedMoEKernel,
36-
)
20+
from aphrodite.model_executor.layers.fused_moe.lora_context import MoELoRAContext
21+
from aphrodite.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
3722
from aphrodite.model_executor.layers.fused_moe.prepare_finalize import (
3823
MoEPrepareAndFinalizeNoDPEPModular,
3924
)
4025

41-
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
26+
from .utils import _get_lora_device
4227

4328

4429
class FusedMoEWithLoRA(BaseLayerWithLoRA):
@@ -56,275 +41,46 @@ def __init__(self, base_layer: FusedMoE) -> None:
5641
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
5742
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
5843
self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
59-
self._inject_lora_into_fused_moe()
60-
61-
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
62-
normalized_config = {}
63-
for key, value in config.items():
64-
if key.islower():
65-
if key.startswith("block_"):
66-
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
67-
else:
68-
normalized_key = key.upper()
69-
else:
70-
normalized_key = key
71-
normalized_config[normalized_key] = value
72-
return normalized_config
73-
74-
def _get_lora_moe_configs(
75-
self,
76-
op_prefix: str,
77-
num_loras: int,
78-
rank: int,
79-
num_slices: int,
80-
M: int,
81-
layer: FusedMoE,
82-
top_k: int,
83-
config_dtype: str,
84-
):
85-
if envs.APHRODITE_TUNED_CONFIG_FOLDER:
86-
hidden_size = layer.hidden_size
87-
intermediate_size = (
88-
self.w2_lora_a_stacked[0].shape[-1] if op_prefix == "w2" else self.w13_lora_b_stacked[0].shape[-2]
89-
)
90-
shrink_config = get_lora_op_configs(
91-
op_type=f"fused_moe_lora_{op_prefix}_shrink",
92-
max_loras=num_loras,
93-
batch=M,
94-
hidden_size=hidden_size,
95-
rank=rank,
96-
num_slices=num_slices,
97-
moe_intermediate_size=intermediate_size,
98-
)
99-
expand_config = get_lora_op_configs(
100-
op_type=f"fused_moe_lora_{op_prefix}_expand",
101-
max_loras=num_loras,
102-
batch=M,
103-
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
104-
rank=rank,
105-
num_slices=num_slices,
106-
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
107-
)
108-
else: # fall back to the default config
109-
get_config_func = functools.partial(
110-
try_get_optimal_moe_lora_config,
111-
w1_shape=layer.w13_weight.shape,
112-
w2_shape=layer.w2_weight.shape,
113-
rank=rank,
114-
top_k=top_k,
115-
dtype=config_dtype,
116-
M=M,
117-
block_shape=layer.quant_method.moe_quant_config.block_shape,
118-
)
119-
shrink_config = get_config_func(op_type=f"fused_moe_lora_{op_prefix}_shrink")
120-
expand_config = get_config_func(op_type=f"fused_moe_lora_{op_prefix}_expand")
121-
shrink_config = self._normalize_keys(shrink_config)
122-
expand_config = self._normalize_keys(expand_config)
123-
return shrink_config, expand_config
124-
125-
def _inject_lora_into_fused_moe(self):
126-
moe_state_dict = {}
127-
top_k = self.base_layer.top_k
12844

12945
self.base_layer.ensure_moe_quant_config_init()
130-
quant_config = self.base_layer.quant_method.moe_quant_config
131-
13246
if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
133-
# Use the existing modular kernel from the quant method
134-
m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
47+
moe_kernel = self.base_layer.quant_method.moe_kernel
13548
# Don't let the kernel own shared experts so the runner can
13649
# overlap them with routed experts via a separate CUDA stream.
137-
m_fused_moe_fn.shared_experts = None
50+
moe_kernel.shared_experts = None
13851
else:
139-
# Create a new modular kernel via select_gemm_impl.
140-
# Don't pass shared_experts to the kernel so the runner can
141-
# overlap them with routed experts via a separate CUDA stream.
14252
prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
143-
m_fused_moe_fn = FusedMoEKernel(
53+
moe_kernel = FusedMoEKernel(
14454
prepare_finalize,
14555
self.base_layer.quant_method.select_gemm_impl(prepare_finalize, self.base_layer),
14656
)
147-
148-
if quant_config.use_mxfp4_w4a16:
149-
assert isinstance(
150-
m_fused_moe_fn.impl.fused_experts,
151-
(MarlinExperts, UnfusedOAITritonExperts),
152-
)
153-
else:
154-
assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
155-
156-
def fwd_decorator(layer, func):
157-
def wrapper(*args, **kwargs):
158-
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
159-
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
160-
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
161-
moe_state_dict["expert_map"] = kwargs["expert_map"]
162-
moe_state_dict["apply_router_weight_on_input"] = kwargs["apply_router_weight_on_input"]
163-
result = func(*args, **kwargs)
164-
return result
165-
166-
return wrapper
167-
168-
def act_decorator(layer, func):
169-
def wrapper(*args, **kwargs):
170-
_, output, input = args
171-
172-
hidden_states = moe_state_dict["hidden_states"]
173-
topk_weights = moe_state_dict["topk_weights"]
174-
curr_topk_ids = moe_state_dict["topk_ids"]
175-
176-
expert_map = moe_state_dict["expert_map"]
177-
178-
config_dtype = _get_config_dtype_str(
179-
dtype=hidden_states.dtype,
180-
use_fp8_w8a8=False,
181-
use_int8_w8a16=False,
182-
use_int4_w4a16=False,
183-
)
184-
num_tokens = hidden_states.size(0)
185-
M = num_tokens
186-
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
187-
shrink_config, expand_config = self._get_lora_moe_configs(
188-
op_prefix="w13",
189-
num_loras=self.max_loras,
190-
rank=max_lora_rank,
191-
num_slices=self._w13_slices,
192-
M=M,
193-
layer=layer,
194-
top_k=top_k,
195-
config_dtype=config_dtype,
196-
)
197-
198-
# SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
199-
# activates only a small fraction of total experts * loras.
200-
SPARSITY_FACTOR = 8
201-
naive_block_assignment = (
202-
expert_map is None
203-
and num_tokens * top_k * SPARSITY_FACTOR <= self.base_layer.local_num_experts * self.max_loras
204-
)
205-
206-
# get the block size of m from customized config or default config
207-
(
208-
token_lora_mapping,
209-
sorted_token_ids_lora,
210-
expert_ids_lora,
211-
num_tokens_post_padded_lora,
212-
) = self.punica_wrapper.moe_lora_align_block_size(
213-
curr_topk_ids,
214-
num_tokens,
215-
shrink_config["BLOCK_SIZE_M"],
216-
self.base_layer.local_num_experts,
217-
self.max_loras,
218-
self.adapter_enabled,
219-
expert_map,
220-
naive_block_assignment=naive_block_assignment,
221-
)
222-
223-
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
224-
moe_state_dict["expert_ids_lora"] = expert_ids_lora
225-
moe_state_dict["num_tokens_post_padded_lora"] = num_tokens_post_padded_lora
226-
moe_state_dict["token_lora_mapping"] = token_lora_mapping
227-
228-
if sorted_token_ids_lora is not None:
229-
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
230-
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
231-
#
232-
233-
self.punica_wrapper.add_lora_fused_moe(
234-
input.view(-1, top_k, input.shape[-1]),
235-
hidden_states,
236-
self.w13_lora_a_stacked,
237-
self.w13_lora_b_stacked,
238-
topk_weights,
239-
sorted_token_ids_lora,
240-
expert_ids_lora,
241-
num_tokens_post_padded_lora,
242-
max_lora_rank,
243-
top_k,
244-
shrink_config, ## pass the shrink config
245-
expand_config, ## pass the expand config
246-
self.adapter_enabled,
247-
fully_sharded=self.fully_sharded,
248-
token_lora_mapping=token_lora_mapping,
249-
)
250-
251-
result = func(*args, **kwargs)
252-
253-
moe_state_dict["intermediate_cache2"] = output
254-
return result
255-
256-
return wrapper
257-
258-
def moe_sum_decorator(layer, func):
259-
def wrapper(*args, **kwargs):
260-
hidden_states = moe_state_dict["hidden_states"]
261-
topk_weights = moe_state_dict["topk_weights"]
262-
263-
config_dtype = _get_config_dtype_str(
264-
dtype=hidden_states.dtype,
265-
use_fp8_w8a8=False,
266-
use_int8_w8a16=False,
267-
use_int4_w4a16=False,
268-
)
269-
num_tokens = hidden_states.size(0)
270-
M = num_tokens
271-
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
272-
shrink_config, expand_config = self._get_lora_moe_configs(
273-
op_prefix="w2",
274-
num_loras=self.max_loras,
275-
rank=max_lora_rank,
276-
num_slices=1,
277-
M=M,
278-
layer=layer,
279-
top_k=top_k,
280-
config_dtype=config_dtype,
281-
)
282-
283-
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
284-
expert_ids_lora = moe_state_dict["expert_ids_lora"]
285-
num_tokens_post_padded_lora = moe_state_dict["num_tokens_post_padded_lora"]
286-
token_lora_mapping = moe_state_dict.get("token_lora_mapping")
287-
288-
if sorted_token_ids_lora is not None:
289-
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
290-
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
291-
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
292-
intermediate_cache3 = args[0]
293-
294-
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
295-
296-
self.punica_wrapper.add_lora_fused_moe(
297-
intermediate_cache3,
298-
intermediate_cache2,
299-
self.w2_lora_a_stacked,
300-
self.w2_lora_b_stacked,
301-
topk_weights,
302-
sorted_token_ids_lora,
303-
expert_ids_lora,
304-
num_tokens_post_padded_lora,
305-
max_lora_rank,
306-
top_k,
307-
shrink_config, ## pass the shrink config
308-
expand_config, ## pass the expand config
309-
self.adapter_enabled,
310-
True,
311-
fully_sharded=self.fully_sharded,
312-
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
313-
token_lora_mapping=token_lora_mapping,
314-
)
315-
316-
result = func(*args, **kwargs)
317-
return result
318-
319-
return wrapper
320-
321-
fused_experts = m_fused_moe_fn.impl.fused_experts
322-
323-
m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
324-
fused_experts.activation = act_decorator(self.base_layer, fused_experts.activation)
325-
fused_experts.moe_sum = moe_sum_decorator(self.base_layer, fused_experts.moe_sum)
326-
# TODO(bnell): find a less intrusive way to handle this.
327-
self.base_layer._replace_quant_method(FusedMoEModularMethod(self.base_layer.quant_method, m_fused_moe_fn))
57+
assert moe_kernel.supports_lora(), (
58+
f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. "
59+
"For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
60+
"(auto selects Triton automatically when LoRA is enabled). "
61+
"For quantized MoE, mix LoRAExpertsMixin into the experts class "
62+
"and consume self._lora_context in apply()."
63+
)
64+
self._fused_experts = moe_kernel.fused_experts
65+
self.base_layer._replace_quant_method(FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel))
66+
67+
def _build_lora_context(self):
68+
return MoELoRAContext(
69+
w13_lora_a_stacked=self.w13_lora_a_stacked,
70+
w13_lora_b_stacked=self.w13_lora_b_stacked,
71+
w2_lora_a_stacked=self.w2_lora_a_stacked,
72+
w2_lora_b_stacked=self.w2_lora_b_stacked,
73+
adapter_enabled=self.adapter_enabled,
74+
max_loras=self.max_loras,
75+
top_k=self.base_layer.top_k,
76+
w13_num_slices=self._w13_slices,
77+
fully_sharded=self.fully_sharded,
78+
tp_rank=self.tp_rank,
79+
tp_size=self.tp_size,
80+
local_num_experts=self.base_layer.local_num_experts,
81+
punica_wrapper=self.punica_wrapper,
82+
use_tuned_config=bool(envs.APHRODITE_TUNED_CONFIG_FOLDER),
83+
)
32884

32985
def _create_lora_a_weights(
33086
self,
@@ -543,6 +299,10 @@ def set_lora(
543299
sliced_w2_lora_b, non_blocking=True
544300
)
545301

302+
def set_mapping(self, punica_wrapper):
303+
super().set_mapping(punica_wrapper)
304+
self._fused_experts.set_lora_context(self._build_lora_context())
305+
546306
def forward(self, *args, **kwargs):
547307
return self.base_layer.forward(*args, **kwargs)
548308

aphrodite/lora/layers/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,12 @@ def try_get_optimal_moe_lora_config(
8888
top_k: int,
8989
dtype: str | None,
9090
M: int,
91-
block_shape: list[int] | None = None,
9291
) -> dict[str, int | None]:
93-
config = try_get_optimal_moe_config(w1_shape, w2_shape, top_k, dtype, M, block_shape).copy()
92+
# LoRA shrink/expand operates on bf16/fp16 adapters regardless of the
93+
# base MoE weight's block-wise quantization, so block_shape is omitted
94+
# from the config lookup — the non-quantized branch in get_default_config
95+
# ignores it anyway.
96+
config = try_get_optimal_moe_config(w1_shape, w2_shape, top_k, dtype, M).copy()
9497
if op_type in [
9598
"fused_moe_lora_w13_shrink",
9699
"fused_moe_lora_w2_shrink",

aphrodite/lora/ops/triton_ops/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,20 @@ def supports_pdl(device: torch.device | None = None) -> bool:
296296
def supports_tma(device: torch.device | None = None) -> bool:
297297
# TMA requires compute capability SM90 or above
298298
return current_platform.is_cuda() and current_platform.has_device_capability(90)
299+
300+
301+
def _normalize_lora_config_keys(
302+
config: dict[str, int | None],
303+
) -> dict[str, int | None]:
304+
"""Normalize Triton config dict keys to uppercase BLOCK_SIZE_* format."""
305+
out: dict[str, int | None] = {}
306+
for key, val in config.items():
307+
if key.islower():
308+
if key.startswith("block_"):
309+
nk = "BLOCK_SIZE_" + key.split("_")[-1].upper()
310+
else:
311+
nk = key.upper()
312+
else:
313+
nk = key
314+
out[nk] = val
315+
return out

0 commit comments

Comments
 (0)