diff --git a/swift/loss/embedding.py b/swift/loss/embedding.py index 97d28f621b..d1fe48a632 100644 --- a/swift/loss/embedding.py +++ b/swift/loss/embedding.py @@ -110,212 +110,213 @@ def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None): return split_tensors -class InfonceLoss(BaseLoss): +class InfoNCELoss(BaseLoss): - def __call__(self, outputs, labels, **kwargs) -> torch.Tensor: - temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.1')) # temperature - # calculate CE across the batch, meaning all samples will be negative except the matching positive - use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True')) - hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) # how many negative prompts kept in one sample - # mask out fake negatives - infonce_mask_fake_negative = strtobool(os.environ.get('INFONCE_MASK_FAKE_NEGATIVE', 'False')) - fake_neg_margin = float(os.environ.get('INFONCE_FAKE_NEG_MARGIN', '0.1')) - # enhanced components to align with Qwen3-Embedding denominator; controlled individually - # defaults set to False for backward compatibility - infonce_include_qq = strtobool(os.environ.get('INFONCE_INCLUDE_QQ', 'False')) - infonce_include_dd = strtobool(os.environ.get('INFONCE_INCLUDE_DD', 'False')) + def _parse_config(self): + """Parse InfoNCE loss configuration from environment variables.""" + hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) if hard_negatives is not None: hard_negatives = int(hard_negatives) + return { + 'temperature': float(os.environ.get('INFONCE_TEMPERATURE', '0.1')), + 'use_batch': strtobool(os.environ.get('INFONCE_USE_BATCH', 'True')), + 'hard_negatives': hard_negatives, + # mask out fake negatives + 'mask_fake_negative': strtobool(os.environ.get('INFONCE_MASK_FAKE_NEGATIVE', 'False')), + 'fake_neg_margin': float(os.environ.get('INFONCE_FAKE_NEG_MARGIN', '0.1')), + # enhanced components to align with Qwen3-Embedding denominator; controlled individually + # defaults set to False for backward compatibility + 'include_qq': strtobool(os.environ.get('INFONCE_INCLUDE_QQ', 'False')), + 'include_dd': strtobool(os.environ.get('INFONCE_INCLUDE_DD', 'False')), + } + + def _gather_distributed(self, sentences, labels, rank, world_size): + """Gather sentences and labels across ranks for cross-batch negatives.""" + if getattr(sequence_parallel, 'dp_group', None) is not None: + all_sentences = sequence_parallel._gather_object_dp(sentences.unsqueeze(0)) + labels = sequence_parallel._gather_object_dp(labels) + rank = sequence_parallel.dp_rank + elif self.is_megatron: + from megatron.core import mpu + dp_group = mpu.get_data_parallel_group() + # Gather sentences + shapes = [sentences.new_empty((2, ), dtype=torch.long) for _ in range(world_size)] + dist.all_gather( + shapes, + sentences.new_tensor(sentences.shape, dtype=torch.long), + group=dp_group, + ) + all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes] + dist.all_gather(all_sentences, sentences, group=dp_group) + # Gather labels (must also be gathered in megatron path) + all_labels = [labels.new_empty_like(labels) for _ in range(world_size)] + dist.all_gather(all_labels, labels, group=dp_group) + labels = all_labels + else: + # gather all the sentences and labels across the gpus when calculate loss across all batches of all gpus + all_sentences = gather_object(sentences.unsqueeze(0)) + labels = gather_object(labels) + + # Override with local sentences to preserve gradient flow + all_sentences[rank] = sentences + for idx in range(len(all_sentences)): + if idx == rank: + continue + # we don't calculate grad from other gpus + all_sentences[idx] = all_sentences[idx].detach().to(sentences.device) + sentences = torch.cat(all_sentences, dim=0) + labels = [tensor.to(sentences.device) for tensor in labels] + labels = torch.stack(labels, dim=0) + return sentences, labels, rank + + def _compute_local_loss(self, split_tensors, can_batched, temperature): + """Compute loss using only within-sample negatives (no cross-batch).""" + if can_batched: + # negative numbers are equal + sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D] + # [B, 1, D] * [B, D, neg+1] -> [B, 1, neg+1] + similarity = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / temperature + # The positive one is the first element + labels = torch.zeros(len(split_tensors), dtype=torch.int64, device=sentences.device) + return nn.CrossEntropyLoss()(similarity.squeeze(1), labels) + # the negative numbers may be different, use for loop + loss = 0 + for tensor in split_tensors: + # [D] * [neg+1, D] + similarity = torch.matmul(tensor[0], tensor[1:].T) / temperature + target = torch.tensor(0, device=tensor.device) + loss += nn.CrossEntropyLoss()(similarity, target) + return loss / len(split_tensors) + + @staticmethod + def _mask_fake_negatives(logits, threshold): + """Mask logits exceeding the threshold (fake negatives) with -inf.""" + return torch.where(logits > threshold, torch.tensor(float('-inf'), device=logits.device), logits) + + def _compute_cross_batch_loss_batched(self, split_tensors, config): + """Compute cross-batch loss when all samples have equal numbers of negatives.""" + temperature = config['temperature'] + sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D] + # base q->d similarities (includes own positive and all in-batch documents) + queries = sentences[:, 0].squeeze(1) # [B, D] + docs_all = sentences[:, 1:].reshape(-1, sentences.size(2)) # [B*(neg+1), D] + block_size = sentences.size(1) - 1 # neg + 1 + + qd_matrix = torch.matmul(queries, docs_all.T) # [B, B*(neg+1)] + # target indices: start of each group's document block (its positive) + labels = torch.arange(0, sentences.size(0) * block_size, block_size, device=sentences.device) + + logits_list = [qd_matrix] + + # Optional q->q similarity; exclude self via -inf on diagonal + qq_matrix = None + if config['include_qq']: + qq_matrix = torch.matmul(queries, queries.T) # [B, B] + qq_matrix = qq_matrix.clone() + qq_matrix.fill_diagonal_(float('-inf')) + logits_list.append(qq_matrix) + + # Optional d+->d (doc-doc) similarity; exclude self-positive column per row + dd_matrix = None + if config['include_dd']: + pos_docs = sentences[:, 1].squeeze(1) # [B, D] + dd_matrix = torch.matmul(pos_docs, docs_all.T) # [B, B*(neg+1)] + if block_size > 0: + row_idx = torch.arange(dd_matrix.size(0), device=dd_matrix.device) + dd_matrix[row_idx, row_idx * block_size] = float('-inf') + logits_list.append(dd_matrix) + + # Build final similarity matrix with optional fake-negative masking + if config['mask_fake_negative']: + row_idx = torch.arange(qd_matrix.size(0), device=qd_matrix.device) + pos_scores = qd_matrix[row_idx, labels] + thresholds = pos_scores.view(-1, 1).detach() + config['fake_neg_margin'] + + components = [self._mask_fake_negatives(qd_matrix, thresholds)] + if qq_matrix is not None: + # diagonal already masked unconditionally at construction time + components.append(self._mask_fake_negatives(qq_matrix, thresholds)) + if dd_matrix is not None: + # align with Qwen3-Embedding, no threshold masking for d-d + components.append(dd_matrix) + similarity_matrix = torch.cat(components, dim=1) + else: + similarity_matrix = torch.cat(logits_list, dim=1) + + similarity_matrix = similarity_matrix / temperature + return nn.CrossEntropyLoss()(similarity_matrix, labels) + + def _compute_cross_batch_loss_unbatched(self, split_tensors, config): + """Compute cross-batch loss when samples have varying numbers of negatives.""" + temperature = config['temperature'] + # Concatenate all documents (positive + negatives) across samples + all_docs = torch.cat([t[1:] for t in split_tensors], dim=0) # [total_docs, D] + + queries_all = None + if config['include_qq']: + queries_all = torch.stack([t[0] for t in split_tensors], dim=0) # [B, D] + + loss = 0 + offset = 0 # tracks position of current sample's positive in all_docs + for idx, tensor in enumerate(split_tensors): + query = tensor[0] # [D] + target = torch.tensor(offset, device=tensor.device) + + # q->d similarity + qd_vec = torch.matmul(query, all_docs.T) # [total_docs] + threshold = qd_vec[target].detach() + config['fake_neg_margin'] + + if config['mask_fake_negative']: + qd_vec = self._mask_fake_negatives(qd_vec, threshold) + logits_parts = [qd_vec] + + # Optional q->q + if config['include_qq']: + qq_vec = torch.matmul(query, queries_all.T) # [B] + qq_vec = qq_vec.clone() + qq_vec[idx] = float('-inf') # exclude self + if config['mask_fake_negative']: + qq_vec = self._mask_fake_negatives(qq_vec, threshold) + logits_parts.append(qq_vec) + + # Optional d+->d (no threshold masking for d-d) + if config['include_dd']: + dd_vec = torch.matmul(tensor[1], all_docs.T) # [total_docs] + dd_vec[offset] = float('-inf') # mask self-positive + logits_parts.append(dd_vec) + + logits_row = torch.cat(logits_parts, dim=-1) / temperature + loss += nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0)) + offset += tensor.size(0) - 1 + + return loss / len(split_tensors) + + def __call__(self, outputs, labels, **kwargs) -> torch.Tensor: + config = self._parse_config() + if self.is_megatron: from megatron.core import mpu rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() else: rank, _, world_size, _ = get_dist_setting() + # repeat of anchor(1)+positive(1)+negatives(n) sentences = outputs['last_hidden_state'] - if world_size > 1 and use_batch: - if getattr(sequence_parallel, 'dp_group', None) is not None: - all_sentences = sequence_parallel._gather_object_dp(sentences.unsqueeze(0)) - labels = sequence_parallel._gather_object_dp(labels) - rank = sequence_parallel.dp_rank - elif self.is_megatron: - from megatron.core import mpu - dp_group = mpu.get_data_parallel_group() - shapes = [sentences.new_empty((2, ), dtype=torch.long) for _ in range(world_size)] - dist.all_gather( - shapes, - sentences.new_tensor(sentences.shape, dtype=torch.long), - group=dp_group, - ) - all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes] - dist.all_gather( - all_sentences, - sentences, - group=dp_group, - ) - else: - # gather all the sentences and labels across the gpus when calculate loss across all batches of all gpus - all_sentences = gather_object(sentences.unsqueeze(0)) - labels = gather_object(labels) - # override the gathered one - all_sentences[rank] = sentences - for idx in range(len(all_sentences)): - if idx == rank: - continue - # we don't calculate grad from other gpus - all_sentences[idx] = all_sentences[idx].detach().to(sentences.device) - sentences = torch.cat(all_sentences, dim=0) - labels = [tensor.to(sentences.device) for tensor in labels] - labels = torch.stack(labels, dim=0) + if world_size > 1 and config['use_batch']: + sentences, labels, rank = self._gather_distributed(sentences, labels, rank, world_size) # split tensors into single sample # for example: batch_size=2 with tensor anchor(1)+positive(1)+negatives(3) + anchor(1)+positive(1)+negatives(2) # labels will be [1,0,0,0,1,0,0], meaning 1 positive, 3 negatives, 1 positive, 2 negatives - split_tensors = _parse_multi_negative_sentences(sentences, labels, hard_negatives) - loss = 0 - can_batched = hard_negatives is not None - if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1: - # all tensors have the same batch size - can_batched = True - if not use_batch: - # only calculate loss inside one sample - if can_batched: - # negative numbers are equal - # [B, neg+2, D] - sentences = torch.stack(split_tensors, dim=0) - # [B, 1, D] * [B, neg+1, D] - similarity_matrix = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / temperature - # The positive one is the first element - labels = torch.zeros(len(split_tensors), dtype=torch.int64).to(sentences.device) - loss = nn.CrossEntropyLoss()(similarity_matrix.squeeze(1), labels) - else: - # the negative numbers may be different, use for loop - for tensor in split_tensors: - # [D] * [neg+1, D] - similarity_matrix = torch.matmul(tensor[0], tensor[1:].T) / temperature - # The positive one is the first element - labels = torch.tensor(0).to(tensor.device) - loss += nn.CrossEntropyLoss()(similarity_matrix, labels) - # avg between all batches in one gpu - loss /= len(split_tensors) + split_tensors = _parse_multi_negative_sentences(sentences, labels, config['hard_negatives']) + + # Determine if all samples can be batched (equal negative counts) + can_batched = config['hard_negatives'] is not None or len(set(s.shape[0] for s in split_tensors)) == 1 + + if not config['use_batch']: + return self._compute_local_loss(split_tensors, can_batched, config['temperature']) + elif can_batched: + return self._compute_cross_batch_loss_batched(split_tensors, config) else: - if can_batched: - # [B, neg+2, D] - sentences = torch.stack(split_tensors, dim=0) - # base q->d similarities (includes own positive and all in-batch documents) - queries = sentences[:, 0].squeeze(1) # [B, D] - docs_all = sentences[:, 1:].reshape(-1, sentences.size(2)) # [B*(neg+1), D] - qd_matrix = torch.matmul(queries, docs_all.T) # [B, B*(neg+1)] - # target indices: start of each group's document block (its positive) - labels = torch.tensor(range(0, - sentences.size(0) * (sentences.size(1) - 1), - sentences.size(1) - 1)).view(-1).to(sentences.device) - - logits_list = [qd_matrix] - - if infonce_include_qq: - # q->q similarities; exclude self via -inf on diagonal to avoid accidental positives - qq_matrix = torch.matmul(queries, queries.T) # [B, B] - qq_matrix = qq_matrix.clone() - qq_matrix.fill_diagonal_(float('-inf')) - logits_list.append(qq_matrix) - - if infonce_include_dd: - # d+ -> d (doc-doc) similarities; exclude self-positive column per row - pos_docs = sentences[:, 1].squeeze(1) # [B, D] - dd_matrix = torch.matmul(pos_docs, docs_all.T) # [B, B*(neg+1)] - # mask self positive per row: column index = row_idx * (neg+1) - block = sentences.size(1) - 1 # (neg+1) - if block > 0: - row_idx = torch.arange(dd_matrix.size(0), device=dd_matrix.device) - col_idx = row_idx * block - dd_matrix[row_idx, col_idx] = float('-inf') - logits_list.append(dd_matrix) - - if infonce_mask_fake_negative: - # thresholds derived from positive q->d scores per row - row_idx = torch.arange(qd_matrix.size(0), device=qd_matrix.device) - pos_scores = qd_matrix[row_idx, labels] - thresholds = pos_scores.view(-1, 1).detach() + fake_neg_margin - - # qd block mask - qd_block = qd_matrix.clone() - qd_mask = qd_block > thresholds - qd_block[qd_mask] = float('-inf') - - components = [qd_block] - - # qq block mask (if present) - if infonce_include_qq: - qq_block = qq_matrix.clone() - qq_mask = qq_block > thresholds - qq_block[qq_mask] = float('-inf') - # diagonal already masked unconditionally at construction time - components.append(qq_block) - - # dd block (if present): self-positive column already masked unconditionally - if infonce_include_dd: - # align with Qwen3-Embedding, no threshold masking for d-d - components.append(dd_matrix) - - similarity_matrix = torch.cat(components, dim=1) - else: - # concatenate all components without masking - similarity_matrix = torch.cat(logits_list, dim=1) - # temperature scaling and CE - similarity_matrix = similarity_matrix / temperature - loss = nn.CrossEntropyLoss()(similarity_matrix, labels) - else: - all_tensors = [] - for tensor in split_tensors: - all_tensors.append(tensor[1:]) - # cat all neg+1 tensors - sentences = torch.cat(all_tensors, dim=0) - # prepare query anchors list if q-q is included - if infonce_include_qq: - queries_all = torch.stack([t[0] for t in split_tensors], dim=0) # [B, D] - length = 0 - for idx, tensor in enumerate(split_tensors): - # [D] * [B*(neg+1), D], neg numbers are different - qd_vec = torch.matmul(tensor[0], sentences.T) - target = torch.tensor(length).to(tensor.device) - logits_parts = [] - - # compute threshold from positive q->d score - threshold = (qd_vec[target].detach() + fake_neg_margin) - - # qd part with masking - if infonce_mask_fake_negative: - qd_masked = torch.where(qd_vec > threshold, torch.tensor(float('-inf'), device=qd_vec.device), - qd_vec) - else: - qd_masked = qd_vec - logits_parts.append(qd_masked) - - # qq part - if infonce_include_qq: - qq_vec = torch.matmul(tensor[0], queries_all.T) # [B] - # exclude self - qq_vec = qq_vec.clone() - qq_vec[idx] = float('-inf') - if infonce_mask_fake_negative: - qq_vec = torch.where(qq_vec > threshold, torch.tensor(float('-inf'), device=qq_vec.device), - qq_vec) - logits_parts.append(qq_vec) - - # dd part - if infonce_include_dd: - dd_vec = torch.matmul(tensor[1], sentences.T) # [B*(neg+1)] - # mask self positive column for this row only (no threshold masking for d-d) - block = split_tensors[idx].size(0) - 1 # (neg+1) for this group - dd_vec[length] = float('-inf') - logits_parts.append(dd_vec) - - logits_row = torch.cat(logits_parts, dim=-1) - logits_row = logits_row / temperature - loss += nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0)) - # next positive is neg+1 - length += tensor.size(0) - 1 - loss /= len(split_tensors) - return loss + return self._compute_cross_batch_loss_unbatched(split_tensors, config) diff --git a/swift/loss/mapping.py b/swift/loss/mapping.py index 7f1233ac56..8bed2b775b 100644 --- a/swift/loss/mapping.py +++ b/swift/loss/mapping.py @@ -1,6 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .causal_lm import CustomCrossEntropyLoss -from .embedding import ContrastiveLoss, CosineSimilarityLoss, InfonceLoss, OnlineContrastiveLoss +from .embedding import ContrastiveLoss, CosineSimilarityLoss, InfoNCELoss, OnlineContrastiveLoss from .reranker import ListwiseRerankerLoss, PointwiseRerankerLoss loss_map = { @@ -9,7 +9,7 @@ 'cosine_similarity': CosineSimilarityLoss, 'contrastive': ContrastiveLoss, 'online_contrastive': OnlineContrastiveLoss, - 'infonce': InfonceLoss, + 'infonce': InfoNCELoss, # # reranker 'pointwise_reranker': PointwiseRerankerLoss, 'listwise_reranker': ListwiseRerankerLoss,