Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions docs/source/Instruction/GRPO/GetStarted/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,20 @@ IS 校正指标(需设置`rollout_importance_sampling_mode`):

设置 `report_to wandb/swanlab` 将训练动态Table推送到对应的平台

如果需要在Table中额外记录其他列,请在 `GRPOTrainer._generate_and_score_completions` 方法中,设置 metrics_to_gather 字典。
如果需要在Table中额外记录数据集的其他列,可以设置:

```bash
--log_completions_extra_columns col1 col2
```

- 会同时写入 `completions.jsonl`、wandb table、swanlab table。
- 若某些样本缺少指定列,会记录为 `None`,并输出一次 warning。
- 值会按原始类型记录(例如 list/dict 不会自动转成字符串)。

如需更深度自定义,也可以在 `GRPOTrainer._generate_and_score_completions` 中扩展日志收集逻辑。

默认自动检测
- `image`:视觉数据集图像输入。(暂时只支持wandb)
- `image`:视觉数据集图像输入(仅wandb)。若输入包含多张图片,只会记录第一张,并会输出warning;完整多图日志后续支持。
- `solution`:数据集中的 solution 列。

## FAQ
Expand Down
15 changes: 13 additions & 2 deletions docs/source_en/Instruction/GRPO/GetStarted/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,22 @@ If `log_completions` is set, the training dynamics will be saved in the output d

Setting `report_to wandb/swanlab` will send training dynamics table to the respective platform.

If you want to log extra columns in the Table, populate the `metrics_to_gather` dictionary inside `GRPOTrainer._generate_and_score_completions`.
If you want to log extra dataset columns in the completion table, set:

```bash
--log_completions_extra_columns col1 col2
```

- The columns are logged to `completions.jsonl`, wandb table, and swanlab table.
- If a configured column is missing in some samples, `None` is logged and a warning is emitted once.
- Values are kept as-is (for example, list/dict values are not stringified).

For advanced customization, you can still extend the log collection logic in
`GRPOTrainer._generate_and_score_completions`.

The trainer automatically detects and logs the following keys:

- image: image inputs for vision models(wandb only).
- image: image inputs for vision models (wandb only). If multiple images are provided, only the first image is logged and a warning is emitted; full multi-image logging will be supported later.
- solution: the solution column from the dataset.

## FAQ
Expand Down
3 changes: 3 additions & 0 deletions swift/arguments/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class GRPOArguments(GRPOArgumentsMixin):
be used with an experiment tracker like WandB or SwanLab (`--report_to wandb`/`swanlab`). If enabled
without a tracker, completions are saved to `completions.jsonl` in the checkpoint directory. Defaults to
False.
log_completions_extra_columns (List[str]): Extra dataset columns to include in completion tables when
`log_completions=true`. Missing values are logged as None with one warning per column. Defaults to `[]`.
num_iterations (int): The number of update steps to perform for each data sample. This corresponds to the K
value in the GRPO paper. Defaults to 1.
truncation_strategy (Literal['delete', 'left', 'right', 'split', None]): The strategy for handling input
Expand All @@ -151,6 +153,7 @@ class GRPOArguments(GRPOArgumentsMixin):
reward_funcs: List[str] = field(default_factory=list)
reward_weights: List[float] = None
log_completions: bool = False
log_completions_extra_columns: List[str] = field(default_factory=list)

# multi step
num_iterations: int = 1
Expand Down
1 change: 1 addition & 0 deletions swift/megatron/arguments/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class RLHFMegatronArgumentsMixin:

wandb_log_unique_prompts: Optional[bool] = None
log_completions: bool = False
log_completions_extra_columns: List[str] = field(default_factory=list)

