@@ -170,6 +170,14 @@ class RepFlowArgs:
170170 In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
171171 or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
172172 accommodating larger selection numbers.
173+ use_moe : bool, optional
174+ Whether to use Mixture-of-Experts for the MLP layers in each RepFlowLayer.
175+ n_routing_experts : int, optional
176+ Total number of routing experts across all GPUs.
177+ moe_topk : int, optional
178+ Number of experts selected per token.
179+ n_shared_experts : int, optional
180+ Number of shared experts (replicated on every GPU).
173181 """
174182
175183 def __init__ (
@@ -201,6 +209,10 @@ def __init__(
201209 use_exp_switch : bool = False ,
202210 use_dynamic_sel : bool = False ,
203211 sel_reduce_factor : float = 10.0 ,
212+ use_moe : bool = False ,
213+ n_routing_experts : int = 0 ,
214+ moe_topk : int = 0 ,
215+ n_shared_experts : int = 0 ,
204216 ) -> None :
205217 self .n_dim = n_dim
206218 self .e_dim = e_dim
@@ -231,6 +243,10 @@ def __init__(
231243 self .use_exp_switch = use_exp_switch
232244 self .use_dynamic_sel = use_dynamic_sel
233245 self .sel_reduce_factor = sel_reduce_factor
246+ self .use_moe = use_moe
247+ self .n_routing_experts = n_routing_experts
248+ self .moe_topk = moe_topk
249+ self .n_shared_experts = n_shared_experts
234250
235251 def __getitem__ (self , key : str ) -> Any :
236252 if hasattr (self , key ):
@@ -266,6 +282,10 @@ def serialize(self) -> dict:
266282 "use_exp_switch" : self .use_exp_switch ,
267283 "use_dynamic_sel" : self .use_dynamic_sel ,
268284 "sel_reduce_factor" : self .sel_reduce_factor ,
285+ "use_moe" : self .use_moe ,
286+ "n_routing_experts" : self .n_routing_experts ,
287+ "moe_topk" : self .moe_topk ,
288+ "n_shared_experts" : self .n_shared_experts ,
269289 }
270290
271291 @classmethod
0 commit comments