Skip to content

Commit 7a42747

Browse files
authored
Fix Fleet Model SFT bug. (#3603)
1 parent e43ab0b commit 7a42747

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

paddleformers/transformers/glm4_moe/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ def __new__(cls, config):
13951395
model_provider_class = GLMMoEModelProvider
13961396
model_provider = model_provider_class.from_config(config)
13971397
loss_fn = None
1398-
if hasattr(config, "dpo_config"):
1398+
if getattr(config, "dpo_config", None):
13991399
loss_fn = CriterionLayerPipe(config, use_infohub=True)
14001400
gpt_model = model_provider.provide(loss_fn=loss_fn)
14011401
gpt_model._gen_aoa_config = cls._gen_aoa_config
@@ -1569,7 +1569,7 @@ def __new__(cls, config):
15691569
model_provider_class = GLMMoEModelProvider
15701570
model_provider = model_provider_class.from_config(config)
15711571
loss_fn = None
1572-
if hasattr(config, "dpo_config"):
1572+
if getattr(config, "dpo_config", None):
15731573
loss_fn = CriterionLayerPipe(config, use_infohub=True)
15741574
gpt_model = model_provider.provide(loss_fn=loss_fn)
15751575
gpt_model._gen_aoa_config = cls._gen_aoa_config

paddleformers/transformers/qwen3_moe/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,7 @@ def __new__(cls, config):
12181218
model_provider_class = Qwen3MoEModelProvider
12191219
model_provider = model_provider_class.from_config(config)
12201220
loss_fn = None
1221-
if hasattr(config, "dpo_config"):
1221+
if getattr(config, "dpo_config", None):
12221222
loss_fn = CriterionLayerPipe(config, use_infohub=True)
12231223
gpt_model = model_provider.provide(loss_fn=loss_fn)
12241224
gpt_model._gen_aoa_config = cls._gen_aoa_config
@@ -1372,7 +1372,7 @@ def __new__(cls, config):
13721372
model_provider_class = Qwen3MoEModelProvider
13731373
model_provider = model_provider_class.from_config(config)
13741374
loss_fn = None
1375-
if hasattr(config, "dpo_config"):
1375+
if getattr(config, "dpo_config", None):
13761376
loss_fn = CriterionLayerPipe(config, use_infohub=True)
13771377
gpt_model = model_provider.provide(loss_fn=loss_fn)
13781378
gpt_model._gen_aoa_config = cls._gen_aoa_config

0 commit comments

Comments
 (0)