rollout_importance_sampling_mode: Optional[Literal['token_truncate', 'token_mask', 'sequence_truncate',
'sequence_mask']] = None
Expand Down
36 changes: 33 additions & 3 deletions swift/megatron/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
from swift.megatron.utils import forward_step_helper, get_padding_to, set_random_seed
from swift.rewards import orms
from swift.rlhf_trainers.grpo_trainer import DataType
from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context,
from swift.rlhf_trainers.utils import (aggressive_empty_cache, collect_log_columns, load_pil_img, nanstd,
normalize_log_image, pad_logps_back_to_batch, profiling_context,
profiling_decorator, replace_assistant_response_with_ids,
set_expandable_segments)
select_log_completions_extra_columns, set_expandable_segments)
from swift.rollout import MultiTurnScheduler, multi_turns
from swift.template import Template, TemplateInputs
from swift.utils import (JsonlWriter, get_logger, get_packed_seq_params, remove_response, shutdown_event_loop_in_daemon,
Expand Down Expand Up @@ -501,6 +502,23 @@ def _rollout(self, batch) -> List[RolloutOutput]:
completions = gather_object([data.response.choices[0].message.content for data in rollout_outputs])
self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(messages))
self._logs['completion'].extend(completions)
if all('images' in data for data in batch):
images = gather_object([normalize_log_image(data['images']) for data in batch])
if 'image' not in self._logs:
self._logs['image'] = deque(maxlen=self.args.generation_batch_size)
self._logs['image'].extend(images)
Comment on lines +505 to +509
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition all('images' in data for data in batch) is too restrictive. If a batch contains a mix of samples with and without images, this condition will be false, and no images will be logged at all. It would be more robust to handle such mixed batches by logging images for the samples that have them and None for those that don't.

Suggested change
if all('images' in data for data in batch):
images = gather_object([normalize_log_image(data['images']) for data in batch])
if 'image' not in self._logs:
self._logs['image'] = deque(maxlen=self.args.generation_batch_size)
self._logs['image'].extend(images)
if any('images' in data for data in batch):
images = gather_object([normalize_log_image(data.get('images')) for data in batch])
if 'image' not in self._logs:
self._logs['image'] = deque(maxlen=self.args.generation_batch_size)
self._logs['image'].extend(images)


if self.log_completions_extra_columns:
extra_columns = select_log_completions_extra_columns(self.log_completions_extra_columns)
extra_metrics = collect_log_columns(
batch,
extra_columns,
warned_columns=self._missing_log_columns_warned,
)
for key, value in extra_metrics.items():
if key not in self._logs:
self._logs[key] = deque(maxlen=self.args.generation_batch_size)
self._logs[key].extend(gather_object(value))

return rollout_outputs

Expand Down Expand Up @@ -1390,11 +1408,21 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
for k, v in self._logs['rewards'].items()},
'advantages': list(self._logs['advantages']),
}
if self._logs.get('image'):
table['image'] = list(self._logs['image'])
for key, value in self._logs.items():
if key not in table and key != 'rewards':
table[key] = list(value)
self.jsonl_writer.append(table)
args = self.args
if 'wandb' in args.report_to:
import wandb
df = pd.DataFrame(table)
wandb_table = table.copy()
if self._logs.get('image'):
wandb_table['image'] = [
wandb.Image(load_pil_img(img)) if img is not None else None for img in self._logs['image']
]
df = pd.DataFrame(wandb_table)
if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=['prompt'])
wandb.log({'completions': wandb.Table(dataframe=df)})
Expand Down Expand Up @@ -1658,6 +1686,8 @@ def _prepare_metrics(self):

