diff --git a/git_command b/git_command new file mode 100644 index 0000000000..f8c3dc5be7 --- /dev/null +++ b/git_command @@ -0,0 +1,13 @@ +# 1. 查看两个分支的差异(确认要合并的内容) +git diff main..new_althorithm/sympo +# 2. 确认当前在哪个分支 +git branch +# 3. 切换 分支 +git checkout new_althorithm/sympo +git status # 确认没有未提交的修改 +# 如果还有修改,先提交: +git add . +git commit -m "feat: 添加 SymPO 算法" + +# 2. 推送到远程(关键!) +git push origin new_althorithm/sympo \ No newline at end of file diff --git a/swift/rlhf_trainers/grpo_trainer.py b/swift/rlhf_trainers/grpo_trainer.py index a4e79446ec..ac5bc655ac 100644 --- a/swift/rlhf_trainers/grpo_trainer.py +++ b/swift/rlhf_trainers/grpo_trainer.py @@ -1,60 +1,39 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/trl. - -# fmt: off -# apply patch before importing trl, which may internally reference GuidedDecodingParams -try: - import vllm - try: - from vllm.sampling_params import GuidedDecodingParams - except ImportError: - import vllm.sampling_params - - # removed in https://github.com/vllm-project/vllm/pull/22772 - vllm.sampling_params.GuidedDecodingParams = vllm.sampling_params.StructuredOutputsParams -except ImportError: - pass -# fmt: on - -import asyncio -import atexit import concurrent.futures import inspect import os import time +from collections import defaultdict, deque +from contextlib import contextmanager, nullcontext +from copy import copy, deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + import torch import torch.distributed as dist import torch.nn as nn import transformers from accelerate.utils import gather, gather_object, is_peft_model, set_seed -from collections import defaultdict, deque -from contextlib import contextmanager, nullcontext -from copy import copy, deepcopy from packaging import version from transformers import PreTrainedModel -from transformers.trainer import Trainer as HfTrainer +from transformers.trainer import Trainer from trl import GRPOTrainer as HFGRPOTrainer from trl.models import prepare_deepspeed from trl.trainer import grpo_trainer from trl.trainer.callbacks import SyncRefModelCallback -from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin +from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin, nanstd from trl.trainer.utils import selective_log_softmax -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from swift.dataset import RowPreprocessor -from swift.infer_engine import TransformersEngine -from swift.rewards import orms, rm_plugins -from swift.sequence_parallel import GatherLoss, sequence_parallel -from swift.template import Template, TemplateInputs -from swift.trainers import SwiftMixin, disable_gradient_checkpointing -from swift.utils import (JsonlWriter, get_cu_seqlens_from_position_ids, get_logger, is_swanlab_available, - is_wandb_available, remove_response, seed_worker, shutdown_event_loop_in_daemon, - start_event_loop_in_daemon, to_device, unwrap_model_for_generation) -from .arguments import GRPOConfig +from swift.llm import RowPreprocessor, Template, to_device +from swift.llm.template.template_inputs import TemplateInputs +from swift.plugin import orms, rm_plugins +from swift.utils import (JsonlWriter, get_logger, is_swanlab_available, is_wandb_available, remove_response, + seed_worker, unwrap_model_for_generation) +from ..mixin import SwiftMixin from .rollout_mixin import DataType, RolloutTrainerMixin from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator, - load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch, patch_save_last_checkpoint, - profiling_context, profiling_decorator, replace_assistant_response_with_ids) + load_pil_img, make_chord_sft_dataset, patch_profiling_context, patch_profiling_decorator, + patch_save_last_checkpoint, replace_assistant_response_with_ids) try: from trl.trainer.utils import entropy_from_logits @@ -83,30 +62,33 @@ def __init__(self, *_args, **kwargs): patch_save_last_checkpoint() + from swift.trainers.rlhf_arguments import GRPOConfig args: GRPOConfig = kwargs['args'] self.args = args self.ref_adapter_name = getattr(args, 'ref_adapter_name', None) self.model_adapter_name = None self.is_multimodal = model.model_meta.is_multimodal + + model.warnings_issued['estimate_tokens'] = True + kwargs['data_collator'] = identity_data_collator # No data collation is needed in GRPO + self.model_kwarg_keys = ( inspect.signature(model.forward).parameters.keys() if not hasattr(model, 'get_base_model') else inspect.signature(model.get_base_model().forward).parameters.keys()) self.vllm_client = kwargs.pop('vllm_client', None) self.chord_sft_dataset = kwargs.pop('chord_sft_dataset', None) - reward_templates = kwargs.pop('reward_template', None) self._prepare_algorithm_params() super().__init__(model, ref_model, *_args, **kwargs) - self._prepare_chord_dataset() self.prepare_rollout() - self._prepare_rewards(reward_funcs, reward_model, reward_templates) + self._prepare_rewards(reward_funcs, reward_model, **kwargs) if not self.reward_funcs and not self.use_gym_env: raise ValueError('You must specify reward_funcs or reward_model') if self.args.eval_strategy != 'no': total_eval_batch_size = self.args.per_device_eval_batch_size * \ - self.accelerator.num_processes // self.num_generations_eval + self.accelerator.num_processes // self.args.num_generations assert len(self.eval_dataset) >= total_eval_batch_size, ( f'eval_dataset size {len(self.eval_dataset)} is smaller than ' f'total_eval_batch_size {total_eval_batch_size}. ' @@ -121,10 +103,11 @@ def __init__(self, set_seed(args.seed, device_specific=True) if not self.args.use_vllm: + from swift.llm import PtEngine infer_template = copy(self.template) infer_template.padding_free = False infer_template.sequence_parallel_size = 1 - self.engine = TransformersEngine(self.model, template=infer_template, max_batch_size=0) # 0: no limit + self.engine = PtEngine.from_model_template(self.model, infer_template, max_batch_size=0) # 0: no limit # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set @@ -140,6 +123,7 @@ def __init__(self, self.eval_flag = False if self.template.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel self.args.gradient_accumulation_steps = self.args.gradient_accumulation_steps * sequence_parallel.world_size # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs @@ -152,13 +136,10 @@ def __init__(self, # Buffer the batch to reuse generated outputs across multiple updates. For more details, see # `_get_train_sampler` and `_prepare_inputs`. self._buffered_inputs = None - self._current_train_step_time = 0.0 - - def _get_data_collator(self, args, template): - return identity_data_collator def _get_train_sampler(self, train_dataset=None): if self.template.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel return RepeatSampler( data_source=train_dataset or self.train_dataset, mini_repeat_count=self.num_generations, @@ -170,7 +151,7 @@ def _get_train_sampler(self, train_dataset=None): else: return super()._get_train_sampler(train_dataset) - @profiling_decorator + @patch_profiling_decorator def _prepare_inputs(self, generation_batch: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: # Prepares inputs for model training/evaluation by managing completion generation and batch handling. @@ -199,6 +180,16 @@ def _prepare_inputs(self, generation_batch: Dict[str, Union[torch.Tensor, inputs = self._generate_and_score_completions(generation_batch) return inputs + @contextmanager + def _template_context(self, template: Template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + template.max_length = None + try: + yield + finally: + template.max_length = max_length + def _generate_completions(self, inputs: DataType) -> DataType: # add prompt ids and system prompts inputs = self._preprocess_inputs(inputs) @@ -213,12 +204,12 @@ def _generate_completions(self, inputs: DataType) -> DataType: results = self._infer_single_or_multi_turn(inputs, self.request_config) if mode == 'train': # In training mode, ensure the model is returned to train() mode after inference - # This is necessary as transformers engines set the model to eval mode during generation + # This is necessary as pt engines set the model to eval mode during generation self.model.train() return results - @profiling_decorator + @patch_profiling_decorator def _generate_and_score_completions(self, inputs: DataType) -> DataType: # resample for encoding failed data when set truncation_strategy 'delete' if self.template.truncation_strategy == 'raise': @@ -249,11 +240,15 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: f'Mismatch: {len(gas_chunks)} chunks vs {len(batch_encoded_inputs)} batches' for batch, batch_encoded in zip(gas_chunks, batch_encoded_inputs): - # Advantages are always [batch_size], will be broadcast to [batch_size, seq_len] in loss computation - all_advantages = torch.stack([data['advantages'] for data in batch]) + if self.template.padding_free: + lengths = batch_encoded['seq_lengths'] + advantages_stacked = torch.stack([data['advantages'] for data in batch]) + all_advantages = torch.repeat_interleave(advantages_stacked, lengths) + else: + all_advantages = torch.stack([data['advantages'] for data in batch]) batch_encoded['advantages'] = all_advantages - with profiling_context(self, 'log_metrics'): + with patch_profiling_context(self, 'log_metrics'): # --- logs (prompts + completions) --- messages = [inp['messages'][:-1] for inp in inputs] completions = [deepcopy(inp['messages'][-1]['content']) for inp in inputs] @@ -290,7 +285,7 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: return batch_encoded_inputs - @profiling_decorator + @patch_profiling_decorator def _score_completions(self, inputs: DataType) -> torch.Tensor: """Score completions using all reward functions. @@ -328,64 +323,28 @@ def _compute_rewards_per_func(self, inputs: DataType) -> torch.Tensor: device = self.accelerator.device rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) completions = [inp['messages'][-1]['content'] for inp in inputs] - - # Common reward kwargs - reward_kwargs = {'trainer_state': self.state} - reward_inputs = [{k: v for k, v in inp.items() if k != 'add_eos'} for inp in inputs] - if self.enable_server_multi_turn: - trajectory_inputs = self._get_trajectory_inputs(inputs) - reward_kwargs.update({'trajectory_inputs': trajectory_inputs}) - reward_kwargs.update(RowPreprocessor.rows_to_batched(reward_inputs)) - - # Use pre-computed indices for async reward functions - async_indices_set = set(self._async_reward_func_indices) - for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)): - template = None if not hasattr(reward_model_plugin, 'template') else reward_model_plugin.template - with profiling_context(self, reward_func_name), self._disable_sp_context(template): - # Reward model (nn.Module) + with patch_profiling_context(self, reward_func_name): + # reward model + reward_kwargs = {'trainer_state': self.state} + if self.enable_server_multi_turn: + trajectory_inputs = self._get_trajectory_inputs(inputs) + reward_kwargs.update({'trajectory_inputs': trajectory_inputs}) if isinstance(reward_func, nn.Module): - output_reward_func = reward_model_plugin(inputs=reward_inputs, **reward_kwargs) - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - # Async reward function - skip here, will be executed in parallel later - elif i in async_indices_set: - pass - # Synchronous reward function + output_reward_func = reward_model_plugin(inputs=inputs, **reward_kwargs) + # reward function else: + # Repeat all input columns (but "messages" and "completion") to match the number of generations + reward_kwargs.update(RowPreprocessor.rows_to_batched(inputs)) output_reward_func = reward_func(completions, **reward_kwargs) - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - # Execute async reward functions in parallel using asyncio.gather - # Process in original order to maintain correspondence with reward_func_names - if self._async_reward_func_indices: - - async def _invoke_async_reward(index): - func = self.reward_funcs[index] - func_name = self.reward_func_names[index] - with profiling_context(self, func_name): - output = await func(completions, **reward_kwargs) - output = [r if r is not None else torch.nan for r in output] - return index, output - - async def _run_async_funcs(): - # Maintain order by processing indices in sequence - coros = [_invoke_async_reward(idx) for idx in self._async_reward_func_indices] - return await asyncio.gather(*coros) - - async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_reward_loop).result() - for idx, output_reward_func in async_results: - rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) # If all reward functions return None for a given row, issue a detailed warning if torch.isnan(rewards_per_func).all(dim=1).any(): nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] - row_reward_kwargs = { - key: value[nan_row_idx] - for key, value in reward_kwargs.items() if key != 'trainer_state' - } + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} row_reward_kwargs['completion'] = completions[nan_row_idx] logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' 'Please ensure that at least one reward function returns a valid reward.') @@ -425,24 +384,22 @@ def normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor): """Log reward statistics for monitoring. Only log once per unique request_id.""" - # rewards: [prompt_batch_size, num_generations] - # rewards_per_func_for_metrics: [prompt_batch_size*num_generations, self.num_reward_funcs] + # rewards: [prompt_batch_size, self.num_generations] + # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs] mode = 'train' if self.model.training else 'eval' - num_generations = self.num_generations if mode == 'train' else self.num_generations_eval - group_rewards = rewards.view(-1, num_generations) - rewards_mean = group_rewards.mean(-1).mean().item() - if self.scale_rewards in ['group', 'none', 'gdpo']: - # Handle edge case when num_generations_eval=1 - if num_generations > 1: + if self.num_generations > 1: + group_rewards = rewards.view(-1, self.num_generations) + rewards_mean = group_rewards.mean(-1).mean().item() + if self.scale_rewards in ['group', 'none']: rewards_std = group_rewards.std(-1).mean().item() - else: - rewards_std = 0.0 - elif self.scale_rewards == 'batch': - rewards_std = rewards.std().item() if rewards.numel() > 1 else 0.0 - if num_generations > 1: + elif self.scale_rewards == 'batch': + rewards_std = rewards.std().item() is_std_zero = torch.isclose(group_rewards.std(dim=1), torch.zeros_like(group_rewards.std(dim=1))) else: - is_std_zero = torch.ones(group_rewards.size(0), dtype=torch.bool, device=group_rewards.device) + # Single generation mode (REINFORCE) + rewards_mean = rewards.mean().item() + rewards_std = rewards.std().item() if len(rewards) > 1 else 1.0 + is_std_zero = torch.tensor([False]) self._metrics[mode]['reward'].append(rewards_mean) self._metrics[mode]['reward_std'].append(rewards_std) @@ -469,24 +426,40 @@ def log_rewards_all(rewards_per_func: torch.Tensor): old_per_token_logps = batch_encoded['old_per_token_logps'] ref_per_token_logps = batch_encoded['ref_per_token_logps'] completion_mask = batch_encoded['completion_mask'] - per_token_kl = old_per_token_logps - ref_per_token_logps - kl = (per_token_kl * completion_mask).sum(-1) + if self.template.padding_free: + lengths = batch_encoded['seq_lengths'] + per_token_kl = torch.split(old_per_token_logps - ref_per_token_logps, lengths.tolist(), dim=1) + completion_masks = torch.split(completion_mask, lengths.tolist(), dim=1) + kl = torch.cat([(kl * mask).sum(-1) for kl, mask in zip(per_token_kl, completion_masks)]) + else: + per_token_kl = old_per_token_logps - ref_per_token_logps + kl = (per_token_kl * completion_mask).sum(-1) kl_list.append(kl) kl = torch.cat(kl_list, dim=0) kl = gather(kl) mode = 'train' if self.model.training else 'eval' self._metrics[mode]['kl'].append(kl.nanmean().item()) - rewards = rewards - self.beta * kl + #rewards = rewards - self\.beta \* kl + rewards = rewards + # Handle single generation case (REINFORCE) + if self.num_generations == 1: + # For REINFORCE, advantages are just the rewards themselves + # + advantages = rewards + + # Log metrics + log_rewards_metrics(rewards=rewards, rewards_per_func_for_metrics=rewards_per_func) + log_rewards_all(rewards_per_func) + + return advantages # -------------------------------------------------- # Case 1: Default grouped mode # -------------------------------------------------- - mode = 'train' if self.model.training else 'eval' - num_generations = self.num_generations if mode == 'train' else self.num_generations_eval if not self.dynamic_num_samples: - grouped_rewards = rewards.view(-1, num_generations) - K = num_generations + grouped_rewards = rewards.view(-1, self.num_generations) + K = self.num_generations # Compute group statistics group_rewards_mean = grouped_rewards.mean(dim=1) @@ -499,11 +472,7 @@ def log_rewards_all(rewards_per_func: torch.Tensor): # RLOO: Leave-One-Out baseline # A_i = r_i - mean(r_j for j != i) # = r_i * K/(K-1) - mean_all * K/(K-1) - # Edge case: when K=1 (e.g., num_generations_eval=1), fall back to simple advantage - if K > 1: - advantages = rewards * K / (K - 1) - group_rewards_mean * K / (K - 1) - else: - advantages = rewards - group_rewards_mean + advantages = rewards * K / (K - 1) - group_rewards_mean * K / (K - 1) else: # 'grpo' or 'reinforce_plus_plus' # Both use group mean as baseline advantages = rewards - group_rewards_mean @@ -514,17 +483,11 @@ def log_rewards_all(rewards_per_func: torch.Tensor): if self.scale_rewards == 'batch': # Global whitening: std computed on advantages # Note: advantages.mean() is mathematically 0, no need to subtract - if advantages.numel() > 1: - advantages_std = advantages.std().expand_as(advantages) - else: # edge case: num_generations_eval=batch_size=1 - advantages_std = torch.zeros_like(advantages) + advantages_std = advantages.std().expand_as(advantages) elif self.scale_rewards == 'group': # Group-level whitening on advantages advantages_grouped = advantages.view(-1, K) - if K > 1: - advantages_std = advantages_grouped.std(dim=1).repeat_interleave(K) - else: # edge case: num_generations_eval=1 - advantages_std = torch.zeros_like(advantages) + advantages_std = advantages_grouped.std(dim=1).repeat_interleave(K) else: # 'none' advantages_std = None if advantages_std is not None: @@ -532,26 +495,9 @@ def log_rewards_all(rewards_per_func: torch.Tensor): else: # 'grpo' or 'rloo' # GRPO/RLOO: Use std of original rewards if self.scale_rewards == 'batch': - if rewards.numel() > 1: - rewards_std = rewards.std().expand_as(rewards) - else: # edge case: num_generations_eval=batch_size=1 - rewards_std = torch.zeros_like(rewards) + rewards_std = rewards.std().expand_as(rewards) elif self.scale_rewards == 'group': - if K > 1: - rewards_std = grouped_rewards.std(dim=1).repeat_interleave(K) - else: # edge case: num_generations_eval=1 - rewards_std = torch.zeros_like(rewards) - elif self.scale_rewards == 'gdpo': - grouped = rewards_per_func.view(-1, K, rewards_per_func.shape[1]) - group_mean = torch.nanmean(grouped, dim=1, keepdim=True) - group_std = nanstd(grouped, dim=1, keepdim=True) if K > 1 else torch.zeros_like(group_mean) - normalized = (grouped - group_mean) / (group_std + 1e-8) - normalized = torch.nan_to_num(normalized, nan=0.0) - normalized = normalized.view(-1, rewards_per_func.shape[1]) - advantages = (normalized * self.reward_weights.unsqueeze(0)).sum(dim=1) - batch_std = advantages.std() + 1e-8 - advantages = (advantages - advantages.mean()) / batch_std - rewards_std = None + rewards_std = grouped_rewards.std(dim=1).repeat_interleave(K) else: # 'none' rewards_std = None if rewards_std is not None: @@ -606,11 +552,7 @@ def log_rewards_all(rewards_per_func: torch.Tensor): idx_tensor = torch.tensor(idxs, device=device) r_group = unique_rewards[idx_tensor] # A_i = r_i * K/(K-1) - mean * K/(K-1) - # Edge case: when K=1, fall back to simple advantage - if K > 1: - request_advantages[idx_tensor] = (r_group * K / (K - 1) - r_group.mean() * K / (K - 1)) - else: - request_advantages[idx_tensor] = r_group - r_group.mean() + request_advantages[idx_tensor] = (r_group * K / (K - 1) - r_group.mean() * K / (K - 1)) else: # 'grpo' or 'reinforce_plus_plus' # Both use group mean as baseline request_advantages = unique_rewards - prompt_means @@ -621,10 +563,7 @@ def log_rewards_all(rewards_per_func: torch.Tensor): if self.scale_rewards == 'batch': # Global whitening: std computed on advantages # Note: advantages.mean() is mathematically 0, no need to subtract - if request_advantages.numel() > 1: - advantages_std = request_advantages.std() - else: - advantages_std = torch.tensor(0.0, device=device) + advantages_std = request_advantages.std() prompt_stds = torch.full_like(request_advantages, advantages_std) elif self.scale_rewards == 'group': # Group-level whitening on advantages @@ -632,8 +571,7 @@ def log_rewards_all(rewards_per_func: torch.Tensor): for pid, idxs in prompt_to_indices.items(): idx_tensor = torch.tensor(idxs, device=device) adv_group = request_advantages[idx_tensor] - # Edge case: when group size is 1 - prompt_stds[idx_tensor] = adv_group.std() if len(idxs) > 1 else 0.0 + prompt_stds[idx_tensor] = adv_group.std() else: # 'none' prompt_stds = None if prompt_stds is not None: @@ -641,18 +579,14 @@ def log_rewards_all(rewards_per_func: torch.Tensor): else: # 'grpo' or 'rloo' # GRPO/RLOO: Use std of original rewards if self.scale_rewards == 'batch': - if unique_rewards.numel() > 1: - rewards_std = unique_rewards.std() - else: - rewards_std = torch.tensor(0.0, device=device) + rewards_std = unique_rewards.std() prompt_stds = torch.full_like(unique_rewards, rewards_std) elif self.scale_rewards == 'group': prompt_stds = torch.zeros(len(unique_rewards), device=device) for pid, idxs in prompt_to_indices.items(): idx_tensor = torch.tensor(idxs, device=device) r_group = unique_rewards[idx_tensor] - # Edge case: when group size is 1 - prompt_stds[idx_tensor] = r_group.std() if len(idxs) > 1 else 0.0 + prompt_stds[idx_tensor] = r_group.std() else: # 'none' prompt_stds = None if prompt_stds is not None: @@ -671,7 +605,7 @@ def log_rewards_all(rewards_per_func: torch.Tensor): return advantages - @profiling_decorator + @patch_profiling_decorator def _dynamic_sampling(self, inputs, rewards_per_func): """ Perform dynamic sampling to replace samples with zero-reward-variance groups. @@ -705,7 +639,7 @@ def _dynamic_sampling(self, inputs, rewards_per_func): inputs = next(self.dynamic_resample_iterator) if self.template.truncation_strategy == 'raise': inputs = self.resample_encode_failed_inputs(inputs) - inputs = HfTrainer._prepare_inputs(self, inputs) + inputs = Trainer._prepare_inputs(self, inputs) inputs = self._generate_completions(inputs) rewards_per_func = self._score_completions(inputs) resample_count += 1 @@ -728,15 +662,9 @@ def compute_std(self, inputs: DataType, rewards_per_func: torch.Tensor) -> torch device = self.accelerator.device rewards = (rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) - mode = 'train' if self.model.training else 'eval' - num_generations = self.num_generations if mode == 'train' else self.num_generations_eval if not self.dynamic_num_samples: - grouped_rewards = rewards.view(-1, num_generations) - # Handle edge case when num_generations_eval=1 - if num_generations > 1: - group_rewards_std = grouped_rewards.std(dim=1).repeat_interleave(num_generations) - else: - group_rewards_std = torch.zeros_like(rewards) + grouped_rewards = rewards.view(-1, self.num_generations) + group_rewards_std = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations) return group_rewards_std else: prompt_ids = gather_object([inp['prompt_id'] for inp in inputs]) @@ -755,8 +683,7 @@ def compute_std(self, inputs: DataType, rewards_per_func: torch.Tensor) -> torch for pid, idxs in prompt_to_indices.items(): idx_tensor = torch.tensor(idxs, device=device) r_group = unique_rewards[idx_tensor] - # Edge case: when group size is 1 - prompt_stds[idx_tensor] = r_group.std() if len(idxs) > 1 else 0.0 + prompt_stds[idx_tensor] = r_group.std() rid_to_idx = {rid: idx for idx, rid in enumerate(unique_request_ids)} indices_in_unique = torch.tensor([rid_to_idx[r] for r in request_ids], device=device) rewards_std = prompt_stds[indices_in_unique] @@ -799,6 +726,7 @@ def split_by_mini_batches(self, inputs: DataType) -> List[DataType]: return spg_chunks else: + from swift.trainers.sequence_parallel import sequence_parallel """Split by mini batches for GRPO sequence parallel training""" output = [None] * sequence_parallel.sp_world_size # gather inputs within a sp group @@ -846,7 +774,7 @@ def null_ref_context(self): if self.ref_adapter_name: self.model.set_adapter(self.model_adapter_name or 'default') - @profiling_decorator + @patch_profiling_decorator def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: """ Prepare the final batch inputs with ref/old_policy logps and other fields for RL training. @@ -878,14 +806,9 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: # Process labels and masks labels = batch_encoded_inputs.pop('labels') logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() - batch_size = len(batch) - - # Create completion_mask - # In padding_free mode: labels shape is [1, total_seq_len] (rmpad format) - # In non-padding_free mode: labels shape is [batch_size, seq_len] (batch format) - completion_mask_raw = labels[:, -logits_to_keep:] != -100 - extra_kwargs = { + 'completion_mask': + labels[:, -logits_to_keep:] != -100, 'truncated_mask': torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool, device=self.accelerator.device), 'logits_to_keep': @@ -904,109 +827,34 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: # The first sentence has its prompt portion removed due to logits_to_keep lengths[0] = lengths[0] - (total_lengths - logits_to_keep) extra_kwargs.update({'seq_lengths': lengths}) - - # In padding_free mode, completion_mask_raw is [1, logits_to_keep] (rmpad format) - # Pad it back to [batch_size, logits_to_keep] for consistency with per_token_logps - completion_mask, _ = pad_logps_back_to_batch( - logps_rmpad=completion_mask_raw.float(), - logits_to_keep=logits_to_keep, - batch_size=batch_size, - seq_lengths=lengths, - pad_value=0.0) - completion_mask = completion_mask.bool() - else: - # In non-padding_free mode, completion_mask is already [batch_size, logits_to_keep] - completion_mask = completion_mask_raw - - extra_kwargs['completion_mask'] = completion_mask batch_encoded_inputs.update(extra_kwargs) - with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + with torch.no_grad(): batch_encoded_inputs['old_per_token_logps'] = ( - self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0]) + self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] + if self.old_policy() or self.kl_in_reward else None) if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: - with disable_gradient_checkpointing(self.ref_model, self.args.gradient_checkpointing_kwargs): - ref_per_token_logps = \ - self._get_per_token_logps_and_entropies(self.ref_model, batch_encoded_inputs)[0] + ref_per_token_logps = \ + self._get_per_token_logps_and_entropies(self.ref_model, batch_encoded_inputs)[0] else: with self.null_ref_context(): ref_per_token_logps = \ self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps - # Extract rollout logprobs if available for importance sampling - # rollout_logprobs is List[List[float]] - nested list where each inner list corresponds to - # one assistant response turn. We need to align these with completion_mask positions. - batch_encoded_inputs['rollout_per_token_logps'] = None - should_compute_rollout_logprobs = ( - self.rollout_importance_sampling_mode is not None or self.log_rollout_offpolicy_metrics) - - if self.use_fast_infer and should_compute_rollout_logprobs: - rollout_logprobs_list = [] - for data in batch: - if 'rollout_logprobs' in data and data['rollout_logprobs']: - rollout_logprobs_list.append(data['rollout_logprobs']) - else: - rollout_logprobs_list.append(None) - - # Convert to tensor if all samples have rollout_logprobs - completion_mask = batch_encoded_inputs['completion_mask'] - if all(lp is not None for lp in rollout_logprobs_list): - # Validate that logprobs count matches completion tokens count - valid_logprobs = True - for i, nested_lp in enumerate(rollout_logprobs_list): - total_logprobs = sum(len(turn_lps) for turn_lps in nested_lp) - completion_count = int(completion_mask[i].sum().item()) - - if total_logprobs != completion_count: - logger.warning(f'Rollout logprobs count ({total_logprobs}) does not match ' - f'completion tokens count ({completion_count}). ' - f'Skipping rollout importance sampling for this batch.') - valid_logprobs = False - break - - if valid_logprobs: - # Align rollout_logprobs with completion_mask for each sample - batch_size = completion_mask.shape[0] - seq_len = completion_mask.shape[1] - - # Initialize with zeros (for prompt positions) - rollout_logps_tensor = torch.zeros( - batch_size, seq_len, dtype=torch.float32, device=self.accelerator.device) - - for i, nested_lp in enumerate(rollout_logprobs_list): - # Flatten logprobs for this sample - flat_lps = [lp for turn_lps in nested_lp for lp in turn_lps] - if flat_lps: - # Check for None values in flat_lps - if any(lp is None for lp in flat_lps): - logger.warning('Found None values in rollout_logprobs. ' - 'Skipping rollout importance sampling for this batch.') - rollout_logps_tensor = None - break - # Get indices where completion_mask is True - completion_indices = completion_mask[i].nonzero(as_tuple=True)[0] - # Scatter logprobs to completion positions - rollout_logps_tensor[i, completion_indices] = torch.tensor( - flat_lps, dtype=torch.float32, device=self.accelerator.device) - - batch_encoded_inputs['rollout_per_token_logps'] = rollout_logps_tensor - ga_batch_encoded_inputs.append(batch_encoded_inputs) # --- log completion lengths --- mode = 'train' if self.model.training else 'eval' device = self.accelerator.device - local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs] + if self.template.padding_free: + local_lengths = [inp['seq_lengths'].tolist() for inp in ga_batch_encoded_inputs] + else: + local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs] total_lengths = self._gather_and_flatten(local_lengths, dtype=torch.float32, device=device, flatten_level=1) - # Store num_items_in_batch for DAPO loss (total completion tokens across all processes) - num_items_in_batch = total_lengths.sum() - for batch_encoded in ga_batch_encoded_inputs: - batch_encoded['num_items_in_batch'] = num_items_in_batch - self._metrics[mode]['completions/mean_length'].append(total_lengths.mean().item()) self._metrics[mode]['completions/min_length'].append(total_lengths.min().item()) self._metrics[mode]['completions/max_length'].append(total_lengths.max().item()) @@ -1049,7 +897,7 @@ def _apply_chat_template_to_messages_list(self, messages_list: DataType): prompts_text.append(self.template.safe_decode(res['input_ids'])) return prompts_text - @profiling_decorator + @patch_profiling_decorator def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Compute the per-token log probabilities for the model, return_outputs=True in mini-batch training if isinstance(inputs, list): @@ -1084,8 +932,11 @@ def _compute_loss_single(self, model, inputs): def _compute_loss_and_metrics(self, model, inputs): """Core loss computation without metrics recording.""" mode = 'train' if self.model.training else 'eval' + completion_mask = inputs['completion_mask'] truncated_mask = inputs['truncated_mask'] + if self.template.padding_free: + lengths = inputs['seq_lengths'] per_token_logps, entropies = self._get_per_token_logps_and_entropies( model, inputs, compute_entropy=self.compute_entropy) @@ -1096,7 +947,11 @@ def _compute_loss_and_metrics(self, model, inputs): # fill the padded token with NaN entropies = entropies.masked_fill(completion_mask == 0, float('nan')) if self.args.log_entropy: - per_completion_entropies_mean = torch.nanmean(entropies, dim=1) + if self.template.padding_free: + entropy_list = torch.split(entropies, lengths.tolist()) + per_completion_entropies_mean = torch.stack([torch.nanmean(e) for e in entropy_list]) + else: + per_completion_entropies_mean = torch.nanmean(entropies, dim=1) global_per_completion_entropies_mean = gather(per_completion_entropies_mean) entropy_metrics = { 'entropy_logs': global_per_completion_entropies_mean.tolist(), @@ -1116,7 +971,11 @@ def _compute_loss_and_metrics(self, model, inputs): if all(truncated_mask): logger.info('All completions are overlong and truncated, ' 'resulting in NaN some values for some metrics (e.g., KL)') - truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask) + if self.template.padding_free: + truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) + assert truncated_mask.shape == completion_mask.shape + else: + truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask) completion_mask = completion_mask & (~truncated_mask) # Compute the KL divergence between the model and the reference model @@ -1135,106 +994,78 @@ def _compute_loss_and_metrics(self, model, inputs): old_per_token_logps = ( per_token_logps.detach() if inputs['old_per_token_logps'] is None else inputs['old_per_token_logps']) - # Compute rollout diagnostic metrics and apply IS correction if enabled - rollout_correction_metrics = {} - should_compute_rollout_metrics = ( - self.rollout_importance_sampling_mode is not None or self.log_rollout_offpolicy_metrics) - - local_has_rollout_per_token_logps = inputs.get('rollout_per_token_logps') is not None - all_has_rollout_per_token_logps = gather_object([local_has_rollout_per_token_logps]) - - should_compute_rollout_metrics = should_compute_rollout_metrics and all(all_has_rollout_per_token_logps) - if (not self.disable_rollout_importance_sampling and should_compute_rollout_metrics): - rollout_per_token_logps = inputs['rollout_per_token_logps'] - - # Compute diagnostic metrics (KL, PPL, etc.) for monitoring off-policy gap - rollout_correction_metrics = self._compute_rollout_offpolicy_metrics(old_per_token_logps, - rollout_per_token_logps, - completion_mask) - - rollout_log_ratio, rollout_is_weights = self._get_rollout_is_correction(old_per_token_logps, - rollout_per_token_logps, - completion_mask) - if rollout_log_ratio is not None: - is_metrics = self._compute_is_correction_metrics(rollout_log_ratio, rollout_is_weights, completion_mask) - rollout_correction_metrics.update(is_metrics) - - inputs['rollout_is_weights'] = rollout_is_weights - else: - inputs['rollout_is_weights'] = None - log_ratio = per_token_logps - old_per_token_logps if self.importance_sampling_level == 'token': log_importance_weights = log_ratio elif self.importance_sampling_level in ['sequence', 'sequence_token']: - seq_level_log_weights = ((log_ratio * completion_mask).sum(-1) - / completion_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1) - if self.importance_sampling_level == 'sequence': - log_importance_weights = seq_level_log_weights + if self.template.padding_free: + # split to batch, compute seq-level normalization + log_ratio_list = torch.split(log_ratio.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)] + seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1) + if self.importance_sampling_level == 'sequence': + log_importance_weights = seq_level_log_weights + else: + seq_level_log_weight = seq_level_log_weights.detach() + seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, lengths).unsqueeze(0) + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: - # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] - seq_level_log_weight = seq_level_log_weights.detach() - log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + seq_level_log_weights = ((log_ratio * completion_mask).sum(-1) + / completion_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1) + if self.importance_sampling_level == 'sequence': + log_importance_weights = seq_level_log_weights + else: + # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] + seq_level_log_weight = seq_level_log_weights.detach() + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + else: raise ValueError( f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " "and 'sequence'.") coef_1 = torch.exp(log_importance_weights) + # 移除clip部分,只保留核心GRPO loss + # coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) - if self.loss_type == 'cispo': - clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() - per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps - elif self.loss_type == 'sapo': - advantages_expanded = advantages.unsqueeze(1) - gate_pos = torch.sigmoid(self.tau_pos * (coef_1 - 1)) * (4.0 / self.tau_pos) - gate_neg = torch.sigmoid(self.tau_neg * (coef_1 - 1)) * (4.0 / self.tau_neg) - is_positive = advantages_expanded > 0 - soft_gate = torch.where(is_positive, gate_pos, gate_neg) - - per_token_loss = -soft_gate * advantages_expanded - elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) - - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.template.padding_free: + if self.importance_sampling_level == 'sequence': + # Expand sequence-level weights to token-level + coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths).unsqueeze(0) + # coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths).unsqueeze(0) + + advantages = advantages[-coef_1.shape[1]:] + per_token_loss = -coef_1 * advantages.unsqueeze(0) + else: + per_token_loss = -coef_1 * advantages.unsqueeze(1) if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask - if per_token_kl is not None: - per_token_loss = per_token_loss + self.beta * per_token_kl - - # Apply vLLM importance sampling weights if available - if inputs.get('rollout_is_weights') is not None and self.rollout_importance_sampling_mode is not None: - rollout_is_weights = inputs['rollout_is_weights'] - per_token_loss = per_token_loss * rollout_is_weights - - # Apply off-policy sequence masking if enabled - # Mask out sequences where delta > threshold AND advantage < 0 - if self.off_policy_sequence_mask_delta is not None: - rollout_per_token_logps = inputs.get('rollout_per_token_logps') - old_policy_per_token_logps = rollout_per_token_logps if rollout_per_token_logps is not None \ - else old_per_token_logps - off_policy_seq_mask = self._compute_off_policy_sequence_mask(per_token_logps, old_policy_per_token_logps, - completion_mask, advantages) - # Expand sequence mask to token level and apply to completion_mask - off_policy_seq_mask_expanded = off_policy_seq_mask.unsqueeze(-1).expand_as(completion_mask) - completion_mask = completion_mask & off_policy_seq_mask_expanded - - if self.loss_type in ['grpo', 'sapo']: - # completion_mask is now always [batch_size, seq_len] after pad_back - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + # 移除KL loss,只保留GRPO核心loss + # if per_token_kl is not None: + # per_token_loss = per_token_loss + self.beta * per_token_kl + + print(f"DEBUG: loss_type={self.loss_type}, padding_free={self.template.padding_free}") + print(f"DEBUG: per_token_loss.shape={per_token_loss.shape}, completion_mask.shape={completion_mask.shape}") + if self.template.padding_free: + print(f"DEBUG: lengths.shape={lengths.shape}") + + if self.loss_type == 'grpo': + if self.template.padding_free: + loss_list = torch.split(per_token_loss.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) + for loss, mask in zip(loss_list, mask_list)] + loss = torch.stack(sample_loss).mean() + else: + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() elif self.loss_type == 'bnpo': loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': - batch_size = completion_mask.shape[0] + batch_size = lengths.shape[0] if self.template.padding_free else inputs['input_ids'].shape[0] loss = (per_token_loss * completion_mask).sum() / (batch_size * self.max_completion_length) - elif self.loss_type in ['cispo', 'dapo']: - # CISPO and DAPO: Normalize by total completion tokens across all processes - normalizer = inputs['num_items_in_batch'] / self.accelerator.num_processes - loss = (per_token_loss * completion_mask).sum() / normalizer else: raise ValueError(f'Unknown loss type: {self.loss_type}') @@ -1259,39 +1090,26 @@ def masked_batch_mean(x): mean_kl = masked_batch_mean(per_token_kl) metrics_data['kl'] = self.accelerator.gather_for_metrics(mean_kl).nanmean().item() - # Add rollout correction metrics - if rollout_correction_metrics: - metrics_data['rollout_correction'] = rollout_correction_metrics - # Compute the clipped probability ratios - if self.loss_type == 'cispo': - # CISPO: Only track upper bound clipping - is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) - cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) - gathered_cispo_clip_ratio = self.accelerator.gather_for_metrics(cispo_clip_ratio) - metrics_data['clipping'] = {'cispo_clip_ratio': gathered_cispo_clip_ratio.nanmean().item()} - elif self.loss_type == 'sapo': - pass - else: - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) - is_region_clipped = is_low_clipped | is_high_clipped - - low_clip = masked_batch_mean(is_low_clipped.float()) - high_clip = masked_batch_mean(is_high_clipped.float()) - clip_ratio = masked_batch_mean(is_region_clipped.float()) - - gathered_low_clip = self.accelerator.gather_for_metrics(low_clip) - gathered_high_clip = self.accelerator.gather_for_metrics(high_clip) - gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio) - - metrics_data['clipping'] = { - 'low_clip_mean': gathered_low_clip.nanmean().item(), - 'low_clip_min': nanmin(gathered_low_clip).item(), - 'high_clip_mean': gathered_high_clip.nanmean().item(), - 'high_clip_max': nanmax(gathered_high_clip).item(), - 'region_clip_mean': gathered_clip_ratio.nanmean().item() - } + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather_for_metrics(low_clip) + gathered_high_clip = self.accelerator.gather_for_metrics(high_clip) + gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio) + + metrics_data['clipping'] = { + 'low_clip_mean': gathered_low_clip.nanmean().item(), + 'low_clip_min': nanmin(gathered_low_clip).item(), + 'high_clip_mean': gathered_high_clip.nanmean().item(), + 'high_clip_max': nanmax(gathered_high_clip).item(), + 'region_clip_mean': gathered_clip_ratio.nanmean().item() + } if mode == 'train' and self.chord_sft_iterator is not None: loss = compute_chord_loss(self, grpo_loss=loss) @@ -1316,24 +1134,14 @@ def _update_metrics(self, metrics_data): if 'kl' in metrics_data: self._metrics[mode]['kl'].append(metrics_data['kl']) - # Update vLLM correction metrics - if 'rollout_correction' in metrics_data: - rollout_metrics = metrics_data['rollout_correction'] - for key, value in rollout_metrics.items(): - self._metrics[mode][f'rollout_correction/{key}'].append(value) - # Update clipping metrics if 'clipping' in metrics_data: clipping = metrics_data['clipping'] - if 'cispo_clip_ratio' in clipping: - # CISPO - self._metrics[mode]['cispo_clip_ratio'].append(clipping['cispo_clip_ratio']) - else: - self._metrics[mode]['clip_ratio/low_mean'].append(clipping['low_clip_mean']) - self._metrics[mode]['clip_ratio/low_min'].append(clipping['low_clip_min']) - self._metrics[mode]['clip_ratio/high_mean'].append(clipping['high_clip_mean']) - self._metrics[mode]['clip_ratio/high_max'].append(clipping['high_clip_max']) - self._metrics[mode]['clip_ratio/region_mean'].append(clipping['region_clip_mean']) + self._metrics[mode]['clip_ratio/low_mean'].append(clipping['low_clip_mean']) + self._metrics[mode]['clip_ratio/low_min'].append(clipping['low_clip_min']) + self._metrics[mode]['clip_ratio/high_mean'].append(clipping['high_clip_mean']) + self._metrics[mode]['clip_ratio/high_max'].append(clipping['high_clip_max']) + self._metrics[mode]['clip_ratio/region_mean'].append(clipping['region_clip_mean']) def _compute_loss_chunked(self, model, inputs: DataType): """ @@ -1362,31 +1170,24 @@ def _compute_loss_chunked(self, model, inputs: DataType): start_idx = chunk_idx * new_chunk_size end_idx = min(start_idx + new_chunk_size, batch_size) - is_dummy = False if start_idx < batch_size: chunk_inputs = self.get_chunked_inputs(inputs, start_idx, end_idx) - chunk_weight = end_idx - start_idx - else: - is_dummy = True - chunk_weight = 0 - # Compute loss and metrics for this chunk + # Compute loss and metrics for this chunk (without updating global metrics) chunk_loss, chunk_metrics_data = self._compute_loss_and_metrics(model, chunk_inputs) + chunk_weight = end_idx - start_idx - if not is_dummy: + if start_idx < batch_size: losses.append(chunk_loss * chunk_weight) weights.append(chunk_weight) all_metrics_data.append((chunk_metrics_data, chunk_weight)) - else: - # # Add dummy loss to computation graph to trigger ZeRO-3 backward hooks - losses.append(chunk_loss * 0.0) # Compute weighted average loss total_weight = sum(weights) if total_weight > 0: final_loss = torch.stack(losses).sum() / total_weight else: - final_loss = torch.stack(losses).sum() + final_loss = torch.tensor(0.0, device=model.device) # Aggregate metrics across all chunks self._aggregate_and_update_metrics(all_metrics_data, mode) @@ -1401,7 +1202,6 @@ def _aggregate_and_update_metrics(self, all_metrics_data, mode): # Separate metrics by type for aggregation entropy_logs, entropy_stats, kl_values = [], [], [] clip_values = {'low': [], 'high': [], 'region': [], 'low_min': [], 'high_max': []} - cispo_clip_values = [] entropy_thresholds = [] for chunk_metrics, chunk_weight in all_metrics_data: @@ -1428,14 +1228,11 @@ def _aggregate_and_update_metrics(self, all_metrics_data, mode): if 'clipping' in chunk_metrics: clipping = chunk_metrics['clipping'] weight = chunk_tokens.item() if hasattr(chunk_tokens, 'item') else chunk_tokens - if 'cispo_clip_ratio' in clipping: - cispo_clip_values.append((clipping['cispo_clip_ratio'], weight)) - else: - clip_values['low'].append((clipping['low_clip_mean'], weight)) - clip_values['high'].append((clipping['high_clip_mean'], weight)) - clip_values['region'].append((clipping['region_clip_mean'], weight)) - clip_values['low_min'].append(clipping['low_clip_min']) - clip_values['high_max'].append(clipping['high_clip_max']) + clip_values['low'].append((clipping['low_clip_mean'], weight)) + clip_values['high'].append((clipping['high_clip_mean'], weight)) + clip_values['region'].append((clipping['region_clip_mean'], weight)) + clip_values['low_min'].append(clipping['low_clip_min']) + clip_values['high_max'].append(clipping['high_clip_max']) # Build aggregated metrics aggregated_metrics = {'mode': mode, 'entropy': {}} @@ -1443,8 +1240,8 @@ def _aggregate_and_update_metrics(self, all_metrics_data, mode): # Aggregate entropy if entropy_logs: # Directly update entropy logs + self._logs['entropy'].extend(entropy_logs) aggregated_metrics['entropy'] = { - 'entropy_logs': entropy_logs, 'entropy_mean': sum(s['mean'] for s in entropy_stats) / len(entropy_stats), 'entropy_max': max(s['max'] for s in entropy_stats), 'entropy_min': min(s['min'] for s in entropy_stats) @@ -1457,14 +1254,11 @@ def _aggregate_and_update_metrics(self, all_metrics_data, mode): aggregated_metrics['kl'] = sum(kl_values) / len(kl_values) # Aggregate clipping (token-weighted averages) - def weighted_avg(values): - return sum(v * w for v, w in values) / sum(w for _, w in values) - - if cispo_clip_values: - # CISPO specific metric - aggregated_metrics['clipping'] = {'cispo_clip_ratio': weighted_avg(cispo_clip_values)} - elif clip_values['low']: - # Two-sided clipping metrics + if clip_values['low']: + + def weighted_avg(values): + return sum(v * w for v, w in values) / sum(w for _, w in values) + aggregated_metrics['clipping'] = { 'low_clip_mean': weighted_avg(clip_values['low']), 'low_clip_min': min(clip_values['low_min']), @@ -1476,48 +1270,33 @@ def weighted_avg(values): # Update metrics self._update_metrics(aggregated_metrics) - def _unpad_logps_and_entropies(self, - logps: torch.Tensor, - entropies: Optional[torch.Tensor], - logits_to_keep: int, - batch_size: int, - seq_lengths: torch.Tensor, - compute_entropy: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Restore logps and entropies from rmpad format [1, total_nnz] to batch format [batch_size, max_seq_len]. - - Args: - logps: Per-token log probabilities in rmpad format [1, total_nnz] - entropies: Per-token entropies in rmpad format [1, total_nnz] or None - logits_to_keep: Number of tokens to keep per sequence - batch_size: Number of sequences in the batch - seq_lengths: Actual sequence lengths [batch_size] - compute_entropy: Whether entropy was computed - - Returns: - logps: Restored log probabilities [batch_size, logits_to_keep] - entropies: Restored entropies [batch_size, logits_to_keep] or None - """ - logps, _ = pad_logps_back_to_batch( - logps_rmpad=logps, logits_to_keep=logits_to_keep, batch_size=batch_size, seq_lengths=seq_lengths) - - if compute_entropy and entropies is not None: - entropies, _ = pad_logps_back_to_batch( - logps_rmpad=entropies, logits_to_keep=logits_to_keep, batch_size=batch_size, seq_lengths=seq_lengths) + def _get_per_token_logps_and_entropies_sp( + self, + model: torch.nn.Module, + inputs: 'DataType', + compute_entropy: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Get per token logps for GRPO sequence parallel training""" + try: + from trl.trainer.utils import selective_log_softmax + except ImportError: + raise ImportError('trl is required for GRPO training. Please install it with: pip install trl') - return logps, entropies + from swift.trainers.sequence_parallel.utils import GatherLoss + from swift.trainers.sequence_parallel import sequence_parallel - def _get_logps_via_sp(self, - model: torch.nn.Module, - inputs: 'DataType', - logits_to_keep: int, - input_ids: torch.Tensor, - compute_entropy: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Get per token logps via sequence parallel, returns rmpad format [1, total_nnz] for padding_free mode""" - model_inputs = self._prepare_model_inputs(inputs) - sequence_parallel.prepare_inputs(model_inputs) - with self._template_context(self.template, inputs): - output = model(**model_inputs) + # original logits to keep + logits_to_keep = inputs['logits_to_keep'] + input_ids = inputs['input_ids'] + inputs = { + k: v + for k, v in inputs.items() if k not in [ + 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', + 'truncated_mask' + ] + } + sequence_parallel.prepare_inputs(inputs) + with self._template_context(self.template): + output = model(**inputs) logits = output.logits # split input_ids to labels position_ids = sequence_parallel.real_position_ids @@ -1533,118 +1312,13 @@ def _get_logps_via_sp(self, entropies = entropy_from_logits(logits) entropies, _ = GatherLoss.apply(entropies, labels, 1, position_ids) - if self.template.padding_free: - # In padding_free mode, we need to extract completion tokens from gathered data. - # The behavior differs based on rp_world_size: - # - rp_world_size > 1: Each sequence is padded to world_size * 2 multiple (per-sequence padding) - # - rp_world_size == 1: Entire data is padded to world_size multiple (end padding only) - seq_lengths = inputs['seq_lengths'] - batch_size = seq_lengths.shape[0] - rp_world_size = sequence_parallel.rp_world_size - - if rp_world_size > 1: - # With ring parallel: GatherLoss pads each sequence to world_size * 2 multiple - # Data layout after gather: [seq1_data, seq1_padding, seq2_data, seq2_padding, ...] - # - Original data is at [offset:offset+orig_len] - # - Padding is at [offset+orig_len:offset+padded_len] - - # Get original sequence boundaries (before padding) - cu_seqlens_orig = get_cu_seqlens_from_position_ids(position_ids) - - # Get padded sequence boundaries (for offset calculation) - padded_position_ids = sequence_parallel.pad(position_ids, padding_value=-1, position_ids=position_ids) - cu_seqlens_padded = get_cu_seqlens_from_position_ids(padded_position_ids) - - result_logps = [] - result_entropies = [] if compute_entropy else None - gathered_logps = per_token_logps.squeeze(0) - gathered_entropies = entropies.squeeze(0) if compute_entropy else None - - offset = 0 - for i in range(batch_size): - # Original sequence length (before SP padding) - orig_len = (cu_seqlens_orig[i + 1] - cu_seqlens_orig[i]).item() - # Padded sequence length (multiple of world_size * 2) - padded_len = (cu_seqlens_padded[i + 1] - cu_seqlens_padded[i]).item() - # Actual completion tokens for this sequence - actual_len = seq_lengths[i].item() - - # Extract the last `actual_len` tokens from this sequence's ORIGINAL data region - # Due to label shifting (roll -1), per_token_logps[i] predicts token i+1 - # So completion tokens [prompt_len, total_len) have logps at [prompt_len-1, total_len-1) - seq_start = offset + orig_len - actual_len - 1 - seq_end = offset + orig_len - 1 - result_logps.append(gathered_logps[seq_start:seq_end]) - if compute_entropy: - result_entropies.append(gathered_entropies[seq_start:seq_end]) - - # Use padded_len for offset because gathered data includes padding - offset += padded_len - - per_token_logps = torch.cat(result_logps).unsqueeze(0) - if compute_entropy: - entropies = torch.cat(result_entropies).unsqueeze(0) - else: - # Without ring parallel (rp_world_size == 1): Simple gather with end padding only - # Use input_ids length directly as the authoritative original length - original_total_len = input_ids.shape[-1] - # Due to label shifting (roll -1), per_token_logps[i] predicts token i+1. - start_idx = original_total_len - logits_to_keep - 1 - end_idx = original_total_len - 1 - per_token_logps = per_token_logps[:, start_idx:end_idx] - if compute_entropy: - entropies = entropies[:, start_idx:end_idx] - else: - per_token_logps = per_token_logps[:, -logits_to_keep - 1:-1] - if compute_entropy: - entropies = entropies[:, -logits_to_keep - 1:-1] - + per_token_logps = per_token_logps[:, -logits_to_keep - 1:-1] + if compute_entropy: + entropies = entropies[:, -logits_to_keep - 1:-1] + # ignore the last token return per_token_logps, entropies - def _get_logps_via_local_forward(self, - model: torch.nn.Module, - inputs: 'DataType', - logits_to_keep: int, - input_ids: torch.Tensor, - compute_entropy: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Get per token logps via local forward pass, returns rmpad format [1, total_nnz] for padding_free mode""" - model_inputs = self._prepare_model_inputs(inputs) - if 'logits_to_keep' in self.model_kwarg_keys: - model_inputs['logits_to_keep'] = logits_to_keep + 1 - - # Forward pass - logits = model(**model_inputs).logits - - # Extract relevant portion and apply temperature - logits = logits[:, -(logits_to_keep + 1):-1, :] / self.temperature - input_ids_for_logps = input_ids[:, -logits_to_keep:] - - is_padding_free = self.template.padding_free - if is_padding_free: - # In padding_free mode, compute logps on flattened tensors - logits_rmpad = logits.squeeze(0) # [total_nnz, vocab_size] - input_ids_rmpad = input_ids_for_logps.squeeze(0) # [total_nnz] - - # Compute logps on rmpad tensors - logps = selective_log_softmax(logits_rmpad, input_ids_rmpad) # [total_nnz] - logps = logps.unsqueeze(0) # [1, total_nnz] - - # Compute entropy if needed - if compute_entropy: - entropies = entropy_from_logits(logits_rmpad) # [total_nnz] - entropies = entropies.unsqueeze(0) # [1, total_nnz] - else: - entropies = None - else: - logps = selective_log_softmax(logits, input_ids_for_logps) - if compute_entropy: - entropies = entropy_from_logits(logits) - else: - entropies = None - - return logps, entropies - - @profiling_decorator + @patch_profiling_decorator def _get_per_token_logps_and_entropies(self, model, inputs, @@ -1668,16 +1342,10 @@ def _get_per_token_logps_and_entropies_single(self, model, inputs, compute_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.template.sequence_parallel_size > 1: + return self._get_per_token_logps_and_entropies_sp(model, inputs, compute_entropy=compute_entropy) logits_to_keep = inputs['logits_to_keep'] input_ids = inputs['input_ids'] - is_padding_free = self.template.padding_free - use_sp = self.template.sequence_parallel_size > 1 - - # Store metadata for padding_free restoration - if is_padding_free: - original_seq_lengths = inputs.get('seq_lengths') - batch_size = original_seq_lengths.shape[0] - unwrapped_model = self.accelerator.unwrap_model(model) if is_peft_model(unwrapped_model): parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters @@ -1685,34 +1353,38 @@ def _get_per_token_logps_and_entropies_single(self, parameters = inspect.signature(unwrapped_model.forward).parameters use_local_entropy = not hasattr(super(), '_get_per_token_logps_and_entropies') and compute_entropy - # can_use_super only when not padding_free and not using SP - can_use_super = (not self.is_multimodal and 'logits_to_keep' in parameters and not use_local_entropy - and not is_padding_free and not use_sp) + can_use_super = (not self.is_multimodal and 'logits_to_keep' in parameters and not use_local_entropy) + if 'attention_mask' not in inputs: + # when set padding_free true, the attention_mask is not in inputs + can_use_super = False if can_use_super: - # Path 1: Use super() method (non-padding_free, non-SP) + # save memory if hasattr(super(), '_get_per_token_logps_and_entropies'): logps, entropies = super()._get_per_token_logps_and_entropies( model, input_ids, inputs['attention_mask'], logits_to_keep, compute_entropy=compute_entropy) else: logps = super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep) entropies = None - elif use_sp: - # Path 2: Use sequence parallel - # In padding_free mode: returns [1, logits_to_keep] format (rmpad, needs unpad) - # In non-padding_free mode: returns [batch_size, logits_to_keep] format - logps, entropies = self._get_logps_via_sp( - model, inputs, logits_to_keep, input_ids, compute_entropy=compute_entropy) else: - # Path 3: Local forward pass (padding_free or multimodal or no logits_to_keep support) - # Returns [1, total_nnz] in padding_free mode, or [batch_size, logits_to_keep] otherwise - logps, entropies = self._get_logps_via_local_forward( - model, inputs, logits_to_keep, input_ids, compute_entropy=compute_entropy) - - # Unpad for padding_free mode (both SP and non-SP paths need this) - if is_padding_free: - logps, entropies = self._unpad_logps_and_entropies(logps, entropies, logits_to_keep, batch_size, - original_seq_lengths, compute_entropy) + inputs = { + k: v + for k, v in inputs.items() if k not in [ + 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', + 'truncated_mask', 'seq_lengths' + ] + } + if 'logits_to_keep' in self.model_kwarg_keys: + inputs['logits_to_keep'] = logits_to_keep + 1 + logits = model(**inputs).logits + # exclude the last logit: it corresponds to the next token pred + logits = logits[:, -(logits_to_keep + 1):-1, :] + logits = logits / self.temperature + input_ids = input_ids[:, -logits_to_keep:] + logps = selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + entropies = None + if compute_entropy: + entropies = entropy_from_logits(logits) return logps, entropies @@ -1780,7 +1452,7 @@ def _get_per_token_logps_and_entropies_chunked(self, return final_logps, final_entropies - @profiling_decorator + @patch_profiling_decorator def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep): # unwrap the model to access the model.model if is_peft_model(unwrapped_model): @@ -1789,32 +1461,25 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep): last_hidden_state = unwrapped_model.model( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).last_hidden_state else: - model_inputs = self._prepare_model_inputs(inputs) + inputs = { + k: v + for k, v in inputs.items() if k not in [ + 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', + 'truncated_mask' + ] + } if 'logits_to_keep' in self.model_kwarg_keys: - model_inputs['logits_to_keep'] = logits_to_keep + 1 + inputs['logits_to_keep'] = logits_to_keep + 1 - last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state + last_hidden_state = unwrapped_model.model(**inputs).last_hidden_state last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) if logits_to_keep is not None: last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) return last_hidden_state - def _get_rollout_is_correction(self, old_per_token_logps, rollout_per_token_logps, completion_mask): - """Compute rollout importance sampling log-ratio and IS weights. - - Returns: - (rollout_log_ratio, rollout_is_weights) if rollout IS correction is applicable, - (None, None) otherwise. - """ - if self.rollout_importance_sampling_mode is None or self.disable_rollout_importance_sampling: - return None, None - - rollout_log_ratio = old_per_token_logps - rollout_per_token_logps - rollout_is_weights = self._apply_rollout_importance_sampling(rollout_log_ratio, completion_mask) - return rollout_log_ratio, rollout_is_weights - def compute_liger_loss(self, unwrapped_model, inputs): + # Compute the per-token log probabilities for the model assert not self.template.padding_free assert self.advantage_estimator == 'grpo' input_ids = inputs['input_ids'] @@ -1822,16 +1487,9 @@ def compute_liger_loss(self, unwrapped_model, inputs): completion_ids = input_ids[:, -logits_to_keep:] completion_mask = inputs['completion_mask'] + # get the last hidden state of the model last_hidden_state = self._get_last_hidden_state(unwrapped_model, inputs, logits_to_keep) - - old_per_token_logps = inputs.get('old_per_token_logps') - local_has = inputs.get('rollout_per_token_logps') is not None - vllm_is_ratio = None - if all(gather_object([local_has])): - rollout_per_token_logps = inputs['rollout_per_token_logps'] - _, vllm_is_ratio = self._get_rollout_is_correction(old_per_token_logps, rollout_per_token_logps, - completion_mask) - + # compute loss and metrics using liger grpo loss loss, metrics = self.liger_grpo_loss( _input=last_hidden_state, lin_weight=unwrapped_model.lm_head.weight, @@ -1839,11 +1497,11 @@ def compute_liger_loss(self, unwrapped_model, inputs): attention_mask=completion_mask, advantages=inputs['advantages'], bias=unwrapped_model.lm_head.bias, - old_per_token_logps=old_per_token_logps, + old_per_token_logps=inputs.get('old_per_token_logps'), ref_per_token_logps=inputs.get('ref_per_token_logps'), - vllm_is_ratio=vllm_is_ratio, ) - + # Extract metrics from the liger_grpo_loss output + # KL divergence is the first metric when beta is non-zero mean_kl = metrics[0] if self.beta != 0.0 else None clip_ratio = metrics[-1] @@ -1876,6 +1534,7 @@ def old_policy(self): return (self.num_iterations > 1 or self.args.gradient_accumulation_steps % self.args.steps_per_generation != 0) else: + from swift.trainers.sequence_parallel import sequence_parallel return (self.num_iterations > 1 or self.args.gradient_accumulation_steps % (self.args.steps_per_generation * sequence_parallel.world_size) != 0) @@ -1899,6 +1558,66 @@ def offload_context(self): if getattr(self, 'optimizer', None) and self.args.offload_optimizer: self.load_optimizer() + @patch_profiling_decorator + def resample_encode_failed_inputs(self, inputs: DataType, n_try_fetch: int = 10) -> DataType: + """ + Attempt to encode each input using the template. If encoding fails, + resample from a backup iterator until successful or until the maximum + number of retries is reached. + + Args: + inputs (DataType): A list of input data samples, each containing a `messages` field. + n_try_fetch (int, optional): Maximum number of retries to fetch a new sample + when encoding fails. Defaults to 10. + + Returns: + DataType: A list of successfully encoded input samples. + + Raises: + RuntimeError: If encoding fails after `n_try_fetch` resampling attempts. + """ + template = self.template + last_messages = None + last_valid_data = None + + for i, data in enumerate(inputs): + # Skip samples with the same `messages` as the previous one. + # If the last sample was successfully encoded, reuse it. + if last_messages is not None and data['messages'] == last_messages: + if last_valid_data is not None: + inputs[i] = last_valid_data + continue + + current_data = data + n_try = 0 + + while True: + try: + # Attempt to encode the current sample. + remove_response(current_data['messages']) + template.encode(current_data) + # If successful, store the result and update the last valid data. + inputs[i] = current_data + last_messages = current_data['messages'] + last_valid_data = current_data + break + + except Exception as e: + # Encoding failed — attempt to resample a new input. + logger.warning(f'Encoding failed for one sample; resampling a new input. {e}') + n_try += 1 + + # Stop if the maximum retry limit is exceeded. + if n_try > n_try_fetch: + raise RuntimeError('Failed to obtain a valid sample after multiple attempts. ' + 'Consider increasing `max_length` or adjusting the ' + '`truncation_strategy` to avoid excessive truncation.') + + # Fetch a new sample from the resampling iterator. + current_data = next(self.truncated_resample_iterator)[0] + + return inputs + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: mode = 'train' if self.model.training else 'eval' metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics @@ -1945,7 +1664,6 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non if report_to_wandb: import pandas as pd - # Create a copy to avoid modifying the original table used by other loggers. wandb_table = table.copy() if self._logs.get('image'): @@ -2086,7 +1804,7 @@ def get_chunked_inputs(self, inputs, start_idx, end_idx): # for LLM, slice the inputs for key, val in inputs.items(): if isinstance(val, torch.Tensor): - chunk_inputs[key] = val if val.ndim == 0 else val[start_idx:end_idx] + chunk_inputs[key] = val[start_idx:end_idx] else: chunk_inputs[key] = val if self.is_multimodal: @@ -2103,18 +1821,18 @@ def _prepare_liger_loss(self): self.use_liger_loss = self.args.use_liger_kernel if self.use_liger_loss: from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + kwargs = {} + if 'importance_sampling_level' in inspect.signature(LigerFusedLinearGRPOLoss.__init__).parameters: + kwargs['importance_sampling_level'] = self.importance_sampling_level self.liger_grpo_loss = LigerFusedLinearGRPOLoss( beta=self.beta, - compiled=False, epsilon_low=self.epsilon_low, epsilon_high=self.epsilon_high, temperature=self.temperature, use_ref_model=self.beta != 0.0, loss_type=self.loss_type, max_completion_length=self.max_completion_length, - importance_sampling_level=self.importance_sampling_level, - sapo_temperature_pos=self.tau_pos, - sapo_temperature_neg=self.tau_neg, + **kwargs, ) self._forward_redirection = _ForwardRedirection() @@ -2141,9 +1859,6 @@ def _collect_config_info(self) -> Dict[str, str]: 'importance_sampling_level': str(self.importance_sampling_level), 'advantage_estimator': str(self.advantage_estimator), 'chord_sft_enabled': str(self.chord_sft_dataset is not None), - 'offpolicy_sequence_mask': 'enable' if self.off_policy_sequence_mask_delta is not None else 'disable', - 'rollout_importance_sampling': 'enable' if self.rollout_importance_sampling_mode is not None else 'disable', - 'loss_type': str(self.loss_type), } return config @@ -2167,35 +1882,19 @@ def _prepare_algorithm_params(self): # Entropy Mask, https://arxiv.org/abs/2506.01939 self.top_entropy_quantile = args.top_entropy_quantile - # GSPO, https://arxiv.org/abs/2507.18071 + # GSPO, https://www.arxiv.org/abs/2507.18071 self.importance_sampling_level = args.importance_sampling_level - # SAPO, https://arxiv.org/abs/2511.20347 - self.tau_pos = args.tau_pos - self.tau_neg = args.tau_neg - # RLOO, self.advantage_estimator = args.advantage_estimator self.kl_in_reward = args.kl_in_reward - if self.scale_rewards == 'gdpo' and self.kl_in_reward: - logger.warning('GDPO mode does not support kl_in_reward=True. Setting kl_in_reward=False.') - self.kl_in_reward = False - # Rollout Importance Sampling Correction - self.rollout_importance_sampling_mode = args.rollout_importance_sampling_mode - self.rollout_importance_sampling_threshold = args.rollout_importance_sampling_threshold - self.log_rollout_offpolicy_metrics = args.log_rollout_offpolicy_metrics - - # Off-Policy Sequence Masking - self.off_policy_sequence_mask_delta = args.off_policy_sequence_mask_delta - - def _prepare_chord_dataset(self): # CHORD, https://arxiv.org/abs/2508.11408 self.chord_sft_iterator = None if self.chord_sft_dataset: self.chord_sft_iterator = make_chord_sft_dataset(self, self.chord_sft_dataset) - def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=None): + def _prepare_rewards(self, reward_funcs, reward_model=None, **kwargs): args = self.args device = self.accelerator.device @@ -2206,9 +1905,16 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non for i, reward_func in enumerate(reward_funcs): if reward_func in orms: reward_func_class = orms[reward_func] - reward_funcs[i] = reward_func_class(args=args) + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) elif not callable(reward_func): - raise ValueError(f'reward_function {reward_func} is not implemented in swift.rewards') + raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin') self.reward_funcs = reward_funcs self.reward_func_names = [] @@ -2222,6 +1928,7 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non self.reward_model_plugins = [None] * len(self.reward_funcs) if reward_model is not None: + reward_template = kwargs.pop('reward_template') reward_plugins = args.reward_model_plugin if reward_plugins is None: reward_plugins = ['default'] * len(reward_model) @@ -2229,20 +1936,17 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non f"The number of 'reward_model_plugin' ({len(reward_plugins)}) does not match " f"the number of 'reward_model' ({len(reward_model)}). " "Please provide a corresponding 'reward_model_plugin' for each 'reward_model'.") - for rm, rm_plugin, rm_template in zip(reward_model, reward_plugins, reward_templates): + for rm, rm_plugin, rm_template in zip(reward_model, reward_plugins, reward_template): # Set encoding mode train(see details in Template.encode). # Set max_length to None to disable truncation, as the input length has already been truncated earlier. rm_template.set_mode('train') rm_template.max_length = None if rm_plugin not in rm_plugins: - raise ValueError(f'rm_plugin {rm_plugin} is not implemented in swift.rewards') + raise ValueError(f'rm_plugin {rm_plugin} is not implemented in swift.llm.plugin') self.reward_model_plugins.append(rm_plugins[rm_plugin](model=rm, template=rm_template)) self.reward_funcs.append(rm) self.reward_func_names.append(rm.config._name_or_path.split('/')[-1]) - if self.use_gym_env and not self.reward_func_names: - self.reward_func_names = ['gym_reward'] - # Reward weights if args.reward_weights is not None: if len(args.reward_weights) != len(reward_funcs): @@ -2257,27 +1961,10 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non if isinstance(reward_func, PreTrainedModel): if self.is_deepspeed_enabled: self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - elif self.is_fsdp_enabled: - from .utils import prepare_fsdp - self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) else: self.reward_funcs[i] = self.accelerator.prepare_model( reward_func, evaluation_mode=True, device_placement=True) - self._async_reward_func_indices = [] - for i, func in enumerate(self.reward_funcs): - if not isinstance(func, PreTrainedModel): - if asyncio.iscoroutinefunction(func) or asyncio.iscoroutinefunction(getattr(func, '__call__', None)): - self._async_reward_func_indices.append(i) - - # Initialize event loop for async reward functions if needed - if self._async_reward_func_indices: - self.async_reward_loop_thread, self.async_reward_loop, self.async_reward_loop_ready_event = ( - start_event_loop_in_daemon(name='GRPOTrainer-AsyncRewardLoop')) - # Wait until the event loop is running in the daemon thread - self.async_reward_loop_ready_event.wait() - atexit.register(shutdown_event_loop_in_daemon, self.async_reward_loop_thread, self.async_reward_loop) - def _prepare_resample_data_iterator(self): def cyclic_iter(iterable): @@ -2318,312 +2005,3 @@ def single_sample_context(): with single_sample_context(): self.truncated_resample_iterator = cyclic_iter(self.get_train_dataloader()) - - def _compute_sequence_level_ratios(self, is_ratio: torch.Tensor, completion_mask: torch.Tensor) -> torch.Tensor: - """ - Helper function to compute sequence-level importance sampling ratios. - - Args: - is_ratio: Token-level IS ratios, shape [B, T] - completion_mask: Boolean mask for completion tokens, shape [B, T] - - Returns: - Sequence-level ratios as geometric mean of token-level ratios - """ - log_ratio = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) - - return seq_ratios - - def _apply_rollout_importance_sampling(self, rollout_log_ratio: torch.Tensor, - completion_mask: torch.Tensor) -> torch.Tensor: - """ - Apply vLLM importance sampling correction using one of four modes. - - Args: - rollout_log_ratio: log(π_θ / π_rollout) per token, shape [B, T] - completion_mask: Boolean mask for completion tokens, shape [B, T] - - Returns: - IS weights to multiply with loss, same shape as rollout_log_ratio - """ - mode = self.rollout_importance_sampling_mode - threshold = self.rollout_importance_sampling_threshold - - # Clamp log_ratio to prevent numerical overflow from padding values (-1e10) - # A log_ratio of 20 corresponds to exp(20) ≈ 485 million, which is already extreme - SAFETY_BOUND = 20.0 - rollout_log_ratio_safe = torch.clamp(rollout_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) - - # Compute importance sampling ratios: exp(log_ratio) - is_ratio = torch.exp(rollout_log_ratio_safe) - - if mode == 'token_truncate': - # Token-level truncated IS: clip ratios from above at threshold - is_weights = torch.clamp(is_ratio, max=threshold) - - elif mode == 'token_mask': - # Token-level masked IS: mask out tokens with ratio > threshold - is_weights = torch.where(is_ratio <= threshold, is_ratio, torch.zeros_like(is_ratio)) - - elif mode == 'sequence_truncate': - # Sequence-level truncated IS: compute sequence-level ratio and clip - seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) - clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) - - is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) - - elif mode == 'sequence_mask': - # Sequence-level masked IS: mask entire sequences with ratio > threshold - seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) - seq_mask = (seq_ratios <= threshold).float() - - # Apply mask to original token-level ratios - is_weights = is_ratio * seq_mask.unsqueeze(-1) - else: - return is_ratio - - return is_weights - - def _compute_off_policy_sequence_mask( - self, - per_token_logps: torch.Tensor, - old_policy_per_token_logps: torch.Tensor, - completion_mask: torch.Tensor, - advantages: torch.Tensor, - ) -> torch.Tensor: - """ - Compute off-policy sequence mask to filter out sequences that deviate too much - from the old/rollout policy AND have negative advantage. - - This implements the Off-Policy Sequence Masking technique from DeepSeek-V3.2 - (https://arxiv.org/abs/2512.02556). The mask filters sequences where: - 1. mean(old_policy_logps - policy_logps) > off_policy_sequence_mask_delta - 2. AND advantage < 0 - - Args: - per_token_logps: Log probs from current policy, shape [B, T] - old_policy_per_token_logps: Log probs from old/rollout policy, shape [B, T]. - Uses rollout_per_token_logps if available, otherwise old_per_token_logps. - completion_mask: Boolean mask for completion tokens, shape [B, T] - advantages: Advantage values per sample, shape [B] - - Returns: - Sequence mask, shape [B], True = keep sequence, False = mask out - """ - # Compute per-token log ratio: log(π_old / π_current) - # Following DeepSeek-V3.2: positive delta means old policy assigns higher prob - log_ratio = old_policy_per_token_logps - per_token_logps - - # Compute sequence-level mean of log ratio - seq_mean_log_ratio = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - - # Mask condition: delta > threshold AND advantage < 0 - # Keep sequences that do NOT meet this condition - exceeds_threshold = seq_mean_log_ratio > self.off_policy_sequence_mask_delta - negative_advantage = advantages < 0 - should_mask = exceeds_threshold & negative_advantage - - # Return mask: True = keep, False = mask out - return ~should_mask - - def _compute_rollout_offpolicy_metrics( - self, - per_token_logps: torch.Tensor, - rollout_per_token_logps: torch.Tensor, - completion_mask: torch.Tensor, - ) -> Dict[str, float]: - """ - Compute off-policy diagnostic metrics (always computed for monitoring). - reference: verl/verl/trainer/ppo/rollout_corr_helper.py - - These metrics help diagnose the off-policy gap between rollout and training policies, - which can arise from policy mismatch (e.g., vLLM BF16 vs FSDP FP32), model staleness, - or general distribution shifts. - - Key metrics: - - kl: Direct KL divergence estimator KL(π_rollout || π_training) - - k3_kl: K3 KL estimator for stability (more stable for small KL) - - training_ppl: Perplexity of training policy - - rollout_ppl: Perplexity of rollout policy - - log_ppl_diff: Difference in log perplexities - - ppl_ratio: Ratio of training PPL to rollout PPL - - chi2_token: Token-level χ² divergence E[ρ²] - 1 - - chi2_seq: Sequence-level χ² divergence E[(∏ρ_t)²] - 1 - - Args: - per_token_logps: Log probs from training policy model, shape [B, T] - rollout_per_token_logps: Log probs from rollout policy, shape [B, T] - completion_mask: Boolean mask for completion tokens, shape [B, T] - - Returns: - Dictionary with off-policy diagnostic metrics - """ - SAFETY_BOUND = 20.0 - metrics = {} - - # Helper function for masked mean - def masked_mean(x, mask, axis=None): - if axis is None: - return (x * mask).sum() / mask.sum().clamp(min=1.0) - else: - return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) - - # 1. Training policy perplexity (always computed) - # Formula: exp(-1/|T| * Σ log π_training(y_t|y_ Dict[str, float]: - """ - Compute importance sampling correction metrics (ess, clipped_frac, is_weight_mean). - Only called when rollout_importance_sampling_mode is enabled. - - Args: - rollout_log_ratio: Log ratio log(π_policy / π_rollout), shape [B, T] - is_weights: Importance sampling weights after correction, shape [B, T] - completion_mask: Boolean mask for completion tokens, shape [B, T] - - Returns: - Dictionary with IS-specific metrics: - - is_weight_mean: Mean of IS weights - - ess: Effective Sample Size = 1 / E[(w_i / E[w_i])²] - - clipped_frac: Fraction of clipped/masked samples - """ - metrics = {} - SAFETY_BOUND = 20.0 - threshold = self.rollout_importance_sampling_threshold - threshold_lower = 1.0 / threshold # Default lower threshold (reciprocal of upper) - - # Helper function for masked mean - def masked_mean(x, mask): - return (x * mask).sum() / mask.sum().clamp(min=1.0) - - # Compute IS ratio with safety bounds - log_ratio_safe = torch.clamp(rollout_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) - is_ratio = torch.exp(log_ratio_safe) - - # 1. IS weight statistics - mean_is_weight = masked_mean(is_weights, completion_mask) - metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() - - # 2. Compute Effective Sample Size (ESS) for IS weights - # ESS = 1 / E[(w_i / E[w_i])²] (using clamped weights for stability) - # This measures how many "effective" independent samples we have after IS weighting - weights_for_ess = is_weights.clamp(min=threshold_lower, max=threshold) - mean_for_ess = masked_mean(weights_for_ess, completion_mask) - is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero - ess = 1.0 / masked_mean(is_weights_normalized.square(), completion_mask).clamp(min=1e-10) - metrics['ess'] = self.accelerator.gather_for_metrics(ess).nanmean().item() - - # 3. Fraction of clipped/masked samples - if self.rollout_importance_sampling_mode in ['token_truncate', 'token_mask']: - # Token-level - if self.rollout_importance_sampling_mode == 'token_truncate': - clipped_frac = masked_mean((is_ratio > threshold).float(), completion_mask) - else: # token_mask - clipped_frac = masked_mean((is_weights == 0).float(), completion_mask) - metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() - else: - # Sequence-level (both truncate and mask) - seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) - clipped_frac = (seq_ratios > threshold).float().mean() - metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() - - return metrics - - def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]: - """Filters inputs to create model_inputs, removing GRPO-specific keys.""" - return { - k: v - for k, v in inputs.items() if k not in [ - 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', - 'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'rollout_per_token_logps' - ] - } - - def _get_eval_sampler(self, eval_dataset): - return RepeatSampler( - data_source=eval_dataset, - mini_repeat_count=self.num_generations_eval, - seed=self.args.seed, - ) diff --git a/train_2b_grpo.sh b/train_2b_grpo.sh new file mode 100644 index 0000000000..70eb7fac6a --- /dev/null +++ b/train_2b_grpo.sh @@ -0,0 +1,34 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +MAX_PIXELS=1003520 \ +NPROC_PER_NODE=4 \ +swift rlhf \ + --rlhf_type grpo \ + --model /root/autodl-tmp/model/Qwen3_grpo \ + --dynamic_sample true \ + --max_resample_times 3 \ + --train_type lora \ + --dataset /root/autodl-tmp/sft_data/train/dataset_rlhf_new.json \ + --use_vllm true \ + --vllm_mode server \ + --vllm_gpu_memory_utilization 0.55 \ + --vllm_tensor_parallel_size 2 \ + --vllm_mm_processor_cache_gb 0 \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-6 \ + --save_total_limit 20 \ + --logging_steps 5 \ + --output_dir /root/autodl-tmp/sft_data/train/grpo_result \ + --gradient_accumulation_steps 1 \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --max_completion_length 1024 \ + --external_plugins /root/autodl-tmp/sft_data/train/safety_reward_plugin1.py \ + --reward_funcs multi_label_safety_penalty format \ + --num_generations 8 \ + --sleep_level 0 \ + --temperature 0.7 \ + --top_p 0.85 \ + --deepspeed zero3 \ No newline at end of file diff --git a/train_2b_sft.sh b/train_2b_sft.sh new file mode 100644 index 0000000000..1319f3da52 --- /dev/null +++ b/train_2b_sft.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +MODEL_DIR="/root/autodl-tmp/model/Qwen/Qwen3-VL-2B-Instruct" +OUTPUT_DIR="/root/autodl-tmp/sft_data/train/grpo_result" +DATASET="/root/autodl-tmp/sft_data/train/dataset_multi_vio.json" + + +NPROC_PER_NODE=4 +TOTAL_BATCH_SIZE=16 +PER_DEVICE_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=$((TOTAL_BATCH_SIZE / (NPROC_PER_NODE * PER_DEVICE_BATCH_SIZE))) + +FORCE_TORCHRUN=1 \ +NPROC_PER_NODE=$NPROC_PER_NODE \ +swift sft \ + --model $MODEL_DIR \ + --model_type qwen3_vl \ + --train_type lora \ + --dataset $DATASET \ + --torch_dtype bfloat16 \ + --num_train_epochs 2 \ + --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ + --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ + --learning_rate 1e-4 \ + --save_steps 30 \ + --save_total_limit 50 \ + --logging_steps 10 \ + --max_length 8192 \ + --output_dir $OUTPUT_DIR \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero3 \ + --bf16 true \ + --save_safetensors true \ + --gradient_checkpointing true \ diff --git a/train_2b_sympo.sh b/train_2b_sympo.sh new file mode 100644 index 0000000000..f631ed2ca9 --- /dev/null +++ b/train_2b_sympo.sh @@ -0,0 +1,35 @@ +#!/bin/bash +export PYTHONPATH=/root/autodl-tmp/sft_data:$PYTHONPATH +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +MAX_PIXELS=1003520 \ +NPROC_PER_NODE=4 \ +python -m swift.cli.main rlhf \ + --rlhf_type grpo \ + --advantage_estimator grpo \ + --model /root/autodl-tmp/model/Qwen3_grpo \ + --train_type lora \ + --dataset /root/autodl-tmp/sft_data/train/dataset_rlhf_new.json \ + --use_vllm true \ + --vllm_mode server \ + --vllm_gpu_memory_utilization 0.55 \ + --vllm_tensor_parallel_size 2 \ + --vllm_mm_processor_cache_gb 0 \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-6 \ + --save_total_limit 50 \ + --save_steps 10 \ + --logging_steps 5 \ + --output_dir /root/autodl-tmp/sft_data/train/grpo_result \ + --gradient_accumulation_steps 1 \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --max_completion_length 1024 \ + --external_plugins /root/autodl-tmp/sft_data/train/safety_reward_plugin1.py \ + --reward_funcs multi_label_safety_penalty format \ + --num_generations 1 \ + --sleep_level 0 \ + --temperature 0.7 \ + --top_p 0.85 \ + --deepspeed zero3 \ No newline at end of file