Skip to content

Commit c95b79f

Browse files
committed
Use torch.gather instead of generic int indexing for cross_entropy example
stack-info: PR: #2058, branch: AmesingFlank/stack/25
1 parent 584027c commit c95b79f

1 file changed

Lines changed: 2 additions & 10 deletions

File tree

examples/cross_entropy.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,16 @@ def cross_entropy(
4848
n, v = logits.shape
4949
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
5050

51-
# Flatten logits once at the beginning
52-
logits_flat = logits.view(-1)
53-
5451
for tile_n in hl.tile(n):
5552
# Get data for this tile
5653
labels_tile = labels[tile_n] # [tile_size]
57-
base_indices_tile = tile_n.index * v # [tile_size]
58-
59-
# Compute the actual flat indices by adding the label offset
60-
flat_indices = base_indices_tile + labels_tile
61-
62-
# Load the logits at the target indices
63-
logits_at_target = hl.load(logits_flat, [flat_indices])
6454

6555
# Compute log_softmax for numerical stability
6656
# Load the full rows for this tile
6757
logits_rows = logits[tile_n, :] # [tile_size, V]
6858

59+
logits_at_target = logits_rows.gather(1, labels_tile.unsqueeze(1)).squeeze(1)
60+
6961
# Compute log-sum-exp
7062
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
7163
shifted = logits_rows - max_logits

0 commit comments

Comments
 (0)