self.log_completions = args.log_completions
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
self.log_completions_extra_columns = list(args.log_completions_extra_columns)
self._missing_log_columns_warned = set()
self.jsonl_writer = JsonlWriter(os.path.join(args.output_dir, 'completions.jsonl'), write_on_rank='last')
self.init_custom_metric = False
self._last_logged_step = -1
Expand Down
3 changes: 3 additions & 0 deletions swift/rlhf_trainers/args_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
log_rollout_offpolicy_metrics (bool): Whether to log rollout off-policy diagnostic metrics (KL, PPL, chi2, etc.)
when `rollout_importance_sampling_mode` is not set. When `rollout_importance_sampling_mode` is set,
metrics are always logged regardless of this setting. Defaults to False.
log_completions_extra_columns (List[str]): Extra dataset columns to include in completion tables when
`log_completions=true`. Missing values are logged as None with one warning per column. Defaults to [].
"""
epsilon: float = 0.2
epsilon_high: Optional[float] = None
Expand Down Expand Up @@ -374,6 +376,7 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
'sequence_mask']] = None
rollout_importance_sampling_threshold: float = 2.0 # Threshold for truncation/masking (C in paper)
log_rollout_offpolicy_metrics: bool = False # Log off-policy metrics even when IS correction is disabled
log_completions_extra_columns: List[str] = field(default_factory=list)
# Off-Policy Sequence Masking: mask out sequences that deviate too much from rollout policy
# If set, compute mean(rollout_per_token_logps - per_token_logps) per sequence,
# and mask sequences where this delta > threshold AND advantage < 0
Expand Down
24 changes: 21 additions & 3 deletions swift/rlhf_trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@
start_event_loop_in_daemon, to_device, unwrap_model_for_generation)
from .arguments import GRPOConfig
from .rollout_mixin import DataType, RolloutTrainerMixin
from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator,
load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch, patch_save_last_checkpoint,
profiling_context, profiling_decorator, replace_assistant_response_with_ids)
from .utils import (_ForwardRedirection, collect_log_columns, compute_chord_loss, get_even_process_data, identity_data_collator,
load_pil_img, make_chord_sft_dataset, nanstd, normalize_log_image, pad_logps_back_to_batch,
patch_save_last_checkpoint, profiling_context, profiling_decorator, replace_assistant_response_with_ids,
select_log_completions_extra_columns)

try:
from trl.trainer.utils import entropy_from_logits
Expand Down Expand Up @@ -282,6 +283,21 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType:
if all('rollout_infos' in inp and 'num_turns' in inp['rollout_infos'] for inp in inputs):
metrics_for_logs_to_gather['num_turns'] = [inp['rollout_infos']['num_turns'] for inp in inputs]

if all('images' in inp for inp in inputs):
metrics_for_logs_to_gather['image'] = [normalize_log_image(inp['images']) for inp in inputs]
Comment on lines +286 to +287
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition all('images' in inp for inp in inputs) is too restrictive. If a batch contains a mix of samples with and without images, this condition will be false, and no images will be logged at all. It would be more robust to handle such mixed batches by logging images for the samples that have them and None for those that don't.

Suggested change
if all('images' in inp for inp in inputs):
metrics_for_logs_to_gather['image'] = [normalize_log_image(inp['images']) for inp in inputs]
if any('images' in inp for inp in inputs):
metrics_for_logs_to_gather['image'] = [normalize_log_image(inp.get('images')) for inp in inputs]


if self.log_completions_extra_columns:
extra_columns = select_log_completions_extra_columns(
self.log_completions_extra_columns,
occupied_columns=metrics_for_logs_to_gather.keys(),
)
metrics_for_logs_to_gather.update(
collect_log_columns(
inputs,
extra_columns,
warned_columns=self._missing_log_columns_warned,
))

if metrics_for_logs_to_gather:
for key, value in metrics_for_logs_to_gather.items():
if key not in self._logs:
Expand Down Expand Up @@ -2123,6 +2139,8 @@ def _prepare_metrics(self):
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
self.log_completions = args.log_completions
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
self.log_completions_extra_columns = list(args.log_completions_extra_columns)
self._missing_log_columns_warned = set()
self.num_completions_to_print = args.num_completions_to_print
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))
self._logs = {
Expand Down
67 changes: 66 additions & 1 deletion swift/rlhf_trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import nn
from torch.utils.data import DataLoader, RandomSampler
from types import MethodType
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union

from swift.template import Messages
from swift.tuners.lora import LoraConfig
Expand Down Expand Up @@ -701,6 +701,71 @@ def load_pil_img(img) -> Image:
raise ValueError("Image dictionary must contain either 'bytes' or 'path' key.")


def normalize_log_image(img: Any) -> Any:
if img is None:
return None
if isinstance(img, (list, tuple)):
if not img:
return None
if len(img) > 1:
logger = get_logger()
logger.warning(
'Multiple images detected; only the first image will be logged. '
'Full multi-image logging is not yet supported.'
)
return img[0]
return img


def collect_log_columns(rows: List[Dict[str, Any]],
column_names: List[str],
warned_columns: Optional[Set[str]] = None) -> Dict[str, List[Any]]:
"""Collect configured columns for completion-table logging.

