Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddleformers/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"MoECorrectionBiasAdjustCallback",
"MoeExpertsGradScaleCallback",
"MoEGateSpGradSyncCallBack",
"GlobalRNGCallback",
],
"trainer_utils": [
"get_last_checkpoint",
Expand Down
39 changes: 35 additions & 4 deletions paddleformers/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dataclasses
import json
import os
import random
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
Expand All @@ -41,8 +42,19 @@
# Conditionally import paddlefleet modules
if is_paddlefleet_available():
from paddlefleet.models.gpt import GPTModel
from paddlefleet.transformer.moe.moe_layer import MoELayer
from paddlefleet.transformer.moe.moe_router import StandardMoERouter
else:
GPTModel = None # Define a mock or None when not available

class GPTModel:
pass

class MoELayer:
pass

class StandardMoERouter:
pass


from tqdm.auto import tqdm

Expand Down Expand Up @@ -715,7 +727,7 @@ def on_step_begin(self, args, state, control, **kwargs):
):
self.moe_weights_name = []
self.use_fp8 = True
if GPTModel is not None and isinstance(model, GPTModel):
if isinstance(model, GPTModel):
self.use_fp8 = model.use_fp8()
if not self.use_fp8:
return
Expand Down Expand Up @@ -774,7 +786,9 @@ def on_optimizer_end(self, args, state, control, **kwargs):
usages = []

def get_stat(layer):
if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc":
if (
isinstance(layer, PretrainedMoEGate) or isinstance(layer, StandardMoERouter)
) and layer.topk_method == "noaux_tc":
biases.append(layer.e_score_correction_bias)
usages.append(layer.expert_usage)

Expand Down Expand Up @@ -809,7 +823,9 @@ def get_stat(layer):
# print('on_optimizer_end update:', update.tolist())

def update_bias(layer):
if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc":
if (
isinstance(layer, PretrainedMoEGate) or isinstance(layer, StandardMoERouter)
) and layer.topk_method == "noaux_tc":
with paddle.no_grad():
if not layer.weight.stop_gradient:
biases.pop(0).add_(update_list.pop(0))
Expand Down Expand Up @@ -933,3 +949,18 @@ def on_train_begin(self, args, state, control, **kwargs):
for name, param in self.model.state_dict().items():
if "weight1" in name:
self.interleave_gate_up_proj(param)


class GlobalRNGCallback(TrainerCallback):
"""
此 hook 给组网插入正确的全局 随机数生成器
"""

def on_step_end(self, args, state, control, model, **kwargs):
rng = random.Random(state.global_step)

def _set_global_rng(layer):
if isinstance(layer, MoELayer):
layer.rng = rng

model.apply(_set_global_rng)
1 change: 1 addition & 0 deletions tests/transformers/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_from_pretrained_cache_dir(self):
# check against double appending model_name in cache_dir
self.assertFalse(os.path.exists(os.path.join(tempdir, model_id, model_id)))

@slow
def test_load_from_hf(self):
"""test load config from hf"""
config = Qwen3Config.from_pretrained("Qwen/Qwen3-0.6B", download_hub="huggingface")
Expand Down
Loading