28284. Unified EPLB integration for backends that support it
2929"""
3030
31+ import copy
3132from typing import Dict , List , Optional , Tuple , Union
3233
3334import torch
@@ -162,21 +163,34 @@ def __init__(
162163 self .apply_router_weight_on_input = apply_router_weight_on_input
163164
164165 # ========== Create MoE Backend (Default: Cutlass) ==========
165- from tensorrt_llm ._torch .modules .fused_moe .create_moe import create_moe_backend , get_moe_cls
166+ from tensorrt_llm ._torch .modules .fused_moe .create_moe import (
167+ create_moe_backend ,
168+ resolve_moe_cls ,
169+ )
170+
171+ # Get MoE backend class based on override_quant_config, routing_method, and model_config
172+ moe_cls = resolve_moe_cls (
173+ model_config ,
174+ routing_method ,
175+ self .dtype ,
176+ override_quant_config = override_quant_config ,
177+ )
166178
167- # Get MoE backend class based on override_quant_config or model_config
168- moe_cls = get_moe_cls (model_config , override_quant_config = override_quant_config )
179+ backend_model_config = model_config
180+ if override_quant_config is not None :
181+ backend_model_config = copy .deepcopy (model_config )
182+ backend_model_config .quant_config = override_quant_config
169183
170184 # Call create_moe_backend with all necessary parameters
171185 # init_load_balancer=False: Prevents backend from registering itself with load balancer
172186 # without_comm=True: Prevents backend from initializing communication (ConfigurableMoE handles it)
173187 # skip_create_weights_in_init=True: Prevents backend from creating weights in __init__
174188 # because backend uses layer_idx=None and may have different expert assignments
175189 # We will create weights after syncing attributes from ConfigurableMoE
176- tmp_skip_create_weights_in_init = model_config .skip_create_weights_in_init
177- model_config ._frozen = False
178- model_config .skip_create_weights_in_init = True
179- model_config ._frozen = True
190+ tmp_skip_create_weights_in_init = backend_model_config .skip_create_weights_in_init
191+ backend_model_config ._frozen = False
192+ backend_model_config .skip_create_weights_in_init = True
193+ backend_model_config ._frozen = True
180194
181195 backend = create_moe_backend (
182196 moe_cls = moe_cls ,
@@ -186,7 +200,7 @@ def __init__(
186200 intermediate_size = self .intermediate_size ,
187201 dtype = self .dtype ,
188202 reduce_results = self .reduce_results ,
189- model_config = model_config ,
203+ model_config = backend_model_config ,
190204 aux_stream_dict = self .aux_stream_dict ,
191205 weight_loading_mode = self .weight_loading_mode ,
192206 bias = kwargs .get ("bias" , False ),
@@ -221,10 +235,10 @@ def __init__(
221235 self .backend .expert_size_per_partition = self .expert_size_per_partition
222236
223237 # Create weights here, because the backend needs the layer_load_balancer info to create weights
224- model_config ._frozen = False
225- model_config .skip_create_weights_in_init = tmp_skip_create_weights_in_init
226- model_config ._frozen = True
227- if not model_config .skip_create_weights_in_init :
238+ backend_model_config ._frozen = False
239+ backend_model_config .skip_create_weights_in_init = tmp_skip_create_weights_in_init
240+ backend_model_config ._frozen = True
241+ if not backend_model_config .skip_create_weights_in_init :
228242 self .backend .create_weights ()
229243
230244 # ========== Create Communication Strategy ==========
0 commit comments