Skip to content

Commit 909c1d7

Browse files
committed
perf_logger: remove legacy _compute_per_token_flops back-compat shim
The single-term per-token FLOP function was kept around after the two-term split (non_attn + coeff·Σ(Lᵢ²)) was introduced, solely so old unit tests and the startup log line could still reference it. Nothing in the runtime path was using it — log_step had already switched to the split formula. Remove the function, the self._per_token_flops attribute, and the test scaffolding that existed only to pin the split against the legacy reference: - _compute_per_token_flops(cfg, seq_len) deleted from all 4 recipes - self._per_token_flops init + assignment + log-line mention removed - logger.info startup format no longer includes "per-token FLOPs=%e" - TestComputePerTokenFlops test class deleted (llama3 only) - test_algebraic_identity deleted (its whole purpose was to pin non_attn + coeff·S against the legacy function) - test_bshd_no_op simplified to test_bshd_shape_synthesis — keeps the Σ(Lᵢ²)=B·S² shape-synthesis check, drops the legacy comparison - Docstrings no longer reference the old function The tests that exercise actual formula correctness (test_thd_multi_doc_uses_squared_sum, test_cp_size_divides_attention_only, test_bshd_cp_correction, test_include_padding_*, etc.) all stay — they verify behavior directly without going through the legacy. Net: -336 / +40 lines. Pure dead-code removal, no runtime change. Signed-off-by: Gagan Kaushik <gkaushik@nvidia.com>
1 parent f8f84cb commit 909c1d7

8 files changed

Lines changed: 40 additions & 336 deletions

File tree

bionemo-recipes/recipes/codonfm_native_te/perf_logger.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -66,50 +66,13 @@ def _detect_peak_tflops_bf16():
6666
return None, name
6767

6868

69-
def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int:
70-
"""Training FLOPs per token for a transformer (forward + backward = 3x forward).
71-
72-
First-principles matmul count: Q/K/V/O projections (GQA-aware), attention
73-
logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform
74-
BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via
75-
model_type), and LM head.
76-
77-
Kept for back-compat. For accurate per-step accounting use
78-
``_compute_non_attn_per_token_flops`` (applied to the total token count)
79-
together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from
80-
cu_seq_lens), since a packed THD batch of total length S containing docs
81-
L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S².
82-
"""
83-
h = model_config_dict["hidden_size"]
84-
n_heads = model_config_dict["num_attention_heads"]
85-
n_kv = model_config_dict.get("num_key_value_heads", n_heads)
86-
head_dim = h // n_heads
87-
kv_dim = n_kv * head_dim
88-
ffn = model_config_dict["intermediate_size"]
89-
vocab = model_config_dict.get("vocab_size", 0)
90-
num_layers = model_config_dict["num_hidden_layers"]
91-
model_type = model_config_dict.get("model_type", "")
92-
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2
93-
94-
per_layer = (
95-
2 * h * h # Q projection
96-
+ 4 * h * kv_dim # K + V projections (GQA-aware)
97-
+ 2 * h * h # O projection
98-
+ 4 * seq_len * h # attention logits + values (S^2 -> S per token)
99-
+ 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections)
100-
)
101-
lm_head = 2 * h * vocab if vocab > 0 else 0
102-
per_token_fwd = num_layers * per_layer + lm_head
103-
return 3 * per_token_fwd
104-
105-
10669
def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
10770
"""Per-token FLOPs for everything EXCEPT the S² attention term.
10871
10972
Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
11073
actual total token count of the batch to get per-step non-attention FLOPs. Pairs
111-
with ``_compute_attn_flop_coeff`` so that
112-
``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``.
74+
with ``_compute_attn_flop_coeff``, which contributes the attention term as
75+
``coeff · Σ(Lᵢ²)`` from cu_seq_lens.
11376
"""
11477
h = model_config_dict["hidden_size"]
11578
n_heads = model_config_dict["num_attention_heads"]
@@ -228,26 +191,23 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi
228191
# step are derived at log time from the tracked unpadded token count, which already
229192
# reflects each rank's share under DP and sequence packing.
230193
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
231-
self._per_token_flops = 0
232194
self._non_attn_per_token_flops = 0
233195
self._attn_flop_coeff = 0
234196
self._cp_size = int(args.get("cp_size", 1))
235197
self._peak_tflops: float | None = None
236198
if self._log_mfu:
237-
self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length)
238199
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
239200
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
240201
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
241202
if dist_config.local_rank == 0:
242203
logger.info(
243-
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, "
244-
"non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d",
204+
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, "
205+
"non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d",
245206
gpu_name,
246207
f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown",
247-
float(self._per_token_flops),
248-
args.dataset.max_seq_length,
249208
float(self._non_attn_per_token_flops),
250209
float(self._attn_flop_coeff),
210+
args.dataset.max_seq_length,
251211
self._cp_size,
252212
)
253213

bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
_attn_work_from_batch,
2727
_compute_attn_flop_coeff,
2828
_compute_non_attn_per_token_flops,
29-
_compute_per_token_flops,
3029
)
3130
from transformers.modeling_outputs import MaskedLMOutput
3231

@@ -231,26 +230,14 @@ def _codon_cfg():
231230

232231

233232
class TestFlopSplitAndAttention:
234-
"""Verify the split non-attn + Σ(Lᵢ²) attention formula."""
233+
"""Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed."""
235234

236-
def test_algebraic_identity(self):
237-
"""non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S."""
238-
cfg = _codon_cfg()
239-
for s in (256, 512, 1024, 8192):
240-
lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s
241-
rhs = _compute_per_token_flops(cfg, s)
242-
assert lhs == rhs, f"S={s}: {lhs} != {rhs}"
243-
244-
def test_bshd_no_op(self):
245-
"""BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly."""
246-
cfg = _codon_cfg()
235+
def test_bshd_shape_synthesis(self):
236+
"""BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape."""
247237
b, s = 4, 512
248238
batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)}
249239
sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item()
250240
assert sigma_l_sq == b * s * s
251-
new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq
252-
legacy_flops = _compute_per_token_flops(cfg, s) * b * s
253-
assert new_flops == legacy_flops
254241

255242
def test_thd_single_doc_matches_bshd(self):
256243
"""cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S²."""

bionemo-recipes/recipes/esm2_native_te/perf_logger.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -65,50 +65,13 @@ def _detect_peak_tflops_bf16():
6565
return None, name
6666

6767

68-
def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int:
69-
"""Training FLOPs per token for a transformer (forward + backward = 3x forward).
70-
71-
First-principles matmul count: Q/K/V/O projections (GQA-aware), attention
72-
logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform
73-
BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via
74-
model_type), and LM head.
75-
76-
Kept for back-compat. For accurate per-step accounting use
77-
``_compute_non_attn_per_token_flops`` (applied to the total token count)
78-
together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from
79-
cu_seq_lens), since a packed THD batch of total length S containing docs
80-
L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S².
81-
"""
82-
h = model_config_dict["hidden_size"]
83-
n_heads = model_config_dict["num_attention_heads"]
84-
n_kv = model_config_dict.get("num_key_value_heads", n_heads)
85-
head_dim = h // n_heads
86-
kv_dim = n_kv * head_dim
87-
ffn = model_config_dict["intermediate_size"]
88-
vocab = model_config_dict.get("vocab_size", 0)
89-
num_layers = model_config_dict["num_hidden_layers"]
90-
model_type = model_config_dict.get("model_type", "")
91-
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2
92-
93-
per_layer = (
94-
2 * h * h # Q projection
95-
+ 4 * h * kv_dim # K + V projections (GQA-aware)
96-
+ 2 * h * h # O projection
97-
+ 4 * seq_len * h # attention logits + values (S^2 -> S per token)
98-
+ 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections)
99-
)
100-
lm_head = 2 * h * vocab if vocab > 0 else 0
101-
per_token_fwd = num_layers * per_layer + lm_head
102-
return 3 * per_token_fwd
103-
104-
10568
def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
10669
"""Per-token FLOPs for everything EXCEPT the S² attention term.
10770
10871
Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
10972
actual total token count of the batch to get per-step non-attention FLOPs. Pairs
110-
with ``_compute_attn_flop_coeff`` so that
111-
``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``.
73+
with ``_compute_attn_flop_coeff``, which contributes the attention term as
74+
``coeff · Σ(Lᵢ²)`` from cu_seq_lens.
11275
"""
11376
h = model_config_dict["hidden_size"]
11477
n_heads = model_config_dict["num_attention_heads"]
@@ -231,26 +194,23 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi
231194
# step are derived at log time from the accumulated token count + Σ(Lᵢ²), which
232195
# already reflects each rank's share under DP/CP and sequence packing.
233196
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
234-
self._per_token_flops = 0
235197
self._non_attn_per_token_flops = 0
236198
self._attn_flop_coeff = 0
237199
self._cp_size = int(args.get("cp_size", 1))
238200
self._peak_tflops: float | None = None
239201
if self._log_mfu:
240-
self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length)
241202
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
242203
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
243204
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
244205
if dist_config.local_rank == 0:
245206
logger.info(
246-
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, "
247-
"non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d",
207+
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, "
208+
"non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d",
248209
gpu_name,
249210
f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown",
250-
float(self._per_token_flops),
251-
args.dataset.max_seq_length,
252211
float(self._non_attn_per_token_flops),
253212
float(self._attn_flop_coeff),
213+
args.dataset.max_seq_length,
254214
self._cp_size,
255215
)
256216

bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
_attn_work_from_batch,
3636
_compute_attn_flop_coeff,
3737
_compute_non_attn_per_token_flops,
38-
_compute_per_token_flops,
3938
)
4039

4140

@@ -96,26 +95,14 @@ def _esm_cfg():
9695

9796

9897
class TestFlopSplitAndAttention:
99-
"""Verify the split non-attn + Σ(Lᵢ²) attention formula for ESM-2."""
98+
"""Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed for ESM-2."""
10099

101-
def test_algebraic_identity(self):
102-
"""non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S."""
103-
cfg = _esm_cfg()
104-
for s in (256, 1024, 8192, 131072):
105-
lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s
106-
rhs = _compute_per_token_flops(cfg, s)
107-
assert lhs == rhs, f"S={s}: {lhs} != {rhs}"
108-
109-
def test_bshd_no_op(self):
110-
"""BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly."""
111-
cfg = _esm_cfg()
100+
def test_bshd_shape_synthesis(self):
101+
"""BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape."""
112102
b, s = 2, 512
113103
batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)}
114104
sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item()
115105
assert sigma_l_sq == b * s * s
116-
new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq
117-
legacy_flops = _compute_per_token_flops(cfg, s) * b * s
118-
assert new_flops == legacy_flops
119106

120107
def test_thd_single_doc_matches_bshd(self):
121108
"""cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S²."""

bionemo-recipes/recipes/llama3_native_te/perf_logger.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -63,50 +63,13 @@ def _detect_peak_tflops_bf16():
6363
return None, name
6464

6565

66-
def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int:
67-
"""Training FLOPs per token for a transformer (forward + backward = 3x forward).
68-
69-
First-principles matmul count: Q/K/V/O projections (GQA-aware), attention
70-
logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform
71-
BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via
72-
model_type), and LM head.
73-
74-
Kept for back-compat. For accurate per-step accounting use
75-
``_compute_non_attn_per_token_flops`` (applied to the total token count)
76-
together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from
77-
cu_seq_lens), since a packed THD batch of total length S containing docs
78-
L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S².
79-
"""
80-
h = model_config_dict["hidden_size"]
81-
n_heads = model_config_dict["num_attention_heads"]
82-
n_kv = model_config_dict.get("num_key_value_heads", n_heads)
83-
head_dim = h // n_heads
84-
kv_dim = n_kv * head_dim
85-
ffn = model_config_dict["intermediate_size"]
86-
vocab = model_config_dict.get("vocab_size", 0)
87-
num_layers = model_config_dict["num_hidden_layers"]
88-
model_type = model_config_dict.get("model_type", "")
89-
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2
90-
91-
per_layer = (
92-
2 * h * h # Q projection
93-
+ 4 * h * kv_dim # K + V projections (GQA-aware)
94-
+ 2 * h * h # O projection
95-
+ 4 * seq_len * h # attention logits + values (S^2 -> S per token)
96-
+ 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections)
97-
)
98-
lm_head = 2 * h * vocab if vocab > 0 else 0
99-
per_token_fwd = num_layers * per_layer + lm_head
100-
return 3 * per_token_fwd
101-
102-
10366
def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int:
10467
"""Per-token FLOPs for everything EXCEPT the S² attention term.
10568
10669
Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
10770
actual total token count of the batch to get per-step non-attention FLOPs. Pairs
108-
with ``_compute_attn_flop_coeff`` so that
109-
``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``.
71+
with ``_compute_attn_flop_coeff``, which contributes the attention term as
72+
``coeff · Σ(Lᵢ²)`` from cu_seq_lens.
11073
"""
11174
h = model_config_dict["hidden_size"]
11275
n_heads = model_config_dict["num_attention_heads"]
@@ -237,26 +200,23 @@ def __init__(
237200
# step are derived at log time from the tracked unpadded token count, which already
238201
# reflects each rank's share under DP/CP and sequence packing.
239202
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
240-
self._per_token_flops = 0
241203
self._non_attn_per_token_flops = 0
242204
self._attn_flop_coeff = 0
243205
self._cp_size = int(args.get("cp_size", 1))
244206
self._peak_tflops: float | None = None
245207
if self._log_mfu:
246-
self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length)
247208
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
248209
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
249210
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
250211
if dist_config.local_rank == 0:
251212
logger.info(
252-
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, "
253-
"non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d",
213+
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, "
214+
"non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d",
254215
gpu_name,
255216
f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown",
256-
float(self._per_token_flops),
257-
args.dataset.max_seq_length,
258217
float(self._non_attn_per_token_flops),
259218
float(self._attn_flop_coeff),
219+
args.dataset.max_seq_length,
260220
self._cp_size,
261221
)
262222

0 commit comments

Comments
 (0)