diff --git a/docs/source/Instruction/GRPO/GetStarted/GRPO.md b/docs/source/Instruction/GRPO/GetStarted/GRPO.md index c3a97eced3..f48355da22 100644 --- a/docs/source/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source/Instruction/GRPO/GetStarted/GRPO.md @@ -281,10 +281,20 @@ IS 校正指标(需设置`rollout_importance_sampling_mode`): 设置 `report_to wandb/swanlab` 将训练动态Table推送到对应的平台 -如果需要在Table中额外记录其他列,请在 `GRPOTrainer._generate_and_score_completions` 方法中,设置 metrics_to_gather 字典。 +如果需要在Table中额外记录数据集的其他列,可以设置: + +```bash +--log_completions_extra_columns col1 col2 +``` + +- 会同时写入 `completions.jsonl`、wandb table、swanlab table。 +- 若某些样本缺少指定列,会记录为 `None`,并输出一次 warning。 +- 值会按原始类型记录(例如 list/dict 不会自动转成字符串)。 + +如需更深度自定义,也可以在 `GRPOTrainer._generate_and_score_completions` 中扩展日志收集逻辑。 默认自动检测 -- `image`:视觉数据集图像输入。(暂时只支持wandb) +- `image`:视觉数据集图像输入(仅wandb)。若输入包含多张图片,只会记录第一张,并会输出warning;完整多图日志后续支持。 - `solution`:数据集中的 solution 列。 ## FAQ diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index 0500efc1c1..2b88b2fcf9 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -279,11 +279,22 @@ If `log_completions` is set, the training dynamics will be saved in the output d Setting `report_to wandb/swanlab` will send training dynamics table to the respective platform. -If you want to log extra columns in the Table, populate the `metrics_to_gather` dictionary inside `GRPOTrainer._generate_and_score_completions`. +If you want to log extra dataset columns in the completion table, set: + +```bash +--log_completions_extra_columns col1 col2 +``` + +- The columns are logged to `completions.jsonl`, wandb table, and swanlab table. +- If a configured column is missing in some samples, `None` is logged and a warning is emitted once. +- Values are kept as-is (for example, list/dict values are not stringified). + +For advanced customization, you can still extend the log collection logic in +`GRPOTrainer._generate_and_score_completions`. The trainer automatically detects and logs the following keys: -- image: image inputs for vision models(wandb only). +- image: image inputs for vision models (wandb only). If multiple images are provided, only the first image is logged and a warning is emitted; full multi-image logging will be supported later. - solution: the solution column from the dataset. ## FAQ diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index 8fa8967ae6..229e998111 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -137,6 +137,8 @@ class GRPOArguments(GRPOArgumentsMixin): be used with an experiment tracker like WandB or SwanLab (`--report_to wandb`/`swanlab`). If enabled without a tracker, completions are saved to `completions.jsonl` in the checkpoint directory. Defaults to False. + log_completions_extra_columns (List[str]): Extra dataset columns to include in completion tables when + `log_completions=true`. Missing values are logged as None with one warning per column. Defaults to `[]`. num_iterations (int): The number of update steps to perform for each data sample. This corresponds to the K value in the GRPO paper. Defaults to 1. truncation_strategy (Literal['delete', 'left', 'right', 'split', None]): The strategy for handling input @@ -151,6 +153,7 @@ class GRPOArguments(GRPOArgumentsMixin): reward_funcs: List[str] = field(default_factory=list) reward_weights: List[float] = None log_completions: bool = False + log_completions_extra_columns: List[str] = field(default_factory=list) # multi step num_iterations: int = 1 diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 5d6b0f5de0..98626a9807 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -137,6 +137,7 @@ class RLHFMegatronArgumentsMixin: wandb_log_unique_prompts: Optional[bool] = None log_completions: bool = False + log_completions_extra_columns: List[str] = field(default_factory=list) rollout_importance_sampling_mode: Optional[Literal['token_truncate', 'token_mask', 'sequence_truncate', 'sequence_mask']] = None diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 987d25934d..dbb3046981 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -26,9 +26,10 @@ from swift.megatron.utils import forward_step_helper, get_padding_to, set_random_seed from swift.rewards import orms from swift.rlhf_trainers.grpo_trainer import DataType -from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context, +from swift.rlhf_trainers.utils import (aggressive_empty_cache, collect_log_columns, load_pil_img, nanstd, + normalize_log_image, pad_logps_back_to_batch, profiling_context, profiling_decorator, replace_assistant_response_with_ids, - set_expandable_segments) + select_log_completions_extra_columns, set_expandable_segments) from swift.rollout import MultiTurnScheduler, multi_turns from swift.template import Template, TemplateInputs from swift.utils import (JsonlWriter, get_logger, get_packed_seq_params, remove_response, shutdown_event_loop_in_daemon, @@ -501,6 +502,23 @@ def _rollout(self, batch) -> List[RolloutOutput]: completions = gather_object([data.response.choices[0].message.content for data in rollout_outputs]) self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(messages)) self._logs['completion'].extend(completions) + if all('images' in data for data in batch): + images = gather_object([normalize_log_image(data['images']) for data in batch]) + if 'image' not in self._logs: + self._logs['image'] = deque(maxlen=self.args.generation_batch_size) + self._logs['image'].extend(images) + + if self.log_completions_extra_columns: + extra_columns = select_log_completions_extra_columns(self.log_completions_extra_columns) + extra_metrics = collect_log_columns( + batch, + extra_columns, + warned_columns=self._missing_log_columns_warned, + ) + for key, value in extra_metrics.items(): + if key not in self._logs: + self._logs[key] = deque(maxlen=self.args.generation_batch_size) + self._logs[key].extend(gather_object(value)) return rollout_outputs @@ -1390,11 +1408,21 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): for k, v in self._logs['rewards'].items()}, 'advantages': list(self._logs['advantages']), } + if self._logs.get('image'): + table['image'] = list(self._logs['image']) + for key, value in self._logs.items(): + if key not in table and key != 'rewards': + table[key] = list(value) self.jsonl_writer.append(table) args = self.args if 'wandb' in args.report_to: import wandb - df = pd.DataFrame(table) + wandb_table = table.copy() + if self._logs.get('image'): + wandb_table['image'] = [ + wandb.Image(load_pil_img(img)) if img is not None else None for img in self._logs['image'] + ] + df = pd.DataFrame(wandb_table) if self.wandb_log_unique_prompts: df = df.drop_duplicates(subset=['prompt']) wandb.log({'completions': wandb.Table(dataframe=df)}) @@ -1658,6 +1686,8 @@ def _prepare_metrics(self): self.log_completions = args.log_completions self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.log_completions_extra_columns = list(args.log_completions_extra_columns) + self._missing_log_columns_warned = set() self.jsonl_writer = JsonlWriter(os.path.join(args.output_dir, 'completions.jsonl'), write_on_rank='last') self.init_custom_metric = False self._last_logged_step = -1 diff --git a/swift/rlhf_trainers/args_mixin.py b/swift/rlhf_trainers/args_mixin.py index b3dbaea39c..2bd72d47f5 100644 --- a/swift/rlhf_trainers/args_mixin.py +++ b/swift/rlhf_trainers/args_mixin.py @@ -292,6 +292,8 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): log_rollout_offpolicy_metrics (bool): Whether to log rollout off-policy diagnostic metrics (KL, PPL, chi2, etc.) when `rollout_importance_sampling_mode` is not set. When `rollout_importance_sampling_mode` is set, metrics are always logged regardless of this setting. Defaults to False. + log_completions_extra_columns (List[str]): Extra dataset columns to include in completion tables when + `log_completions=true`. Missing values are logged as None with one warning per column. Defaults to []. """ epsilon: float = 0.2 epsilon_high: Optional[float] = None @@ -374,6 +376,7 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): 'sequence_mask']] = None rollout_importance_sampling_threshold: float = 2.0 # Threshold for truncation/masking (C in paper) log_rollout_offpolicy_metrics: bool = False # Log off-policy metrics even when IS correction is disabled + log_completions_extra_columns: List[str] = field(default_factory=list) # Off-Policy Sequence Masking: mask out sequences that deviate too much from rollout policy # If set, compute mean(rollout_per_token_logps - per_token_logps) per sequence, # and mask sequences where this delta > threshold AND advantage < 0 diff --git a/swift/rlhf_trainers/grpo_trainer.py b/swift/rlhf_trainers/grpo_trainer.py index a4e79446ec..ade9674df1 100644 --- a/swift/rlhf_trainers/grpo_trainer.py +++ b/swift/rlhf_trainers/grpo_trainer.py @@ -52,9 +52,10 @@ start_event_loop_in_daemon, to_device, unwrap_model_for_generation) from .arguments import GRPOConfig from .rollout_mixin import DataType, RolloutTrainerMixin -from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator, - load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch, patch_save_last_checkpoint, - profiling_context, profiling_decorator, replace_assistant_response_with_ids) +from .utils import (_ForwardRedirection, collect_log_columns, compute_chord_loss, get_even_process_data, identity_data_collator, + load_pil_img, make_chord_sft_dataset, nanstd, normalize_log_image, pad_logps_back_to_batch, + patch_save_last_checkpoint, profiling_context, profiling_decorator, replace_assistant_response_with_ids, + select_log_completions_extra_columns) try: from trl.trainer.utils import entropy_from_logits @@ -282,6 +283,21 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: if all('rollout_infos' in inp and 'num_turns' in inp['rollout_infos'] for inp in inputs): metrics_for_logs_to_gather['num_turns'] = [inp['rollout_infos']['num_turns'] for inp in inputs] + if all('images' in inp for inp in inputs): + metrics_for_logs_to_gather['image'] = [normalize_log_image(inp['images']) for inp in inputs] + + if self.log_completions_extra_columns: + extra_columns = select_log_completions_extra_columns( + self.log_completions_extra_columns, + occupied_columns=metrics_for_logs_to_gather.keys(), + ) + metrics_for_logs_to_gather.update( + collect_log_columns( + inputs, + extra_columns, + warned_columns=self._missing_log_columns_warned, + )) + if metrics_for_logs_to_gather: for key, value in metrics_for_logs_to_gather.items(): if key not in self._logs: @@ -2123,6 +2139,8 @@ def _prepare_metrics(self): self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self.log_completions = args.log_completions self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.log_completions_extra_columns = list(args.log_completions_extra_columns) + self._missing_log_columns_warned = set() self.num_completions_to_print = args.num_completions_to_print self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl')) self._logs = { diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py index ce3f099562..3f2e9e9381 100644 --- a/swift/rlhf_trainers/utils.py +++ b/swift/rlhf_trainers/utils.py @@ -22,7 +22,7 @@ from torch import nn from torch.utils.data import DataLoader, RandomSampler from types import MethodType -from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union from swift.template import Messages from swift.tuners.lora import LoraConfig @@ -701,6 +701,71 @@ def load_pil_img(img) -> Image: raise ValueError("Image dictionary must contain either 'bytes' or 'path' key.") +def normalize_log_image(img: Any) -> Any: + if img is None: + return None + if isinstance(img, (list, tuple)): + if not img: + return None + if len(img) > 1: + logger = get_logger() + logger.warning( + 'Multiple images detected; only the first image will be logged. ' + 'Full multi-image logging is not yet supported.' + ) + return img[0] + return img + + +def collect_log_columns(rows: List[Dict[str, Any]], + column_names: List[str], + warned_columns: Optional[Set[str]] = None) -> Dict[str, List[Any]]: + """Collect configured columns for completion-table logging. + + Missing values are filled with ``None``. For missing columns, a warning is emitted + only once per column when ``warned_columns`` is provided. + """ + if not rows or not column_names: + return {} + + warned_columns = warned_columns if warned_columns is not None else set() + logger = get_logger() + collected: Dict[str, List[Any]] = {} + + for column in dict.fromkeys(column_names): + values = [] + has_missing = False + for row in rows: + if column in row: + values.append(row[column]) + else: + values.append(None) + has_missing = True + + if has_missing and column not in warned_columns: + logger.warning( + 'Column `%s` from `log_completions_extra_columns` is missing in some samples; ' + 'None will be logged for missing values.', column) + warned_columns.add(column) + collected[column] = values + + return collected + + +def select_log_completions_extra_columns(column_names: List[str], + occupied_columns: Optional[Iterable[str]] = None) -> List[str]: + """Return deduplicated extra log columns not occupied in the current logging pass. + + This helper intentionally does not filter by historical ``self._logs`` keys. + Extra columns (e.g. metadata_log/refs_log) must be collected every step to + stay aligned with prompt/completion rows. + """ + if not column_names: + return [] + occupied = set(occupied_columns or []) + return [col for col in dict.fromkeys(column_names) if col not in occupied] + + def replace_assistant_response_with_ids(messages: 'Messages', completion_ids: List[Union[int, List[int]]], loss_mask: Optional[List[List[int]]] = None) -> 'Messages': # noqa diff --git a/tests/rlhf_trainers/test_collect_log_columns.py b/tests/rlhf_trainers/test_collect_log_columns.py new file mode 100644 index 0000000000..086afe5dfc --- /dev/null +++ b/tests/rlhf_trainers/test_collect_log_columns.py @@ -0,0 +1,70 @@ +import logging + +from swift.rlhf_trainers import utils as rlhf_utils + + +def test_collect_log_columns_empty_config(): + rows = [{'a': 1}] + assert rlhf_utils.collect_log_columns(rows, []) == {} + + +def test_collect_log_columns_empty_rows(): + assert rlhf_utils.collect_log_columns([], ['a']) == {} + + +def test_collect_log_columns_all_present(): + rows = [{'a': 1, 'b': 'x'}, {'a': 2, 'b': 'y'}] + result = rlhf_utils.collect_log_columns(rows, ['a', 'b']) + assert result == {'a': [1, 2], 'b': ['x', 'y']} + + +def test_collect_log_columns_missing_warns_once_per_column(): + rows = [{'a': 1}, {'b': 2}] + warned = set() + logger = rlhf_utils.get_logger() + records = [] + + class ListHandler(logging.Handler): + def emit(self, record): + records.append(record) + + handler = ListHandler() + handler.setLevel(logging.WARNING) + logger.addHandler(handler) + try: + result1 = rlhf_utils.collect_log_columns(rows, ['a', 'b'], warned_columns=warned) + result2 = rlhf_utils.collect_log_columns(rows, ['a', 'b'], warned_columns=warned) + finally: + logger.removeHandler(handler) + + assert result1 == {'a': [1, None], 'b': [None, 2]} + assert result2 == {'a': [1, None], 'b': [None, 2]} + assert len([r for r in records if 'log_completions_extra_columns' in r.getMessage()]) == 2 + + +def test_collect_log_columns_keeps_complex_types(): + d = {'k': 'v'} + values = [1, 2, 3] + rows = [{'meta': d, 'trace': values}] + result = rlhf_utils.collect_log_columns(rows, ['meta', 'trace']) + assert result['meta'][0] is d + assert result['trace'][0] is values + + +def test_select_log_completions_extra_columns_empty(): + assert rlhf_utils.select_log_completions_extra_columns([], occupied_columns=['a']) == [] + + +def test_select_log_completions_extra_columns_dedup_and_exclude_occupied(): + columns = ['metadata_log', 'refs_log', 'metadata_log', 'refs_log'] + result = rlhf_utils.select_log_completions_extra_columns(columns, occupied_columns=['refs_log']) + assert result == ['metadata_log'] + + +def test_select_log_completions_extra_columns_keeps_historical_log_keys(): + columns = ['metadata_log'] + historical_log_keys = {'prompt', 'completion', 'metadata_log'} + # Current-pass occupied columns do not include metadata_log, so it must stay selectable. + result = rlhf_utils.select_log_completions_extra_columns(columns, occupied_columns=[]) + assert result == ['metadata_log'] + assert 'metadata_log' in historical_log_keys diff --git a/tests/rlhf_trainers/test_normalize_log_image.py b/tests/rlhf_trainers/test_normalize_log_image.py new file mode 100644 index 0000000000..df82a1d9da --- /dev/null +++ b/tests/rlhf_trainers/test_normalize_log_image.py @@ -0,0 +1,47 @@ +import logging + +from swift.rlhf_trainers import utils as rlhf_utils + + +def test_normalize_log_image_none(): + assert rlhf_utils.normalize_log_image(None) is None + + +def test_normalize_log_image_empty_list(): + assert rlhf_utils.normalize_log_image([]) is None + + +def test_normalize_log_image_single_list(): + img = {'path': 'a.png'} + assert rlhf_utils.normalize_log_image([img]) == img + + +def test_normalize_log_image_multi_list_warns(): + logger = rlhf_utils.get_logger() + records = [] + + class ListHandler(logging.Handler): + def emit(self, record): + records.append(record) + + handler = ListHandler() + handler.setLevel(logging.WARNING) + logger.addHandler(handler) + try: + img1 = {'path': 'a.png'} + img2 = {'path': 'b.png'} + assert rlhf_utils.normalize_log_image([img1, img2]) == img1 + finally: + logger.removeHandler(handler) + + assert any('Multiple images detected' in record.getMessage() for record in records) + + +def test_normalize_log_image_dict(): + img = {'path': 'a.png'} + assert rlhf_utils.normalize_log_image(img) == img + + +def test_normalize_log_image_string(): + img = 'a.png' + assert rlhf_utils.normalize_log_image(img) == img