Skip to content

Commit 07d38b0

Browse files
committed
fix: ft logic
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent bff9b04 commit 07d38b0

2 files changed

Lines changed: 19 additions & 4 deletions

File tree

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ def prepare_scattermoe(
128128
# pylint: disable=import-outside-toplevel
129129
from .scattermoe import ScatterMoE
130130

131+
lora = False
132+
if lora_config:
133+
lora = True
134+
131135
if disable_distributed and ep_degree > 1:
132136
raise ValueError(
133137
"expert sharding can not be deferred to top level sharding"
@@ -251,6 +255,7 @@ def prepare_scattermoe(
251255
module_name,
252256
router_name,
253257
"|".join(expert_name),
258+
lora_start=lora
254259
target_modules=lora_config.target_modules,
255260
)
256261

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def get_checkpoint_meta_from_sharded_safetensor(
8888
router_name: str = "gate", # e.g., named "gate" within block_sparse_moe
8989
expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe
9090
expert_map: Dict = None, # map -> [w1,w2,w3]
91+
lora_start: bool = False, # if lora is detected in prepare_scattermoe.py
9192
lora_utils: bool = False, # if lora is detected in checkpoint_utils.py
9293
target_modules: Dict = None, # target modules from prepare_scattermoe.py
9394
) -> Dict[str, List[Tuple]]:
@@ -176,12 +177,14 @@ def _insert(L: List, i: int, v):
176177
else:
177178
_map[KEY_SCATTERMOE_ROUTER].append((k, stfile))
178179
elif m.group(1) in expert_name:
180+
index = m.group(2)
181+
index = 0 if index is None else int(index)
182+
mod = None
183+
184+
# LoRA case
179185
if (
180186
"input_linear" in target_modules and "output_linear" in target_modules
181187
) or lora_utils:
182-
index = m.group(2)
183-
index = 0 if index is None else int(index)
184-
mod = None
185188
if not lora_utils:
186189
for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))):
187190
_insert(_map[f"{mod}.weight"], index, (k, stfile))
@@ -190,7 +193,14 @@ def _insert(L: List, i: int, v):
190193
_insert(_map[f"{mod}.lora_A"], index, (k, stfile))
191194
_insert(_map[f"{mod}.lora_B"], index, (k, stfile))
192195

193-
assert mod is not None, f"cannot map '{rel_k}'"
196+
# Fine-tuning case
197+
elif not lora_utils and not lora_start:
198+
for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))):
199+
_insert(_map[f"{mod}.weight"], index, (k, stfile))
200+
201+
assert mod is not None, f"cannot map '{rel_k}'"
202+
203+
194204

195205
if len(_map) == 0:
196206
raise ValueError(

0 commit comments

Comments
 (0)