Skip to content

refactor embedding infonce loss#9179

Open
Jintao-Huang wants to merge 1 commit into
modelscope:mainfrom
Jintao-Huang:fix_refactor_infonce_loss
Open

refactor embedding infonce loss#9179
Jintao-Huang wants to merge 1 commit into
modelscope:mainfrom
Jintao-Huang:fix_refactor_infonce_loss

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the InfonceLoss class into InfoNCELoss, modularizing the implementation into helper methods for configuration parsing, distributed data gathering, and specific loss calculation paths. The review feedback highlights a critical bug in the distributed gathering logic where torch.stack should be replaced with torch.cat to maintain the expected 1D label structure. Other suggestions include improving efficiency by using F.cross_entropy instead of instantiating loss classes within loops, removing redundant squeeze(1) calls that could cause errors with single-dimension embeddings, and addressing potential hangs in the Megatron path due to varying label shapes across ranks.

Comment thread swift/loss/embedding.py
all_sentences[idx] = all_sentences[idx].detach().to(sentences.device)
sentences = torch.cat(all_sentences, dim=0)
labels = [tensor.to(sentences.device) for tensor in labels]
labels = torch.stack(labels, dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The labels tensors gathered from different ranks should be concatenated using torch.cat instead of torch.stack. torch.stack adds a new dimension, resulting in a 2D tensor (e.g., [world_size, batch_size_per_gpu]), which will break the logic in _parse_multi_negative_sentences that expects a 1D tensor to identify split points. Concatenating them preserves the 1D structure required for correct indexing of the global sentences tensor.

Suggested change
labels = torch.stack(labels, dim=0)
labels = torch.cat(labels, dim=0)

Comment thread swift/loss/embedding.py
Comment on lines +152 to +153
all_labels = [labels.new_empty_like(labels) for _ in range(world_size)]
dist.all_gather(all_labels, labels, group=dp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the Megatron path, labels are gathered using dist.all_gather with new_empty_like(labels), which assumes that labels have the same shape across all ranks. If the number of samples or negatives varies per rank, this will cause a hang or crash. Consider gathering the shapes of labels first (as done for sentences in lines 143-150) or using gather_object if performance allows for the labels tensor.

Comment thread swift/loss/embedding.py
# [D] * [neg+1, D]
similarity = torch.matmul(tensor[0], tensor[1:].T) / temperature
target = torch.tensor(0, device=tensor.device)
loss += nn.CrossEntropyLoss()(similarity, target)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instantiating nn.CrossEntropyLoss() inside a loop is inefficient. It is better to use the functional interface F.cross_entropy which is already imported.

Suggested change
loss += nn.CrossEntropyLoss()(similarity, target)
loss += F.cross_entropy(similarity.unsqueeze(0), target.unsqueeze(0))

Comment thread swift/loss/embedding.py
temperature = config['temperature']
sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D]
# base q->d similarities (includes own positive and all in-batch documents)
queries = sentences[:, 0].squeeze(1) # [B, D]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The squeeze(1) call on sentences[:, 0] is unnecessary and potentially risky. sentences[:, 0] already has the shape [B, D]. If the embedding dimension D happens to be 1, squeeze(1) will reduce the tensor to shape [B], which will cause the subsequent torch.matmul to fail or produce incorrect results.

Suggested change
queries = sentences[:, 0].squeeze(1) # [B, D]
queries = sentences[:, 0] # [B, D]

Comment thread swift/loss/embedding.py
# Optional d+->d (doc-doc) similarity; exclude self-positive column per row
dd_matrix = None
if config['include_dd']:
pos_docs = sentences[:, 1].squeeze(1) # [B, D]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the queries extraction, squeeze(1) on sentences[:, 1] is unnecessary and could lead to issues if the embedding dimension is 1.

Suggested change
pos_docs = sentences[:, 1].squeeze(1) # [B, D]
pos_docs = sentences[:, 1] # [B, D]

Comment thread swift/loss/embedding.py
logits_parts.append(dd_vec)

logits_row = torch.cat(logits_parts, dim=-1) / temperature
loss += nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instantiating nn.CrossEntropyLoss() inside a loop is inefficient. Use F.cross_entropy instead.

Suggested change
loss += nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0))
loss += F.cross_entropy(logits_row.unsqueeze(0), target.unsqueeze(0))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant