11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- import functools
43
54import torch
65import torch .nn as nn
1413)
1514from aphrodite .distributed .utils import divide
1615from aphrodite .lora .layers .base import BaseLayerWithLoRA
17- from aphrodite .lora .ops .triton_ops .utils import get_lora_op_configs
1816from aphrodite .model_executor .layers .fused_moe import FusedMoE
19- from aphrodite .model_executor .layers .fused_moe .config import (
20- _get_config_dtype_str ,
21- )
22- from aphrodite .model_executor .layers .fused_moe .experts .gpt_oss_triton_kernels_moe import (
23- UnfusedOAITritonExperts ,
24- )
25- from aphrodite .model_executor .layers .fused_moe .fused_marlin_moe import (
26- MarlinExperts ,
27- )
28- from aphrodite .model_executor .layers .fused_moe .fused_moe import (
29- TritonExperts ,
30- )
3117from aphrodite .model_executor .layers .fused_moe .fused_moe_modular_method import (
3218 FusedMoEModularMethod ,
3319)
34- from aphrodite .model_executor .layers .fused_moe .modular_kernel import (
35- FusedMoEKernel ,
36- )
20+ from aphrodite .model_executor .layers .fused_moe .lora_context import MoELoRAContext
21+ from aphrodite .model_executor .layers .fused_moe .modular_kernel import FusedMoEKernel
3722from aphrodite .model_executor .layers .fused_moe .prepare_finalize import (
3823 MoEPrepareAndFinalizeNoDPEPModular ,
3924)
4025
41- from .utils import _get_lora_device , try_get_optimal_moe_lora_config
26+ from .utils import _get_lora_device
4227
4328
4429class FusedMoEWithLoRA (BaseLayerWithLoRA ):
@@ -56,275 +41,46 @@ def __init__(self, base_layer: FusedMoE) -> None:
5641 # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
5742 # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
5843 self ._w13_slices = 2 if base_layer .moe_config .is_act_and_mul else 1
59- self ._inject_lora_into_fused_moe ()
60-
61- def _normalize_keys (self , config : dict [str , int | None ]) -> dict [str , int | None ]:
62- normalized_config = {}
63- for key , value in config .items ():
64- if key .islower ():
65- if key .startswith ("block_" ):
66- normalized_key = "BLOCK_SIZE_" + key .split ("_" )[- 1 ].upper ()
67- else :
68- normalized_key = key .upper ()
69- else :
70- normalized_key = key
71- normalized_config [normalized_key ] = value
72- return normalized_config
73-
74- def _get_lora_moe_configs (
75- self ,
76- op_prefix : str ,
77- num_loras : int ,
78- rank : int ,
79- num_slices : int ,
80- M : int ,
81- layer : FusedMoE ,
82- top_k : int ,
83- config_dtype : str ,
84- ):
85- if envs .APHRODITE_TUNED_CONFIG_FOLDER :
86- hidden_size = layer .hidden_size
87- intermediate_size = (
88- self .w2_lora_a_stacked [0 ].shape [- 1 ] if op_prefix == "w2" else self .w13_lora_b_stacked [0 ].shape [- 2 ]
89- )
90- shrink_config = get_lora_op_configs (
91- op_type = f"fused_moe_lora_{ op_prefix } _shrink" ,
92- max_loras = num_loras ,
93- batch = M ,
94- hidden_size = hidden_size ,
95- rank = rank ,
96- num_slices = num_slices ,
97- moe_intermediate_size = intermediate_size ,
98- )
99- expand_config = get_lora_op_configs (
100- op_type = f"fused_moe_lora_{ op_prefix } _expand" ,
101- max_loras = num_loras ,
102- batch = M ,
103- hidden_size = hidden_size , # lora_a_stacked.shape[-1],
104- rank = rank ,
105- num_slices = num_slices ,
106- moe_intermediate_size = intermediate_size , # lora_b_stacked.shape[-2],
107- )
108- else : # fall back to the default config
109- get_config_func = functools .partial (
110- try_get_optimal_moe_lora_config ,
111- w1_shape = layer .w13_weight .shape ,
112- w2_shape = layer .w2_weight .shape ,
113- rank = rank ,
114- top_k = top_k ,
115- dtype = config_dtype ,
116- M = M ,
117- block_shape = layer .quant_method .moe_quant_config .block_shape ,
118- )
119- shrink_config = get_config_func (op_type = f"fused_moe_lora_{ op_prefix } _shrink" )
120- expand_config = get_config_func (op_type = f"fused_moe_lora_{ op_prefix } _expand" )
121- shrink_config = self ._normalize_keys (shrink_config )
122- expand_config = self ._normalize_keys (expand_config )
123- return shrink_config , expand_config
124-
125- def _inject_lora_into_fused_moe (self ):
126- moe_state_dict = {}
127- top_k = self .base_layer .top_k
12844
12945 self .base_layer .ensure_moe_quant_config_init ()
130- quant_config = self .base_layer .quant_method .moe_quant_config
131-
13246 if getattr (self .base_layer .quant_method , "supports_internal_mk" , False ):
133- # Use the existing modular kernel from the quant method
134- m_fused_moe_fn = self .base_layer .quant_method .moe_kernel
47+ moe_kernel = self .base_layer .quant_method .moe_kernel
13548 # Don't let the kernel own shared experts so the runner can
13649 # overlap them with routed experts via a separate CUDA stream.
137- m_fused_moe_fn .shared_experts = None
50+ moe_kernel .shared_experts = None
13851 else :
139- # Create a new modular kernel via select_gemm_impl.
140- # Don't pass shared_experts to the kernel so the runner can
141- # overlap them with routed experts via a separate CUDA stream.
14252 prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular ()
143- m_fused_moe_fn = FusedMoEKernel (
53+ moe_kernel = FusedMoEKernel (
14454 prepare_finalize ,
14555 self .base_layer .quant_method .select_gemm_impl (prepare_finalize , self .base_layer ),
14656 )
147-
148- if quant_config .use_mxfp4_w4a16 :
149- assert isinstance (
150- m_fused_moe_fn .impl .fused_experts ,
151- (MarlinExperts , UnfusedOAITritonExperts ),
152- )
153- else :
154- assert isinstance (m_fused_moe_fn .impl .fused_experts , TritonExperts )
155-
156- def fwd_decorator (layer , func ):
157- def wrapper (* args , ** kwargs ):
158- moe_state_dict ["hidden_states" ] = kwargs ["hidden_states" ]
159- moe_state_dict ["topk_ids" ] = kwargs ["topk_ids" ]
160- moe_state_dict ["topk_weights" ] = kwargs ["topk_weights" ]
161- moe_state_dict ["expert_map" ] = kwargs ["expert_map" ]
162- moe_state_dict ["apply_router_weight_on_input" ] = kwargs ["apply_router_weight_on_input" ]
163- result = func (* args , ** kwargs )
164- return result
165-
166- return wrapper
167-
168- def act_decorator (layer , func ):
169- def wrapper (* args , ** kwargs ):
170- _ , output , input = args
171-
172- hidden_states = moe_state_dict ["hidden_states" ]
173- topk_weights = moe_state_dict ["topk_weights" ]
174- curr_topk_ids = moe_state_dict ["topk_ids" ]
175-
176- expert_map = moe_state_dict ["expert_map" ]
177-
178- config_dtype = _get_config_dtype_str (
179- dtype = hidden_states .dtype ,
180- use_fp8_w8a8 = False ,
181- use_int8_w8a16 = False ,
182- use_int4_w4a16 = False ,
183- )
184- num_tokens = hidden_states .size (0 )
185- M = num_tokens
186- max_lora_rank = self .w13_lora_a_stacked [0 ].shape [- 2 ]
187- shrink_config , expand_config = self ._get_lora_moe_configs (
188- op_prefix = "w13" ,
189- num_loras = self .max_loras ,
190- rank = max_lora_rank ,
191- num_slices = self ._w13_slices ,
192- M = M ,
193- layer = layer ,
194- top_k = top_k ,
195- config_dtype = config_dtype ,
196- )
197-
198- # SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
199- # activates only a small fraction of total experts * loras.
200- SPARSITY_FACTOR = 8
201- naive_block_assignment = (
202- expert_map is None
203- and num_tokens * top_k * SPARSITY_FACTOR <= self .base_layer .local_num_experts * self .max_loras
204- )
205-
206- # get the block size of m from customized config or default config
207- (
208- token_lora_mapping ,
209- sorted_token_ids_lora ,
210- expert_ids_lora ,
211- num_tokens_post_padded_lora ,
212- ) = self .punica_wrapper .moe_lora_align_block_size (
213- curr_topk_ids ,
214- num_tokens ,
215- shrink_config ["BLOCK_SIZE_M" ],
216- self .base_layer .local_num_experts ,
217- self .max_loras ,
218- self .adapter_enabled ,
219- expert_map ,
220- naive_block_assignment = naive_block_assignment ,
221- )
222-
223- moe_state_dict ["sorted_token_ids_lora" ] = sorted_token_ids_lora
224- moe_state_dict ["expert_ids_lora" ] = expert_ids_lora
225- moe_state_dict ["num_tokens_post_padded_lora" ] = num_tokens_post_padded_lora
226- moe_state_dict ["token_lora_mapping" ] = token_lora_mapping
227-
228- if sorted_token_ids_lora is not None :
229- expert_ids_lora = expert_ids_lora .view (self .max_loras , - 1 )
230- sorted_token_ids_lora = sorted_token_ids_lora .view (self .max_loras , - 1 )
231- #
232-
233- self .punica_wrapper .add_lora_fused_moe (
234- input .view (- 1 , top_k , input .shape [- 1 ]),
235- hidden_states ,
236- self .w13_lora_a_stacked ,
237- self .w13_lora_b_stacked ,
238- topk_weights ,
239- sorted_token_ids_lora ,
240- expert_ids_lora ,
241- num_tokens_post_padded_lora ,
242- max_lora_rank ,
243- top_k ,
244- shrink_config , ## pass the shrink config
245- expand_config , ## pass the expand config
246- self .adapter_enabled ,
247- fully_sharded = self .fully_sharded ,
248- token_lora_mapping = token_lora_mapping ,
249- )
250-
251- result = func (* args , ** kwargs )
252-
253- moe_state_dict ["intermediate_cache2" ] = output
254- return result
255-
256- return wrapper
257-
258- def moe_sum_decorator (layer , func ):
259- def wrapper (* args , ** kwargs ):
260- hidden_states = moe_state_dict ["hidden_states" ]
261- topk_weights = moe_state_dict ["topk_weights" ]
262-
263- config_dtype = _get_config_dtype_str (
264- dtype = hidden_states .dtype ,
265- use_fp8_w8a8 = False ,
266- use_int8_w8a16 = False ,
267- use_int4_w4a16 = False ,
268- )
269- num_tokens = hidden_states .size (0 )
270- M = num_tokens
271- max_lora_rank = self .w2_lora_a_stacked [0 ].shape [- 2 ]
272- shrink_config , expand_config = self ._get_lora_moe_configs (
273- op_prefix = "w2" ,
274- num_loras = self .max_loras ,
275- rank = max_lora_rank ,
276- num_slices = 1 ,
277- M = M ,
278- layer = layer ,
279- top_k = top_k ,
280- config_dtype = config_dtype ,
281- )
282-
283- sorted_token_ids_lora = moe_state_dict ["sorted_token_ids_lora" ]
284- expert_ids_lora = moe_state_dict ["expert_ids_lora" ]
285- num_tokens_post_padded_lora = moe_state_dict ["num_tokens_post_padded_lora" ]
286- token_lora_mapping = moe_state_dict .get ("token_lora_mapping" )
287-
288- if sorted_token_ids_lora is not None :
289- expert_ids_lora = expert_ids_lora .view (self .max_loras , - 1 )
290- sorted_token_ids_lora = sorted_token_ids_lora .view (self .max_loras , - 1 )
291- intermediate_cache2 = moe_state_dict ["intermediate_cache2" ]
292- intermediate_cache3 = args [0 ]
293-
294- shard_size_w2 = divide (self .base_layer .hidden_size , self .tp_size )
295-
296- self .punica_wrapper .add_lora_fused_moe (
297- intermediate_cache3 ,
298- intermediate_cache2 ,
299- self .w2_lora_a_stacked ,
300- self .w2_lora_b_stacked ,
301- topk_weights ,
302- sorted_token_ids_lora ,
303- expert_ids_lora ,
304- num_tokens_post_padded_lora ,
305- max_lora_rank ,
306- top_k ,
307- shrink_config , ## pass the shrink config
308- expand_config , ## pass the expand config
309- self .adapter_enabled ,
310- True ,
311- fully_sharded = self .fully_sharded ,
312- offset = shard_size_w2 * self .tp_rank if self .fully_sharded else 0 ,
313- token_lora_mapping = token_lora_mapping ,
314- )
315-
316- result = func (* args , ** kwargs )
317- return result
318-
319- return wrapper
320-
321- fused_experts = m_fused_moe_fn .impl .fused_experts
322-
323- m_fused_moe_fn .apply = fwd_decorator (self .base_layer , m_fused_moe_fn .apply )
324- fused_experts .activation = act_decorator (self .base_layer , fused_experts .activation )
325- fused_experts .moe_sum = moe_sum_decorator (self .base_layer , fused_experts .moe_sum )
326- # TODO(bnell): find a less intrusive way to handle this.
327- self .base_layer ._replace_quant_method (FusedMoEModularMethod (self .base_layer .quant_method , m_fused_moe_fn ))
57+ assert moe_kernel .supports_lora (), (
58+ f"{ type (moe_kernel .fused_experts ).__name__ } does not support LoRA. "
59+ "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
60+ "(auto selects Triton automatically when LoRA is enabled). "
61+ "For quantized MoE, mix LoRAExpertsMixin into the experts class "
62+ "and consume self._lora_context in apply()."
63+ )
64+ self ._fused_experts = moe_kernel .fused_experts
65+ self .base_layer ._replace_quant_method (FusedMoEModularMethod (self .base_layer .quant_method , moe_kernel ))
66+
67+ def _build_lora_context (self ):
68+ return MoELoRAContext (
69+ w13_lora_a_stacked = self .w13_lora_a_stacked ,
70+ w13_lora_b_stacked = self .w13_lora_b_stacked ,
71+ w2_lora_a_stacked = self .w2_lora_a_stacked ,
72+ w2_lora_b_stacked = self .w2_lora_b_stacked ,
73+ adapter_enabled = self .adapter_enabled ,
74+ max_loras = self .max_loras ,
75+ top_k = self .base_layer .top_k ,
76+ w13_num_slices = self ._w13_slices ,
77+ fully_sharded = self .fully_sharded ,
78+ tp_rank = self .tp_rank ,
79+ tp_size = self .tp_size ,
80+ local_num_experts = self .base_layer .local_num_experts ,
81+ punica_wrapper = self .punica_wrapper ,
82+ use_tuned_config = bool (envs .APHRODITE_TUNED_CONFIG_FOLDER ),
83+ )
32884
32985 def _create_lora_a_weights (
33086 self ,
@@ -543,6 +299,10 @@ def set_lora(
543299 sliced_w2_lora_b , non_blocking = True
544300 )
545301
302+ def set_mapping (self , punica_wrapper ):
303+ super ().set_mapping (punica_wrapper )
304+ self ._fused_experts .set_lora_context (self ._build_lora_context ())
305+
546306 def forward (self , * args , ** kwargs ):
547307 return self .base_layer .forward (* args , ** kwargs )
548308
0 commit comments