Why are we not multiplying the LM Head flops per iteration with the checkpoint_activations_factor?
|
flops_per_iteration += (6 * batch_size * seq_len * num_layers * (hidden_size**2)) * (vocab_size / (num_layers * hidden_size)) |
Afaik the factor of 4 means 1 forward, 2 backward & 1 forward, where the last forward is needed for ckpt acts. Don't we also need all 4 for the LM Head? cc @RaymondLi0 @NouamaneTazi
Why are we not multiplying the LM Head flops per iteration with the
checkpoint_activations_factor?Megatron-LM/megatron/utils.py
Line 253 in bd0aaba
Afaik the factor of 4 means 1 forward, 2 backward & 1 forward, where the last forward is needed for ckpt acts. Don't we also need all 4 for the LM Head? cc @RaymondLi0 @NouamaneTazi