@@ -88,6 +88,7 @@ def get_checkpoint_meta_from_sharded_safetensor(
8888 router_name : str = "gate" , # e.g., named "gate" within block_sparse_moe
8989 expert_name : str = "experts" , # e.g., named "experts" within block_sparse_moe
9090 expert_map : Dict = None , # map -> [w1,w2,w3]
91+ lora_start : bool = False , # if lora is detected in prepare_scattermoe.py
9192 lora_utils : bool = False , # if lora is detected in checkpoint_utils.py
9293 target_modules : Dict = None , # target modules from prepare_scattermoe.py
9394) -> Dict [str , List [Tuple ]]:
@@ -176,12 +177,14 @@ def _insert(L: List, i: int, v):
176177 else :
177178 _map [KEY_SCATTERMOE_ROUTER ].append ((k , stfile ))
178179 elif m .group (1 ) in expert_name :
180+ index = m .group (2 )
181+ index = 0 if index is None else int (index )
182+ mod = None
183+
184+ # LoRA case
179185 if (
180186 "input_linear" in target_modules and "output_linear" in target_modules
181187 ) or lora_utils :
182- index = m .group (2 )
183- index = 0 if index is None else int (index )
184- mod = None
185188 if not lora_utils :
186189 for mod in expert_map .get (m .group (1 ), expert_map .get (m .group (3 ))):
187190 _insert (_map [f"{ mod } .weight" ], index , (k , stfile ))
@@ -190,7 +193,14 @@ def _insert(L: List, i: int, v):
190193 _insert (_map [f"{ mod } .lora_A" ], index , (k , stfile ))
191194 _insert (_map [f"{ mod } .lora_B" ], index , (k , stfile ))
192195
193- assert mod is not None , f"cannot map '{ rel_k } '"
196+ # Fine-tuning case
197+ elif not lora_utils and not lora_start :
198+ for mod in expert_map .get (m .group (1 ), expert_map .get (m .group (3 ))):
199+ _insert (_map [f"{ mod } .weight" ], index , (k , stfile ))
200+
201+ assert mod is not None , f"cannot map '{ rel_k } '"
202+
203+
194204
195205 if len (_map ) == 0 :
196206 raise ValueError (
0 commit comments