Skip to content

Commit c4b463b

Browse files
authored
[bugfix] fix gemma4 31b (#9080)
1 parent 6824031 commit c4b463b

14 files changed

Lines changed: 57 additions & 26 deletions

File tree

docs/source/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ ENV:
215215

216216
- 🔥output_dir: 模型预测结果和检查点将被写入的输出目录。默认为None,设置为`'output/<model_name>'`
217217
- 🔥gradient_checkpointing: 是否使用gradient_checkpointing,默认为True。该参数可以显著降低显存占用,但降低训练速度。
218-
- 🔥vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为None,即设置为`gradient_checkpointing`。例子参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/vit_gradient_checkpointing.sh)
219-
- 注意:多模态模型且是LoRA训练时,当设置了`--freeze_vit false`,且命令行中出现以下警告:`UserWarning: None of the inputs have requires_grad=True. Gradients will be None`,请设置`--vit_gradient_checkpointing false`,或提相关issue。全参数训练则不会出现该问题。(如果RLHF LoRA训练中,ref_model抛出来的警告,则是正常的)
218+
- 🔥vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为None,即当`--freeze_vit``false`时开启。例子参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/vit_gradient_checkpointing.sh)
220219
- 🔥deepspeed: 默认为None。可以设置为'zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'来使用ms-swift内置的deepspeed配置文件。你也可以传入自定义deepspeed配置文件的路径。
221220
- zero_hpz_partition_size: 默认为None,这个参数是ZeRO++的特性,即node内模型分片,node间数据分片,如果遇到grad_norm NaN,请尝试使用`--torch_dtype float16`
222221
- deepspeed_autotp_size: DeepSpeed张量并行大小,默认为1。使用DeepSpeed AutoTP时需将参数`--deepspeed`设置为'zero0'、'zero1'或'zero2'。(注意:该功能只支持全参数)

docs/source/Instruction/Supported-models-and-datasets.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,10 +1098,10 @@
10981098
|[google/gemma-3n-E4B](https://modelscope.cn/models/google/gemma-3n-E4B)|gemma3n|gemma3n|transformers>=4.53.1|&#x2718;|-|[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)|
10991099
|[google/gemma-3n-E2B-it](https://modelscope.cn/models/google/gemma-3n-E2B-it)|gemma3n|gemma3n|transformers>=4.53.1|&#x2718;|-|[google/gemma-3n-E2B-it](https://huggingface.co/google/gemma-3n-E2B-it)|
11001100
|[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|&#x2718;|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
1101-
|[google/gemma-4-E2B](https://modelscope.cn/models/google/gemma-4-E2B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B](https://huggingface.co/google/gemma-4-E2B)|
1102-
|[google/gemma-4-E2B-it](https://modelscope.cn/models/google/gemma-4-E2B-it)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B-it](https://huggingface.co/google/gemma-4-E2B-it)|
1103-
|[google/gemma-4-E4B](https://modelscope.cn/models/google/gemma-4-E4B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B](https://huggingface.co/google/gemma-4-E4B)|
1104-
|[google/gemma-4-E4B-it](https://modelscope.cn/models/google/gemma-4-E4B-it)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)|
1101+
|[google/gemma-4-E2B](https://modelscope.cn/models/google/gemma-4-E2B)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B](https://huggingface.co/google/gemma-4-E2B)|
1102+
|[google/gemma-4-E2B-it](https://modelscope.cn/models/google/gemma-4-E2B-it)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B-it](https://huggingface.co/google/gemma-4-E2B-it)|
1103+
|[google/gemma-4-E4B](https://modelscope.cn/models/google/gemma-4-E4B)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B](https://huggingface.co/google/gemma-4-E4B)|
1104+
|[google/gemma-4-E4B-it](https://modelscope.cn/models/google/gemma-4-E4B-it)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)|
11051105
|[google/gemma-4-31B](https://modelscope.cn/models/google/gemma-4-31B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-31B](https://huggingface.co/google/gemma-4-31B)|
11061106
|[google/gemma-4-31B-it](https://modelscope.cn/models/google/gemma-4-31B-it)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-31B-it](https://huggingface.co/google/gemma-4-31B-it)|
11071107
|[google/gemma-4-26B-A4B](https://modelscope.cn/models/google/gemma-4-26B-A4B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-26B-A4B](https://huggingface.co/google/gemma-4-26B-A4B)|

docs/source/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ lora训练:
253253
- 🔥offload_bridge: Megatron导出的用于vLLM更新HF格式权重使用CPU主存存放,以降低 GPU 显存占用。默认为 False。(在GRPO/GKD算法中生效)
254254

255255
**多模态参数**:
256-
- vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为True。(**Megatron-SWIFT的vit实现使用transformers实现**
256+
- vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为None,即当`--freeze_vit``false`时开启。(**Megatron-SWIFT的vit实现使用transformers实现**
257257
- vit_gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--vit_gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。该参数只对`vit_gradient_checkpointing`生效。
258258
- vit_attn_impl: 多模态模型训练时,设置vit部分的attn_impl实现。默认为'flash_attn'。
259259
- vit_lr: 当训练多模态大模型时,该参数指定vit的学习率,默认为None,等于learning_rate。通常与`--freeze_vit``--freeze_aligner`参数结合使用。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,7 @@ This list inherits from the Transformers `Seq2SeqTrainingArguments`, with ms-swi
220220

221221
- 🔥output_dir: The output directory where the model predictions and checkpoints will be written. Default is `None`, automatically set to `'output/<model_name>'`.
222222
- 🔥gradient_checkpointing: Whether to use gradient checkpointing. Default is `True`. This significantly reduces GPU memory usage but slows down training.
223-
- 🔥vit_gradient_checkpointing: For multimodal model training, whether to enable gradient checkpointing for the ViT (Vision Transformer) component. Default is `None`, meaning it follows the value of `gradient_checkpointing`. For an example, please refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/vit_gradient_checkpointing.sh).
224-
- Note: When training multimodal models with LoRA and `--freeze_vit false`, if you see the warning: `UserWarning: None of the inputs have requires_grad=True. Gradients will be None`, try setting `--vit_gradient_checkpointing false` or open an issue. This issue does not occur in full-parameter training. (If this warning comes from the `ref_model` during RLHF LoRA training, it is normal.)
223+
- 🔥vit_gradient_checkpointing: Whether to enable gradient checkpointing for the ViT component during multimodal model training. Defaults to `None`, which means it is enabled when `--freeze_vit` is `false`. For an example, please refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/vit_gradient_checkpointing.sh).
225224
- 🔥deepspeed: Default is `None`. Can be set to `'zero0'`, `'zero1'`, `'zero2'`, `'zero3'`, `'zero2_offload'`, `'zero3_offload'` to use built-in DeepSpeed configurations in ms-swift. You can also pass a path to a custom DeepSpeed config file.
226225
- zero_hpz_partition_size: Default is `None`. This enables ZeRO++ functionality—model sharding within nodes and data sharding across nodes. If encountering `grad_norm NaN`, try using `--torch_dtype float16`.
227226
- deepspeed_autotp_size: DeepSpeed tensor parallelism size. Default is 1. To use DeepSpeed AutoTP, set `--deepspeed` to `'zero0'`, `'zero1'`, or `'zero2'`. (Note: Only supports full-parameter training)

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,10 +1099,10 @@ The table below introduces the models integrated with ms-swift:
10991099
|[google/gemma-3n-E4B](https://modelscope.cn/models/google/gemma-3n-E4B)|gemma3n|gemma3n|transformers>=4.53.1|&#x2718;|-|[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)|
11001100
|[google/gemma-3n-E2B-it](https://modelscope.cn/models/google/gemma-3n-E2B-it)|gemma3n|gemma3n|transformers>=4.53.1|&#x2718;|-|[google/gemma-3n-E2B-it](https://huggingface.co/google/gemma-3n-E2B-it)|
11011101
|[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|&#x2718;|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
1102-
|[google/gemma-4-E2B](https://modelscope.cn/models/google/gemma-4-E2B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B](https://huggingface.co/google/gemma-4-E2B)|
1103-
|[google/gemma-4-E2B-it](https://modelscope.cn/models/google/gemma-4-E2B-it)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B-it](https://huggingface.co/google/gemma-4-E2B-it)|
1104-
|[google/gemma-4-E4B](https://modelscope.cn/models/google/gemma-4-E4B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B](https://huggingface.co/google/gemma-4-E4B)|
1105-
|[google/gemma-4-E4B-it](https://modelscope.cn/models/google/gemma-4-E4B-it)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)|
1102+
|[google/gemma-4-E2B](https://modelscope.cn/models/google/gemma-4-E2B)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B](https://huggingface.co/google/gemma-4-E2B)|
1103+
|[google/gemma-4-E2B-it](https://modelscope.cn/models/google/gemma-4-E2B-it)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E2B-it](https://huggingface.co/google/gemma-4-E2B-it)|
1104+
|[google/gemma-4-E4B](https://modelscope.cn/models/google/gemma-4-E4B)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B](https://huggingface.co/google/gemma-4-E4B)|
1105+
|[google/gemma-4-E4B-it](https://modelscope.cn/models/google/gemma-4-E4B-it)|gemma4|gemma4_nothinking|transformers>=4.53|&#x2718;|-|[google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)|
11061106
|[google/gemma-4-31B](https://modelscope.cn/models/google/gemma-4-31B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-31B](https://huggingface.co/google/gemma-4-31B)|
11071107
|[google/gemma-4-31B-it](https://modelscope.cn/models/google/gemma-4-31B-it)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-31B-it](https://huggingface.co/google/gemma-4-31B-it)|
11081108
|[google/gemma-4-26B-A4B](https://modelscope.cn/models/google/gemma-4-26B-A4B)|gemma4|gemma4|transformers>=4.53|&#x2718;|-|[google/gemma-4-26B-A4B](https://huggingface.co/google/gemma-4-26B-A4B)|

docs/source_en/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ LoRA Training:
269269
- 🔥offload_bridge: Use CPU main memory to store HF format weights exported by Megatron for vLLM updates, to reduce GPU memory usage. Defaults to False. (Takes effect in GRPO/GKD algorithms)
270270

271271
**Multimodal Parameters**:
272-
- vit_gradient_checkpointing: Whether to enable gradient checkpointing for the ViT (Vision Transformer) component during multimodal model training. Defaults to `True`. (**The ViT implementation in Megatron-SWIFT uses the Hugging Face `transformers` library.**)
272+
- vit_gradient_checkpointing: Whether to enable gradient checkpointing for the ViT component during multimodal model training. Defaults to `None`, which means it is enabled when `--freeze_vit` is `false`. (**The ViT implementation in Megatron-SWIFT uses the Hugging Face `transformers` library.**)
273273
- vit_gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--vit_gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled.
274274
- vit_attn_impl: When training a multimodal model, sets the `attn_impl` implementation used for the ViT part. Defaults to `'flash_attn'`.
275275
- vit_lr: Specifies the learning rate for the ViT module when training multimodal models. Default is `None`, same as `learning_rate`. Typically used together with `--freeze_vit` and `--freeze_aligner`.

swift/arguments/sft_args.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from transformers.utils.versions import require_version
55
from typing import Literal, Optional
66

7-
from swift.trainers import Seq2SeqTrainingArguments, TrainArgumentsMixin, TrainerFactory
7+
from swift.trainers import Seq2SeqTrainingArguments, TrainerFactory
88
from swift.utils import (add_version_to_work_dir, get_device_count, get_logger, get_pai_tensorboard_dir, is_mp,
99
is_pai_training_job, is_swanlab_available, json_parse_to_dict, to_abspath)
1010
from .base_args import BaseArguments
@@ -124,7 +124,7 @@ class SftArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTrain
124124
"""Arguments pertaining to the training process.
125125
126126
SftArguments is a dataclass that inherits from multiple argument classes: SwanlabArguments, TunerArguments,
127-
BaseArguments, TrainArgumentsMixin, Seq2SeqTrainingArguments.
127+
BaseArguments, Seq2SeqTrainingArguments.
128128
129129
Args:
130130
add_version (bool): Whether to add a versioned subdirectory like '<version>-<timestamp>' to the `output_dir` to
@@ -205,6 +205,8 @@ def __post_init__(self) -> None:
205205
self._init_override()
206206
TunerArguments.__post_init__(self)
207207
self._check_padding_free()
208+
if self.vit_gradient_checkpointing is None:
209+
self.vit_gradient_checkpointing = not self.freeze_vit
208210
if self.optimizer is None:
209211
if self.lorap_lr_ratio:
210212
self.optimizer = 'lorap'
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"^<think>\\s*</think>\\s*": [0.0],
3-
"^<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, so I will directly start answering the question.</seed:cot_budget_reflect>\n</seed:think>\\s*": [0.0],
4-
"^</think>\\s*": [0.0]
3+
"^<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, so I will directly start answering the question.</seed:cot_budget_reflect>\\n</seed:think>\\s*": [0.0],
4+
"^</think>\\s*": [0.0],
5+
"^<\\|channel>thought\\n<channel\\|>": [0.0]
56
}

swift/model/models/gemma.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,17 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
216216
Model('google/gemma-4-E2B-it', 'google/gemma-4-E2B-it'),
217217
Model('google/gemma-4-E4B', 'google/gemma-4-E4B'),
218218
Model('google/gemma-4-E4B-it', 'google/gemma-4-E4B-it'),
219+
],
220+
template=TemplateType.gemma4_nothinking),
221+
ModelGroup([
219222
Model('google/gemma-4-31B', 'google/gemma-4-31B'),
220223
Model('google/gemma-4-31B-it', 'google/gemma-4-31B-it'),
221224
Model('google/gemma-4-26B-A4B', 'google/gemma-4-26B-A4B'),
222225
Model('google/gemma-4-26B-A4B-it', 'google/gemma-4-26B-A4B-it'),
223-
], ),
226+
],
227+
template=TemplateType.gemma4),
224228
],
225229
Gemma4Loader,
226-
template=TemplateType.gemma4,
227230
architectures=['Gemma4ForConditionalGeneration'],
228231
model_arch=ModelArch.gemma3n,
229232
requires=['transformers>=4.53'],

swift/template/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,7 @@ def _is_add_non_thinking_round(self, messages, i: int, start_idx: int):
10501050
message = messages[i]
10511051
return i >= start_idx and message['role'] == 'assistant'
10521052

1053-
def _add_non_thinking_prefix(self, inputs) -> None:
1053+
def _add_non_thinking_prefix(self, inputs, thinking_prefix='<think>') -> None:
10541054
messages = inputs.messages
10551055
non_thinking_prefix = self.template_meta.non_thinking_prefix
10561056
if non_thinking_prefix:
@@ -1063,14 +1063,14 @@ def _add_non_thinking_prefix(self, inputs) -> None:
10631063
start_idx = -1
10641064
for i, message in enumerate(messages):
10651065
if (self._is_add_non_thinking_round(messages, i, start_idx) and isinstance(message['content'], str)
1066-
and not message['content'].startswith(('<think>', non_thinking_prefix))):
1066+
and not message['content'].startswith((thinking_prefix, non_thinking_prefix))):
10671067
# During multi-turn SFT training/validation:
10681068
# If the message has no <think> block and does not start with the non_thinking_prefix,
10691069
# prepend the non_thinking_prefix to the content.
10701070
message['content'] = non_thinking_prefix + message['content']
10711071

1072-
def _remove_thinking_content(self, content: str) -> str:
1073-
content = content.split('</think>')[-1].strip()
1072+
def _remove_thinking_content(self, content: str, thinking_suffix='</think>') -> str:
1073+
content = content.split(thinking_suffix)[-1].strip()
10741074
return self.template_meta.history_thinking_prefix + content
10751075

10761076
def _remove_history_thinking(self, inputs) -> None:

0 commit comments

Comments
 (0)