-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat: log grpo input images to wandb #8157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2272f9e
754fe94
779167a
5acca23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -52,9 +52,10 @@ | |||||||||
| start_event_loop_in_daemon, to_device, unwrap_model_for_generation) | ||||||||||
| from .arguments import GRPOConfig | ||||||||||
| from .rollout_mixin import DataType, RolloutTrainerMixin | ||||||||||
| from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator, | ||||||||||
| load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch, patch_save_last_checkpoint, | ||||||||||
| profiling_context, profiling_decorator, replace_assistant_response_with_ids) | ||||||||||
| from .utils import (_ForwardRedirection, collect_log_columns, compute_chord_loss, get_even_process_data, identity_data_collator, | ||||||||||
| load_pil_img, make_chord_sft_dataset, nanstd, normalize_log_image, pad_logps_back_to_batch, | ||||||||||
| patch_save_last_checkpoint, profiling_context, profiling_decorator, replace_assistant_response_with_ids, | ||||||||||
| select_log_completions_extra_columns) | ||||||||||
|
|
||||||||||
| try: | ||||||||||
| from trl.trainer.utils import entropy_from_logits | ||||||||||
|
|
@@ -282,6 +283,21 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: | |||||||||
| if all('rollout_infos' in inp and 'num_turns' in inp['rollout_infos'] for inp in inputs): | ||||||||||
| metrics_for_logs_to_gather['num_turns'] = [inp['rollout_infos']['num_turns'] for inp in inputs] | ||||||||||
|
|
||||||||||
| if all('images' in inp for inp in inputs): | ||||||||||
| metrics_for_logs_to_gather['image'] = [normalize_log_image(inp['images']) for inp in inputs] | ||||||||||
|
Comment on lines
+286
to
+287
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition
Suggested change
|
||||||||||
|
|
||||||||||
| if self.log_completions_extra_columns: | ||||||||||
| extra_columns = select_log_completions_extra_columns( | ||||||||||
| self.log_completions_extra_columns, | ||||||||||
| occupied_columns=metrics_for_logs_to_gather.keys(), | ||||||||||
| ) | ||||||||||
| metrics_for_logs_to_gather.update( | ||||||||||
| collect_log_columns( | ||||||||||
| inputs, | ||||||||||
| extra_columns, | ||||||||||
| warned_columns=self._missing_log_columns_warned, | ||||||||||
| )) | ||||||||||
|
|
||||||||||
| if metrics_for_logs_to_gather: | ||||||||||
| for key, value in metrics_for_logs_to_gather.items(): | ||||||||||
| if key not in self._logs: | ||||||||||
|
|
@@ -2123,6 +2139,8 @@ def _prepare_metrics(self): | |||||||||
| self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} | ||||||||||
| self.log_completions = args.log_completions | ||||||||||
| self.wandb_log_unique_prompts = args.wandb_log_unique_prompts | ||||||||||
| self.log_completions_extra_columns = list(args.log_completions_extra_columns) | ||||||||||
| self._missing_log_columns_warned = set() | ||||||||||
| self.num_completions_to_print = args.num_completions_to_print | ||||||||||
| self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl')) | ||||||||||
| self._logs = { | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import logging | ||
|
|
||
| from swift.rlhf_trainers import utils as rlhf_utils | ||
|
|
||
|
|
||
| def test_collect_log_columns_empty_config(): | ||
| rows = [{'a': 1}] | ||
| assert rlhf_utils.collect_log_columns(rows, []) == {} | ||
|
|
||
|
|
||
| def test_collect_log_columns_empty_rows(): | ||
| assert rlhf_utils.collect_log_columns([], ['a']) == {} | ||
|
|
||
|
|
||
| def test_collect_log_columns_all_present(): | ||
| rows = [{'a': 1, 'b': 'x'}, {'a': 2, 'b': 'y'}] | ||
| result = rlhf_utils.collect_log_columns(rows, ['a', 'b']) | ||
| assert result == {'a': [1, 2], 'b': ['x', 'y']} | ||
|
|
||
|
|
||
| def test_collect_log_columns_missing_warns_once_per_column(): | ||
| rows = [{'a': 1}, {'b': 2}] | ||
| warned = set() | ||
| logger = rlhf_utils.get_logger() | ||
| records = [] | ||
|
|
||
| class ListHandler(logging.Handler): | ||
| def emit(self, record): | ||
| records.append(record) | ||
|
|
||
| handler = ListHandler() | ||
| handler.setLevel(logging.WARNING) | ||
| logger.addHandler(handler) | ||
| try: | ||
| result1 = rlhf_utils.collect_log_columns(rows, ['a', 'b'], warned_columns=warned) | ||
| result2 = rlhf_utils.collect_log_columns(rows, ['a', 'b'], warned_columns=warned) | ||
| finally: | ||
| logger.removeHandler(handler) | ||
|
|
||
| assert result1 == {'a': [1, None], 'b': [None, 2]} | ||
| assert result2 == {'a': [1, None], 'b': [None, 2]} | ||
| assert len([r for r in records if 'log_completions_extra_columns' in r.getMessage()]) == 2 | ||
|
|
||
|
|
||
| def test_collect_log_columns_keeps_complex_types(): | ||
| d = {'k': 'v'} | ||
| values = [1, 2, 3] | ||
| rows = [{'meta': d, 'trace': values}] | ||
| result = rlhf_utils.collect_log_columns(rows, ['meta', 'trace']) | ||
| assert result['meta'][0] is d | ||
| assert result['trace'][0] is values | ||
|
|
||
|
|
||
| def test_select_log_completions_extra_columns_empty(): | ||
| assert rlhf_utils.select_log_completions_extra_columns([], occupied_columns=['a']) == [] | ||
|
|
||
|
|
||
| def test_select_log_completions_extra_columns_dedup_and_exclude_occupied(): | ||
| columns = ['metadata_log', 'refs_log', 'metadata_log', 'refs_log'] | ||
| result = rlhf_utils.select_log_completions_extra_columns(columns, occupied_columns=['refs_log']) | ||
| assert result == ['metadata_log'] | ||
|
|
||
|
|
||
| def test_select_log_completions_extra_columns_keeps_historical_log_keys(): | ||
| columns = ['metadata_log'] | ||
| historical_log_keys = {'prompt', 'completion', 'metadata_log'} | ||
| # Current-pass occupied columns do not include metadata_log, so it must stay selectable. | ||
| result = rlhf_utils.select_log_completions_extra_columns(columns, occupied_columns=[]) | ||
| assert result == ['metadata_log'] | ||
| assert 'metadata_log' in historical_log_keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
all('images' in data for data in batch)is too restrictive. If a batch contains a mix of samples with and without images, this condition will be false, and no images will be logged at all. It would be more robust to handle such mixed batches by logging images for the samples that have them andNonefor those that don't.