File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -48,24 +48,16 @@ def cross_entropy(
4848 n , v = logits .shape
4949 losses = torch .zeros ([n ], dtype = logits .dtype , device = logits .device )
5050
51- # Flatten logits once at the beginning
52- logits_flat = logits .view (- 1 )
53-
5451 for tile_n in hl .tile (n ):
5552 # Get data for this tile
5653 labels_tile = labels [tile_n ] # [tile_size]
57- base_indices_tile = tile_n .index * v # [tile_size]
58-
59- # Compute the actual flat indices by adding the label offset
60- flat_indices = base_indices_tile + labels_tile
61-
62- # Load the logits at the target indices
63- logits_at_target = hl .load (logits_flat , [flat_indices ])
6454
6555 # Compute log_softmax for numerical stability
6656 # Load the full rows for this tile
6757 logits_rows = logits [tile_n , :] # [tile_size, V]
6858
59+ logits_at_target = logits_rows .gather (1 , labels_tile .unsqueeze (1 )).squeeze (1 )
60+
6961 # Compute log-sum-exp
7062 max_logits = torch .amax (logits_rows , dim = - 1 , keepdim = True )
7163 shifted = logits_rows - max_logits
You can’t perform that action at this time.
0 commit comments