diff --git a/swift/arguments/base_args/base_args.py b/swift/arguments/base_args/base_args.py index 7c0273080b..b41e7cca9b 100644 --- a/swift/arguments/base_args/base_args.py +++ b/swift/arguments/base_args/base_args.py @@ -303,6 +303,7 @@ def get_template(self, processor: Optional[Processor] = None, **kwargs) -> Templ template_type = self.template template_kwargs['template_type'] = template_type template = get_template(processor, **template_kwargs) + template.loss_type = getattr(self, 'loss_type', None) return template def get_model_processor(self, diff --git a/swift/megatron/trainers/reranker_trainer.py b/swift/megatron/trainers/reranker_trainer.py index 6607186960..162191194b 100644 --- a/swift/megatron/trainers/reranker_trainer.py +++ b/swift/megatron/trainers/reranker_trainer.py @@ -36,14 +36,14 @@ def _get_listwise_reranker_preds(logits, labels): labels = torch.tensor([0] * (len(positive_indices) - 1), device=preds.device) return preds, labels - def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params=None): + def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, group_sizes=None, packed_seq_params=None): training = self.unwrapped_models[0].training logits = self.get_last_tokens(output_tensor, packed_seq_params) loss = self._loss_func(ModelOutputs(logits=logits), labels) args = self.args logits_detach = logits.detach().squeeze(-1) if not training: - self.eval_metrics.update(logits_detach, labels) + self.eval_metrics.update(logits_detach, labels, group_sizes) if args.loss_type == 'listwise_reranker': preds, labels = self._get_listwise_reranker_preds(logits_detach, labels) else: @@ -64,7 +64,8 @@ def forward_step(self, data_iterator, model): vp_stage = model.module.module.vp_stage data = self.get_batch(data_iterator, vp_stage) labels = data.pop('labels', None) + group_sizes = data.pop('group_sizes', None) output_tensor = model(**data) packed_seq_params = data.get('packed_seq_params') - loss_func = partial(self.loss_func, labels=labels, packed_seq_params=packed_seq_params) + loss_func = partial(self.loss_func, labels=labels, group_sizes=group_sizes, packed_seq_params=packed_seq_params) return output_tensor, loss_func diff --git a/swift/metrics/reranker.py b/swift/metrics/reranker.py index 19d02753d4..ad05622f48 100644 --- a/swift/metrics/reranker.py +++ b/swift/metrics/reranker.py @@ -17,20 +17,82 @@ def __init__(self, *args, **kwargs): Metric.__init__(self) self.add_state('logits', default_factory=list) self.add_state('labels', default_factory=list) + self.add_state('group_sizes', default_factory=list) - def update(self, logits, labels): + def update(self, logits, labels, group_sizes=None): self.logits.append(logits.cpu().numpy()) self.labels.append(labels.cpu().numpy()) + if group_sizes is not None: + self.group_sizes.append(group_sizes.cpu().numpy()) def compute(self): predictions = np.concatenate(self.logits) labels = np.concatenate(self.labels) - return self._calculate_metrics(predictions, labels) + group_sizes = np.concatenate(self.group_sizes) if self.group_sizes else None + return self._calculate_metrics(predictions, labels, group_sizes) def compute_metrics(self, eval_prediction: EvalPrediction) -> Dict[str, float]: - return self._calculate_metrics(eval_prediction.predictions, eval_prediction.label_ids) + label_ids = eval_prediction.label_ids + group_sizes = None + if isinstance(label_ids, (tuple, list)): + labels = label_ids[0] + if len(label_ids) > 1: + group_sizes = label_ids[1] + else: + labels = label_ids + return self._calculate_metrics(eval_prediction.predictions, labels, group_sizes) + + @staticmethod + def _split_query_groups(logits, labels, group_sizes=None): + if group_sizes is not None: + group_sizes = np.array(group_sizes).astype(int).flatten() + total_size = int(group_sizes.sum()) + if total_size == len(labels): + query_groups = [] + start = 0 + for group_size in group_sizes: + if group_size <= 0: + continue + end = start + group_size + query_groups.append((logits[start:end], labels[start:end])) + start = end + return query_groups + logger.warning('The sum of group_sizes does not match the number of labels. Falling back to label-based ' + 'query boundary inference.') - def _calculate_metrics(self, logits, labels): + positive_indices = np.where(labels == 1)[0] + if len(positive_indices) == 0: + return [] + + query_groups = [] + for i, pos_idx in enumerate(positive_indices): + group_start = pos_idx + if i + 1 < len(positive_indices): + group_end = positive_indices[i + 1] + else: + group_end = len(labels) + query_groups.append((logits[group_start:group_end], labels[group_start:group_end])) + return query_groups + + @staticmethod + def _calculate_classification_metrics(logits, labels): + preds = (logits > 0).astype(int) + labels = labels.astype(int) + tp = np.sum((preds == 1) & (labels == 1)) + fp = np.sum((preds == 1) & (labels == 0)) + fn = np.sum((preds == 0) & (labels == 1)) + precision = tp / (tp + fp) if tp + fp > 0 else 0.0 + recall = tp / (tp + fn) if tp + fn > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0 + acc = np.mean(preds == labels) if len(labels) > 0 else 0.0 + return { + 'acc': acc, + 'precision': precision, + 'recall': recall, + 'f1': f1, + } + + def _calculate_metrics(self, logits, labels, group_sizes=None): """ Calculate MRR and NDCG metrics for reranker. @@ -58,50 +120,36 @@ def _calculate_metrics(self, logits, labels): logits = np.array(logits).flatten() labels = np.array(labels).flatten() - # Step 1: Find all positive sample indices (query boundaries) - positive_indices = np.where(labels == 1)[0] - - if len(positive_indices) == 0: - return {'mrr': 0.0, 'ndcg': 0.0} - - # Step 2: Split into groups (queries) - query_groups = [] - for i, pos_idx in enumerate(positive_indices): - # Each group starts at a positive index - group_start = pos_idx - - # Group ends at the next positive index or end of data - if i + 1 < len(positive_indices): - group_end = positive_indices[i + 1] - else: - group_end = len(labels) + metrics = {} + if getattr(self.args, 'loss_type', None) == 'pointwise_reranker': + metrics.update(self._calculate_classification_metrics(logits, labels)) - # Extract this query's data - query_logits = logits[group_start:group_end] - query_labels = labels[group_start:group_end] - - query_groups.append((query_logits, query_labels)) + query_groups = self._split_query_groups(logits, labels, group_sizes) + metrics['query_count'] = float(len(query_groups)) # Step 3: Calculate metrics for each query independently mrr_scores = [] ndcg_scores = [] + negative_only_query_count = 0 + skipped_query_count = 0 for query_idx, (query_logits, query_labels) in enumerate(query_groups): - # Skip groups that are too small (need at least 1 positive + 1 negative) if len(query_logits) < 2: logger.info(f'Query {query_idx}: Skipped (too small: {len(query_logits)} items)') + skipped_query_count += 1 continue - # Verify that the first sample is positive (data format validation) - if query_labels[0] != 1: - logger.info(f'Query {query_idx}: Skipped (first sample not positive)') + if np.sum(query_labels == 1) == 0: + negative_only_query_count += 1 + skipped_query_count += 1 continue # Step 3a: Calculate ranking within this query ranking = np.argsort(-query_logits) # Sort by logits descending - # Step 3b: Find position of positive document (should be at index 0 in query) - pos_rank = np.where(ranking == 0)[0][0] + 1 # +1 for 1-based ranking + # Step 3b: Find the rank of the highest-ranked positive document. + positive_mask = query_labels[ranking] == 1 + pos_rank = np.where(positive_mask)[0][0] + 1 # +1 for 1-based ranking # Step 3c: Calculate MRR for this query mrr = 1.0 / pos_rank @@ -133,14 +181,19 @@ def calculate_ndcg_single_query(relevance_scores, ranking): ndcg_scores.append(ndcg) # Step 4: Calculate mean metrics across all valid queries + metrics['ranking_query_count'] = float(len(mrr_scores)) + metrics['negative_only_query_count'] = float(negative_only_query_count) + metrics['skipped_query_count'] = float(skipped_query_count) if len(mrr_scores) == 0: logger.warning('No valid queries found for metric calculation') - return {'mrr': 0.0, 'ndcg': 0.0} + metrics.update({'mrr': 0.0, 'ndcg': 0.0}) + return metrics mean_mrr = np.mean(mrr_scores) mean_ndcg = np.mean(ndcg_scores) - return { + metrics.update({ 'mrr': mean_mrr, 'ndcg': mean_ndcg, - } + }) + return metrics diff --git a/swift/template/base.py b/swift/template/base.py index 29fa65000d..78bbfabed9 100644 --- a/swift/template/base.py +++ b/swift/template/base.py @@ -1655,7 +1655,9 @@ def _reranker_data_collator(self, if self.is_training: max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1)) max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7)) + pointwise_negative_only = getattr(self, 'loss_type', None) == 'pointwise_reranker' labels_list = [] + group_sizes = [] if pointwise_negative_only else None new_batch = [] for b in batch: labels = b.pop('labels', None) @@ -1663,22 +1665,40 @@ def _reranker_data_collator(self, negative_num = len(labels) - positive_num max_positive = min(positive_num, max_positive_samples) max_negative = min(negative_num, max_negative_samples) + if pointwise_negative_only and positive_num == 0: + # Pointwise BCE can train on all-negative samples, so keep them instead of dropping the row. + sampled_negative_indices = random.sample(range(negative_num), max_negative) + for j in sampled_negative_indices: + new_batch.append( + {key: b[key][j] + for key in b.keys() if isinstance(b[key], list) and b[key][j] is not None}) + labels_list.append(0) + if sampled_negative_indices and group_sizes is not None: + group_sizes.append(len(sampled_negative_indices)) + continue for i in random.sample(range(positive_num), max_positive): + group_size = 1 new_batch.append( {key: b[key][i] for key in b.keys() if isinstance(b[key], list) and b[key][i] is not None}) labels_list.append(1) - for j in random.sample(range(negative_num), max_negative): + sampled_negative_indices = random.sample(range(negative_num), max_negative) + for j in sampled_negative_indices: new_batch.append({ key: b[key][j + positive_num] for key in b.keys() if isinstance(b[key], list) and b[key][j + positive_num] is not None }) labels_list.append(0) + group_size += 1 + if group_sizes is not None: + group_sizes.append(group_size) num_samples = len(new_batch) res = self._data_collator(new_batch, padding_to=padding_to) res['num_samples'] = num_samples if labels_list: res['labels'] = torch.tensor(labels_list, dtype=torch.long) + if group_sizes: + res['group_sizes'] = torch.tensor(group_sizes, dtype=torch.long) else: res = self._data_collator(batch, padding_to=padding_to) return res diff --git a/swift/trainers/reranker_trainer.py b/swift/trainers/reranker_trainer.py index 06fe19b486..04ded32099 100644 --- a/swift/trainers/reranker_trainer.py +++ b/swift/trainers/reranker_trainer.py @@ -8,17 +8,28 @@ logger = get_logger() +def gather_for_reranker_metrics(input_data, use_gather_object=False): + if isinstance(input_data, tuple): + return tuple(gather_for_reranker_metrics(data, use_gather_object=use_gather_object) for data in input_data) + if isinstance(input_data, list): + return [gather_for_reranker_metrics(data, use_gather_object=use_gather_object) for data in input_data] + return gather_for_unpadded_tensors(input_data, use_gather_object=use_gather_object) + + class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.gather_function = gather_for_unpadded_tensors + self.gather_function = gather_for_reranker_metrics + if getattr(self.args, 'loss_type', None) == 'pointwise_reranker' and 'group_sizes' not in self.label_names: + self.label_names.append('group_sizes') def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Check if we have a custom loss function if self.compute_loss_func is not None: # Get labels and compute outputs labels = inputs.pop('labels', None) + group_sizes = inputs.pop('group_sizes', None) outputs = model(**inputs) if self.task_type == 'generative_reranker': logits = outputs.logits @@ -46,5 +57,5 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N def evaluation_loop(self, *args, **kwargs): output = super().evaluation_loop(*args, **kwargs) - self.gather_function = gather_for_unpadded_tensors + self.gather_function = gather_for_reranker_metrics return output diff --git a/tests/train/test_reranker_collator.py b/tests/train/test_reranker_collator.py new file mode 100644 index 0000000000..1948c6fa70 --- /dev/null +++ b/tests/train/test_reranker_collator.py @@ -0,0 +1,74 @@ +import torch + +from swift.template.base import Template + + +def _build_template(loss_type): + template = Template.__new__(Template) + template.is_training = True + template.loss_type = loss_type + template._data_collator = lambda batch, padding_to=None: {'encoded_batch': batch} + return template + + +def test_pointwise_reranker_collator_supports_negative_only(): + template = _build_template('pointwise_reranker') + batch = [{ + 'input_ids': [[101], [102]], + 'attention_mask': [[1], [1]], + 'labels': [0, 0], + }] + + res = Template._reranker_data_collator(template, batch) + + assert res['num_samples'] == 2 + assert torch.equal(res['labels'], torch.tensor([0, 0], dtype=torch.long)) + assert torch.equal(res['group_sizes'], torch.tensor([2], dtype=torch.long)) + assert len(res['encoded_batch']) == 2 + + +def test_pointwise_reranker_collator_supports_positive_only(): + template = _build_template('pointwise_reranker') + batch = [{ + 'input_ids': [[201], [202]], + 'attention_mask': [[1], [1]], + 'labels': [1, 1], + }] + + res = Template._reranker_data_collator(template, batch) + + assert res['num_samples'] == 2 + assert torch.equal(res['labels'], torch.tensor([1, 1], dtype=torch.long)) + assert torch.equal(res['group_sizes'], torch.tensor([1, 1], dtype=torch.long)) + assert len(res['encoded_batch']) == 2 + + +def test_listwise_reranker_collator_still_skips_negative_only(): + template = _build_template('listwise_reranker') + batch = [{ + 'input_ids': [[301], [302]], + 'attention_mask': [[1], [1]], + 'labels': [0, 0], + }] + + res = Template._reranker_data_collator(template, batch) + + assert res['num_samples'] == 0 + assert 'labels' not in res + assert 'group_sizes' not in res + assert res['encoded_batch'] == [] + + +def test_reranker_collator_does_not_emit_group_sizes_without_custom_loss(): + template = _build_template(None) + batch = [{ + 'input_ids': [[401], [402]], + 'attention_mask': [[1], [1]], + 'labels': [1, 0], + }] + + res = Template._reranker_data_collator(template, batch) + + assert res['num_samples'] == 2 + assert torch.equal(res['labels'], torch.tensor([1, 0], dtype=torch.long)) + assert 'group_sizes' not in res diff --git a/tests/train/test_reranker_metrics.py b/tests/train/test_reranker_metrics.py new file mode 100644 index 0000000000..10d4b6dd7a --- /dev/null +++ b/tests/train/test_reranker_metrics.py @@ -0,0 +1,45 @@ +from types import SimpleNamespace + +from swift.metrics.reranker import RerankerMetrics + + +def _build_metrics(loss_type): + metrics = RerankerMetrics.__new__(RerankerMetrics) + metrics.args = SimpleNamespace(loss_type=loss_type) + metrics.trainer = None + return metrics + + +def test_pointwise_reranker_metrics_support_negative_only_queries(): + metrics = _build_metrics('pointwise_reranker') + + result = metrics._calculate_metrics( + logits=[-2.0, -1.0, 3.0, -0.5], + labels=[0, 0, 1, 0], + group_sizes=[2, 2], + ) + + assert result['acc'] == 1.0 + assert result['precision'] == 1.0 + assert result['recall'] == 1.0 + assert result['f1'] == 1.0 + assert result['query_count'] == 2.0 + assert result['ranking_query_count'] == 1.0 + assert result['negative_only_query_count'] == 1.0 + assert result['mrr'] == 1.0 + assert result['ndcg'] == 1.0 + + +def test_listwise_reranker_metrics_preserve_group_boundaries(): + metrics = _build_metrics('listwise_reranker') + + result = metrics._calculate_metrics( + logits=[1.0, 0.0, -1.0, 1.2, 1.0, 0.5], + labels=[1, 0, 0, 0, 1, 0], + group_sizes=[3, 3], + ) + + assert result['query_count'] == 2.0 + assert result['ranking_query_count'] == 2.0 + assert result['negative_only_query_count'] == 0.0 + assert result['mrr'] == 0.75 diff --git a/tests/train/test_reranker_trainer.py b/tests/train/test_reranker_trainer.py new file mode 100644 index 0000000000..d23bd3bc2f --- /dev/null +++ b/tests/train/test_reranker_trainer.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace + +import torch + +from swift.trainers.reranker_trainer import RerankerTrainer, gather_for_reranker_metrics +from swift.trainers.trainer import Trainer + + +def _fake_trainer_init(self, *args, **kwargs): + self.args = kwargs.get('args', SimpleNamespace(loss_type=None)) + self.label_names = ['labels'] + self.gather_function = None + + +def test_pointwise_reranker_adds_group_sizes_to_label_names(monkeypatch): + monkeypatch.setattr(Trainer, '__init__', _fake_trainer_init) + + trainer = RerankerTrainer(args=SimpleNamespace(loss_type='pointwise_reranker')) + + assert trainer.label_names == ['labels', 'group_sizes'] + assert trainer.gather_function is gather_for_reranker_metrics + + +def test_listwise_reranker_keeps_default_label_names(monkeypatch): + monkeypatch.setattr(Trainer, '__init__', _fake_trainer_init) + + trainer = RerankerTrainer(args=SimpleNamespace(loss_type='listwise_reranker')) + + assert trainer.label_names == ['labels'] + assert trainer.gather_function is gather_for_reranker_metrics + + +def test_gather_for_reranker_metrics_preserves_tuple_labels(): + labels = torch.tensor([1, 0, 0, 1], dtype=torch.long) + group_sizes = torch.tensor([2, 2], dtype=torch.long) + + gathered = gather_for_reranker_metrics((labels, group_sizes)) + + assert isinstance(gathered, tuple) + assert torch.equal(gathered[0], labels) + assert torch.equal(gathered[1], group_sizes) + + +def test_evaluation_loop_preserves_metric_prefix(monkeypatch): + monkeypatch.setattr(Trainer, '__init__', _fake_trainer_init) + + captured = {} + + def _fake_evaluation_loop(self, *args, **kwargs): + captured['metric_key_prefix'] = kwargs.get('metric_key_prefix') + self.gather_function = object() + return SimpleNamespace(metrics={'test_mrr': 0.75}) + + monkeypatch.setattr(Trainer, 'evaluation_loop', _fake_evaluation_loop) + + trainer = RerankerTrainer(args=SimpleNamespace(loss_type='pointwise_reranker')) + output = trainer.evaluation_loop(metric_key_prefix='test') + + assert captured['metric_key_prefix'] == 'test' + assert output.metrics == {'test_mrr': 0.75} + assert trainer.gather_function is gather_for_reranker_metrics