Missing values are filled with ``None``. For missing columns, a warning is emitted
only once per column when ``warned_columns`` is provided.
"""
if not rows or not column_names:
return {}

warned_columns = warned_columns if warned_columns is not None else set()
logger = get_logger()
collected: Dict[str, List[Any]] = {}

for column in dict.fromkeys(column_names):
values = []
has_missing = False
for row in rows:
if column in row:
values.append(row[column])
else:
values.append(None)
has_missing = True

if has_missing and column not in warned_columns:
logger.warning(
'Column `%s` from `log_completions_extra_columns` is missing in some samples; '
'None will be logged for missing values.', column)
warned_columns.add(column)
collected[column] = values

return collected


def select_log_completions_extra_columns(column_names: List[str],
occupied_columns: Optional[Iterable[str]] = None) -> List[str]:
"""Return deduplicated extra log columns not occupied in the current logging pass.

This helper intentionally does not filter by historical ``self._logs`` keys.
Extra columns (e.g. metadata_log/refs_log) must be collected every step to
stay aligned with prompt/completion rows.
"""
if not column_names:
return []
occupied = set(occupied_columns or [])
return [col for col in dict.fromkeys(column_names) if col not in occupied]


def replace_assistant_response_with_ids(messages: 'Messages',
completion_ids: List[Union[int, List[int]]],
loss_mask: Optional[List[List[int]]] = None) -> 'Messages': # noqa
Expand Down
70 changes: 70 additions & 0 deletions tests/rlhf_trainers/test_collect_log_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging

from swift.rlhf_trainers import utils as rlhf_utils


def test_collect_log_columns_empty_config():
rows = [{'a': 1}]
assert rlhf_utils.collect_log_columns(rows, []) == {}


def test_collect_log_columns_empty_rows():
assert rlhf_utils.collect_log_columns([], ['a']) == {}


def test_collect_log_columns_all_present():
rows = [{'a': 1, 'b': 'x'}, {'a': 2, 'b': 'y'}]
result = rlhf_utils.collect_log_columns(rows, ['a', 'b'])
assert result == {'a': [1, 2], 'b': ['x', 'y']}


def test_collect_log_columns_missing_warns_once_per_column():
rows = [{'a': 1}, {'b': 2}]
warned = set()
logger = rlhf_utils.get_logger()
records = []

class ListHandler(logging.Handler):
def emit(self, record):
records.append(record)

handler = ListHandler()
handler.setLevel(logging.WARNING)
logger.addHandler(handler)
try:
result1 = rlhf_utils.collect_log_columns(rows, ['a', 'b'], warned_columns=warned)
result2 = rlhf_utils.collect_log_columns(rows, ['a', 'b'], warned_columns=warned)
finally:
logger.removeHandler(handler)

assert result1 == {'a': [1, None], 'b': [None, 2]}
assert result2 == {'a': [1, None], 'b': [None, 2]}
assert len([r for r in records if 'log_completions_extra_columns' in r.getMessage()]) == 2


def test_collect_log_columns_keeps_complex_types():
d = {'k': 'v'}
values = [1, 2, 3]
rows = [{'meta': d, 'trace': values}]
result = rlhf_utils.collect_log_columns(rows, ['meta', 'trace'])
assert result['meta'][0] is d
assert result['trace'][0] is values


def test_select_log_completions_extra_columns_empty():
assert rlhf_utils.select_log_completions_extra_columns([], occupied_columns=['a']) == []


def test_select_log_completions_extra_columns_dedup_and_exclude_occupied():
columns = ['metadata_log', 'refs_log', 'metadata_log', 'refs_log']
result = rlhf_utils.select_log_completions_extra_columns(columns, occupied_columns=['refs_log'])
assert result == ['metadata_log']


def test_select_log_completions_extra_columns_keeps_historical_log_keys():
columns = ['metadata_log']
historical_log_keys = {'prompt', 'completion', 'metadata_log'}
# Current-pass occupied columns do not include metadata_log, so it must stay selectable.
result = rlhf_utils.select_log_completions_extra_columns(columns, occupied_columns=[])
assert result == ['metadata_log']
assert 'metadata_log' in historical_log_keys
Loading