Skip to content

Commit b979eed

Browse files
committed
MFU: use padded_vocab_size for mfu_padded_pct LM-head FLOPs
For configs with padded_vocab_size set (ESM-2: 33→64 for FP8/tensor-core friendliness), the LM-head matmul physically runs at padded width and the logits are sliced back afterward. Count the padded width in the hardware-view metric (mfu_padded_pct, tflops_per_gpu_padded) while continuing to count raw vocab_size in the useful-work metric (mfu_pct, tflops_per_gpu). For configs without padded_vocab_size (llama3, og2, codonfm) the two values collapse and nothing changes. Addresses review feedback from @trvachov on PR #1548. Signed-off-by: Gagan Kaushik <gkaushik@nvidia.com>
1 parent 423eab7 commit b979eed

4 files changed

Lines changed: 40 additions & 8 deletions

File tree

bionemo-recipes/recipes/codonfm_native_te/perf_logger.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _detect_peak_tflops_bf16():
6666
return None, name
6767

6868

69-
def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
69+
def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int:
7070
"""Per-token FLOPs for everything EXCEPT the S² attention term.
7171
7272
Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
@@ -81,6 +81,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
8181
kv_dim = n_kv * head_dim
8282
ffn = model_config_dict["intermediate_size"]
8383
vocab = model_config_dict.get("vocab_size", 0)
84+
if use_padded_vocab:
85+
# LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for
86+
# FP8/tensor-core friendliness); logits are sliced back post-matmul.
87+
vocab = model_config_dict.get("padded_vocab_size") or vocab
8488
num_layers = model_config_dict["num_hidden_layers"]
8589
model_type = model_config_dict.get("model_type", "")
8690
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2
@@ -192,11 +196,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi
192196
# reflects each rank's share under DP and sequence packing.
193197
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
194198
self._non_attn_per_token_flops = 0
199+
self._non_attn_per_token_flops_padded = 0
195200
self._attn_flop_coeff = 0
196201
self._cp_size = int(args.get("cp_size", 1))
197202
self._peak_tflops: float | None = None
198203
if self._log_mfu:
199204
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
205+
self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops(
206+
model_config_dict, use_padded_vocab=True
207+
)
200208
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
201209
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
202210
if dist_config.local_rank == 0:
@@ -348,7 +356,7 @@ def log_step(
348356
flops_unpadded = non_attn_unpadded + attn_flops_unpadded
349357
tflops_unpadded = flops_unpadded / step_time / 1e12
350358

351-
non_attn_padded = self._non_attn_per_token_flops * self.num_tokens
359+
non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens
352360
attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size
353361
flops_padded = non_attn_padded + attn_flops_padded
354362
tflops_padded = flops_padded / step_time / 1e12

bionemo-recipes/recipes/esm2_native_te/perf_logger.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _detect_peak_tflops_bf16():
6565
return None, name
6666

6767

68-
def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
68+
def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int:
6969
"""Per-token FLOPs for everything EXCEPT the S² attention term.
7070
7171
Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
@@ -80,6 +80,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
8080
kv_dim = n_kv * head_dim
8181
ffn = model_config_dict["intermediate_size"]
8282
vocab = model_config_dict.get("vocab_size", 0)
83+
if use_padded_vocab:
84+
# LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for
85+
# FP8/tensor-core friendliness); logits are sliced back post-matmul.
86+
vocab = model_config_dict.get("padded_vocab_size") or vocab
8387
num_layers = model_config_dict["num_hidden_layers"]
8488
model_type = model_config_dict.get("model_type", "")
8589
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2
@@ -195,11 +199,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi
195199
# already reflects each rank's share under DP/CP and sequence packing.
196200
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
197201
self._non_attn_per_token_flops = 0
202+
self._non_attn_per_token_flops_padded = 0
198203
self._attn_flop_coeff = 0
199204
self._cp_size = int(args.get("cp_size", 1))
200205
self._peak_tflops: float | None = None
201206
if self._log_mfu:
202207
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
208+
self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops(
209+
model_config_dict, use_padded_vocab=True
210+
)
203211
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
204212
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
205213
if dist_config.local_rank == 0:
@@ -357,7 +365,7 @@ def log_step(
357365
flops_unpadded = non_attn_unpadded + attn_flops_unpadded
358366
tflops_unpadded = flops_unpadded / step_time / 1e12
359367

360-
non_attn_padded = self._non_attn_per_token_flops * self.num_tokens
368+
non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens
361369
attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size
362370
flops_padded = non_attn_padded + attn_flops_padded
363371
tflops_padded = flops_padded / step_time / 1e12

bionemo-recipes/recipes/llama3_native_te/perf_logger.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _detect_peak_tflops_bf16():
6363
return None, name
6464

6565

66-
def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
66+
def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int:
6767
"""Per-token FLOPs for everything EXCEPT the S² attention term.
6868
6969
Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
@@ -78,6 +78,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
7878
kv_dim = n_kv * head_dim
7979
ffn = model_config_dict["intermediate_size"]
8080
vocab = model_config_dict.get("vocab_size", 0)
81+
if use_padded_vocab:
82+
# LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for
83+
# FP8/tensor-core friendliness); logits are sliced back post-matmul.
84+
vocab = model_config_dict.get("padded_vocab_size") or vocab
8185
num_layers = model_config_dict["num_hidden_layers"]
8286
model_type = model_config_dict.get("model_type", "")
8387
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2
@@ -201,11 +205,15 @@ def __init__(
201205
# reflects each rank's share under DP/CP and sequence packing.
202206
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
203207
self._non_attn_per_token_flops = 0
208+
self._non_attn_per_token_flops_padded = 0
204209
self._attn_flop_coeff = 0
205210
self._cp_size = int(args.get("cp_size", 1))
206211
self._peak_tflops: float | None = None
207212
if self._log_mfu:
208213
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
214+
self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops(
215+
model_config_dict, use_padded_vocab=True
216+
)
209217
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
210218
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
211219
if dist_config.local_rank == 0:
@@ -384,7 +392,7 @@ def log_step(
384392
flops_unpadded = non_attn_unpadded + attn_flops_unpadded
385393
tflops_unpadded = flops_unpadded / step_time / 1e12
386394

387-
non_attn_padded = self._non_attn_per_token_flops * self.num_tokens
395+
non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens
388396
attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size
389397
flops_padded = non_attn_padded + attn_flops_padded
390398
tflops_padded = flops_padded / step_time / 1e12

bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _detect_peak_tflops_bf16():
7171
return None, name
7272

7373

74-
def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
74+
def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int:
7575
"""Per-token FLOPs for everything EXCEPT the S² attention term.
7676
7777
Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
@@ -86,6 +86,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
8686
kv_dim = n_kv * head_dim
8787
ffn = model_config_dict["intermediate_size"]
8888
vocab = model_config_dict.get("vocab_size", 0)
89+
if use_padded_vocab:
90+
# LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for
91+
# FP8/tensor-core friendliness); logits are sliced back post-matmul.
92+
vocab = model_config_dict.get("padded_vocab_size") or vocab
8993
num_layers = model_config_dict["num_hidden_layers"]
9094
model_type = model_config_dict.get("model_type", "")
9195
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2
@@ -197,11 +201,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi
197201
# reflects each rank's share under DP/CP and sequence packing.
198202
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
199203
self._non_attn_per_token_flops = 0
204+
self._non_attn_per_token_flops_padded = 0
200205
self._attn_flop_coeff = 0
201206
self._cp_size = int(args.get("cp_size", 1))
202207
self._peak_tflops: float | None = None
203208
if self._log_mfu:
204209
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
210+
self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops(
211+
model_config_dict, use_padded_vocab=True
212+
)
205213
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
206214
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
207215
if dist_config.local_rank == 0:
@@ -373,7 +381,7 @@ def log_step(
373381
flops_unpadded = non_attn_unpadded + attn_flops_unpadded
374382
tflops_unpadded = flops_unpadded / step_time / 1e12
375383

376-
non_attn_padded = self._non_attn_per_token_flops * self.num_tokens
384+
non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens
377385
attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size
378386
flops_padded = non_attn_padded + attn_flops_padded
379387
tflops_padded = flops_padded / step_time / 1e12

0 commit comments

Comments
 (0)