refactor embedding infonce loss#9179
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the InfonceLoss class into InfoNCELoss, modularizing the implementation into helper methods for configuration parsing, distributed data gathering, and specific loss calculation paths. The review feedback highlights a critical bug in the distributed gathering logic where torch.stack should be replaced with torch.cat to maintain the expected 1D label structure. Other suggestions include improving efficiency by using F.cross_entropy instead of instantiating loss classes within loops, removing redundant squeeze(1) calls that could cause errors with single-dimension embeddings, and addressing potential hangs in the Megatron path due to varying label shapes across ranks.
| all_sentences[idx] = all_sentences[idx].detach().to(sentences.device) | ||
| sentences = torch.cat(all_sentences, dim=0) | ||
| labels = [tensor.to(sentences.device) for tensor in labels] | ||
| labels = torch.stack(labels, dim=0) |
There was a problem hiding this comment.
The labels tensors gathered from different ranks should be concatenated using torch.cat instead of torch.stack. torch.stack adds a new dimension, resulting in a 2D tensor (e.g., [world_size, batch_size_per_gpu]), which will break the logic in _parse_multi_negative_sentences that expects a 1D tensor to identify split points. Concatenating them preserves the 1D structure required for correct indexing of the global sentences tensor.
| labels = torch.stack(labels, dim=0) | |
| labels = torch.cat(labels, dim=0) |
| all_labels = [labels.new_empty_like(labels) for _ in range(world_size)] | ||
| dist.all_gather(all_labels, labels, group=dp_group) |
There was a problem hiding this comment.
In the Megatron path, labels are gathered using dist.all_gather with new_empty_like(labels), which assumes that labels have the same shape across all ranks. If the number of samples or negatives varies per rank, this will cause a hang or crash. Consider gathering the shapes of labels first (as done for sentences in lines 143-150) or using gather_object if performance allows for the labels tensor.
| # [D] * [neg+1, D] | ||
| similarity = torch.matmul(tensor[0], tensor[1:].T) / temperature | ||
| target = torch.tensor(0, device=tensor.device) | ||
| loss += nn.CrossEntropyLoss()(similarity, target) |
There was a problem hiding this comment.
| temperature = config['temperature'] | ||
| sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D] | ||
| # base q->d similarities (includes own positive and all in-batch documents) | ||
| queries = sentences[:, 0].squeeze(1) # [B, D] |
There was a problem hiding this comment.
The squeeze(1) call on sentences[:, 0] is unnecessary and potentially risky. sentences[:, 0] already has the shape [B, D]. If the embedding dimension D happens to be 1, squeeze(1) will reduce the tensor to shape [B], which will cause the subsequent torch.matmul to fail or produce incorrect results.
| queries = sentences[:, 0].squeeze(1) # [B, D] | |
| queries = sentences[:, 0] # [B, D] |
| # Optional d+->d (doc-doc) similarity; exclude self-positive column per row | ||
| dd_matrix = None | ||
| if config['include_dd']: | ||
| pos_docs = sentences[:, 1].squeeze(1) # [B, D] |
| logits_parts.append(dd_vec) | ||
|
|
||
| logits_row = torch.cat(logits_parts, dim=-1) / temperature | ||
| loss += nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0)) |
No description provided.