-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathperf.py
More file actions
1465 lines (1333 loc) ยท 42.7 KB
/
perf.py
File metadata and controls
1465 lines (1333 loc) ยท 42.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Benchmark FFPA speedups and render plot plus Markdown tables.
Usage::
CUDA_VISIBLE_DEVICES=0 python examples/perf.py
CUDA_VISIBLE_DEVICES=0 python examples/perf.py --no-bwd --fwd-backend triton --tune fast
CUDA_VISIBLE_DEVICES=0 python examples/perf.py --no-fwd --bwd-backend triton --tune max
CUDA_VISIBLE_DEVICES=0 python examples/perf.py --fwd-backend triton --bwd-backend triton --tune fast
CUDA_VISIBLE_DEVICES=0 python examples/perf.py --fwd-backend cutedsl --bwd-backend cutedsl
The cutedsl backend supports SM80/SM89 via the Split-D path and SM90 via the
Hopper path. Selecting
``cutedsl`` on either ``--fwd-backend`` or ``--bwd-backend`` auto-pairs the
other side and restricts tasks to the cutedsl-compatible subset
(self-attn, cross-attn, gqa, causal).
"""
from __future__ import annotations
import argparse
from pathlib import Path
import re
import sys
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import torch
EXAMPLES_DIR = Path(__file__).resolve().parent
if str(EXAMPLES_DIR) not in sys.path:
sys.path.insert(0, str(EXAMPLES_DIR))
from _attn_fwd import run_forward_examples
from _attn_bwd import run_backward_examples
from _attn_flops import format_tflops_short
def _parse_grad_kv_dtype(arg: str) -> torch.dtype | None:
"""Parse the CLI grad-kv-dtype option.
:param arg: CLI value, ``"none"``, ``"fp16"``, or ``"fp32"``.
:return: ``None``, ``torch.float16``, or ``torch.float32``.
"""
if arg == "none":
return None
if arg == "fp16":
return torch.float16
if arg == "fp32":
return torch.float32
raise ValueError(
f"Unsupported grad-kv-dtype={arg!r}; choose 'none', 'fp16', or 'fp32'."
)
# Keep the exact legacy plotting style from tools/plot.py.
plt.rcParams["figure.dpi"] = 300
plt.rcParams["font.sans-serif"] = ["DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
PLOT_CASES: list[tuple[str, str]] = [
("self-attn", "self-attn(F/B)"),
("cross-attn", "cross-attn(F/B)"),
("decode-attn", "decode(Nq=1,F/B)"),
("gqa", "gqa(F/B)"),
("causal", "causal(F/B)"),
("attn-mask", "attn-mask(F/B)"),
("dropout", "dropout(F/B)"),
("non-aligned", "non-aligned(F/B)"),
]
SPEEDUP_PLOT_CASES: list[tuple[str, str]] = [
("self-attn", "self-attn(F/B)"),
("cross-attn", "cross-attn(F/B)"),
("decode-attn", "decode(Nq=1,F/B)"),
("gqa", "gqa(F/B)"),
("causal", "causal(F/B)"),
("attn-mask", "attn-mask(F/B)"),
("dropout", "dropout(F/B)"),
]
TFLOPS_PLOT_CASES: list[tuple[str, str]] = [
("self-attn", "self-attn(F/B)"),
("cross-attn", "cross-attn(F/B)"),
("gqa", "gqa(F/B)"),
("causal", "causal(F/B)"),
]
CASE_LABELS = dict(PLOT_CASES)
VALID_TASKS = tuple(case_name for case_name, _ in PLOT_CASES)
DTYPE_ORDER = ["fp16", "bf16"]
DEFAULT_OUTPUT_STEM = "ffpa_speedup"
DEFAULT_OUTPUT_DIR = Path(".tmp")
FALLBACK_DEVICE_NAME = "NVIDIA Geforce RTX 5090"
CUTEDSL_BACKEND = "cutedsl"
CUTEDSL_COMPAT_TASKS = frozenset({
"self-attn", "cross-attn", "gqa", "causal", "non-aligned"
})
CUTEDSL_DTYPES: tuple[torch.dtype, ...] = (torch.float16, torch.bfloat16)
CUTEDSL_OUTPUT_STEM = "ffpa_speedup_cutedsl"
CUTEDSL_SECTION_LABEL = "CuTeDSL"
TFLOPS_FWD_SDPA_COLOR = "#b0b0b0"
TFLOPS_FWD_FFPA_COLOR = "#2171b5"
TFLOPS_BWD_SDPA_COLOR = "#f5a623"
TFLOPS_BWD_FFPA_COLOR = "#fd493c"
FALLBACK_SPEEDUPS: dict[str, dict[str, dict[str, float]]] = {
"forward": {
"self-attn": {
"fp16": 2.06,
"bf16": 2.08
},
"cross-attn": {
"fp16": 1.86,
"bf16": 1.87
},
"decode-attn": {
"fp16": 2.86,
"bf16": 2.85
},
"gqa": {
"fp16": 2.06,
"bf16": 2.09
},
"causal": {
"fp16": 1.96,
"bf16": 1.99
},
"attn-mask": {
"fp16": 1.70,
"bf16": 1.74
},
"dropout": {
"fp16": 1.79,
"bf16": 1.82
},
"non-aligned": {
"fp16": 1.96,
"bf16": 1.98
},
},
"backward": {
"self-attn": {
"fp16": 2.34,
"bf16": 2.49
},
"cross-attn": {
"fp16": 2.57,
"bf16": 2.51
},
"decode-attn": {
"fp16": 2.97,
"bf16": 3.11
},
"gqa": {
"fp16": 2.32,
"bf16": 2.46
},
"causal": {
"fp16": 2.22,
"bf16": 2.56
},
"attn-mask": {
"fp16": 2.06,
"bf16": 2.17
},
"dropout": {
"fp16": 2.27,
"bf16": 2.41
},
"non-aligned": {
"fp16": 2.37,
"bf16": 2.67
},
},
}
RESULT_ROW = dict[str, Any]
def _parse_args() -> argparse.Namespace:
"""Parse CLI arguments.
:return: Parsed CLI namespace.
"""
parser = argparse.ArgumentParser(
description=
"Benchmark FFPA forward/backward cases and generate plot plus Markdown tables."
)
parser.add_argument(
"--no-forward",
"--no-fwd",
dest="forward",
action="store_false",
help="Disable forward benchmark cases.",
)
parser.add_argument(
"--no-backward",
"--no-bwd",
dest="backward",
action="store_false",
help="Disable backward benchmark cases.",
)
parser.set_defaults(forward=True, backward=True)
parser.add_argument(
"--show-fallback",
action="store_true",
help=
"Render the legacy hard-coded fallback plot and Markdown table instead of running real benchmarks.",
)
parser.add_argument(
"--forward-backend",
"--fwd-backend",
choices=["cuda", "triton", "cutedsl"],
default="triton",
help="Forward backend used when --forward is enabled.",
)
parser.add_argument(
"--backward-backend",
"--bwd-backend",
choices=["sdpa", "triton", "cutedsl"],
default="triton",
help="Backward backend used when --backward is enabled.",
)
parser.add_argument(
"--tune",
choices=["fast", "max"],
help="Enable Triton autotune with the selected search mode."
)
parser.add_argument(
"--tasks",
nargs="*",
default=None,
help=(
"Benchmark cases to run, separated by commas or whitespace, for example self-attn,cross-attn. "
"Defaults to full; valid cases: " + ",".join(VALID_TASKS)
),
)
parser.add_argument(
"--B", type=int, default=1, help="Batch size used by benchmark mode."
)
parser.add_argument(
"--H", type=int, default=32, help="Base head count used by benchmark mode."
)
parser.add_argument(
"--N",
type=int,
default=8192,
help="Base sequence length used by benchmark mode."
)
parser.add_argument(
"--D", type=int, default=512, help="Head dimension used by benchmark mode."
)
parser.add_argument(
"--warmup", type=int, default=2, help="Warmup iterations used for timing."
)
parser.add_argument(
"--iters",
type=int,
default=10,
help="Measured iterations used for timing."
)
parser.add_argument(
"--seed", type=int, default=42, help="RNG seed used by benchmark mode."
)
parser.add_argument(
"--dtype",
choices=["fp16", "bf16", "both"],
default="both",
help="Activation dtype to benchmark. 'both' (default) runs fp16+bf16; "
"fp16/bf16 narrows to that single dtype.",
)
parser.add_argument(
"--norm",
action="store_true",
help="Enable pre-attention LayerNorm on q/k/v for both FFPA and SDPA paths.",
)
parser.add_argument(
"--enable-tma",
"--tma",
action="store_true",
help="Compatibility alias for --enable-fwd-tma --enable-bwd-tma.",
)
parser.add_argument(
"--enable-ws",
"--ws",
action="store_true",
help="Compatibility alias for --enable-fwd-ws --enable-bwd-ws.",
)
parser.add_argument(
"--enable-fwd-tma",
"--fwd-tma",
action="store_true",
help=
"Enable experimental SM90+ TMA forward path (silently falls back on unsupported devices).",
)
parser.add_argument(
"--enable-bwd-tma",
"--bwd-tma",
action="store_true",
help=
"Enable experimental SM90+ TMA backward path (silently falls back on unsupported devices).",
)
parser.add_argument(
"--enable-fwd-ws",
"--fwd-ws",
action="store_true",
help=
"Force warp-specialized configs for the experimental SM90+ TMA forward path.",
)
parser.add_argument(
"--enable-bwd-ws",
"--bwd-ws",
action="store_true",
help=
"Force warp-specialized configs for the experimental SM90+ TMA backward path.",
)
parser.add_argument(
"--enable-persist-dkdv",
"--persist-dkdv",
action="store_true",
help=
"Enable persistent dK/dV fp32 accumulation in the SM90+ TMA backward path (requires --bwd-tma).",
)
parser.add_argument(
"--enable-bwd-split-launch",
"--bwd-split-launch",
"--bwd-split",
action="store_true",
help=
"Enable separate backward launches for dK/dV and dQ on generic Triton or SM90+ TMA paths.",
)
parser.add_argument(
"--grad-kv-storage-dtype",
"--grad-kv-dtype",
choices=["none", "fp16", "fp32"],
default="none",
help=
"Optional Triton backward dK/dV storage dtype forwarded to the example runners.",
)
parser.add_argument(
"--show-allclose",
action="store_true",
help="Include the allclose column in the generated Markdown tables.",
)
parser.add_argument(
"--save-path",
type=Path,
default=None,
help=
"Optional output directory used to save the generated PNG and Markdown artifacts. Defaults to ./.tmp.",
)
return _resolve_directional_cli_flags(parser.parse_args())
def _resolve_directional_cli_flags(
args: argparse.Namespace
) -> argparse.Namespace:
"""Resolve legacy global TMA/WS flags into directional benchmark flags."""
if args.enable_tma:
args.enable_fwd_tma = True
args.enable_bwd_tma = True
if args.enable_ws:
args.enable_fwd_ws = True
args.enable_bwd_ws = True
if args.enable_persist_dkdv and not args.enable_bwd_tma:
raise SystemExit("--enable-persist-dkdv requires --enable-bwd-tma")
return args
def _parse_tasks_arg(tasks_arg: list[str] | None) -> set[str] | None:
"""Parse the optional benchmark task filter.
:param tasks_arg: Raw ``--tasks`` values.
:return: Selected case names, or ``None`` for the full benchmark suite.
:raises SystemExit: If an unknown case name is requested.
"""
if tasks_arg is None:
return None
normalized = " ".join(tasks_arg).strip()
if normalized == "" or normalized.lower() in {"full", "all", "none"}:
return None
tasks = {task for task in re.split(r"[\s,]+", normalized) if task}
if not tasks:
return None
invalid = sorted(tasks.difference(VALID_TASKS))
if invalid:
valid = ",".join(VALID_TASKS)
raise SystemExit(
f"Unknown --tasks value(s): {','.join(invalid)}. Valid cases: {valid}, or full."
)
return tasks
def _active_plot_cases(tasks: set[str] | None,
*,
kind: str = "speedup") -> list[tuple[str, str]]:
"""Return plot cases filtered by an optional task set.
:param tasks: Optional selected case names.
:param kind: ``"speedup"`` (bar chart, omits non-aligned), ``"tflops"`` (TFLOPS
chart), or ``"all"`` (full case list used by the Markdown sort order).
:return: Ordered plot case list.
"""
if kind == "speedup":
source = SPEEDUP_PLOT_CASES
elif kind == "tflops":
source = TFLOPS_PLOT_CASES
elif kind == "all":
source = PLOT_CASES
else:
raise ValueError(
f"Unknown plot kind={kind!r}; choose 'speedup', 'tflops', or 'all'."
)
if tasks is None:
return list(source)
return [(case_name, label) for case_name, label in source
if case_name in tasks]
def _device_name() -> str:
"""Return the active CUDA device name when available.
:return: CUDA device name or a fallback label.
"""
if torch.cuda.is_available():
return torch.cuda.get_device_name(torch.cuda.current_device())
return "CUDA Unavailable"
def _display_device_name(device_name: str) -> str:
"""Rename H20Z โ H200 for cutedsl plot titles only; filenames stay raw."""
return re.sub(r"H20Z", "H200", device_name, flags=re.IGNORECASE)
def _require_cutedsl_device() -> int:
"""Fail fast on devices unsupported by the CuTeDSL backend.
:return: Maximum CuTeDSL head dimension for the active device.
"""
from ffpa_attn.cutedsl import cutedsl_forward_available, cutedsl_max_supported_head_dim
if not torch.cuda.is_available():
raise SystemExit(
"CUDA is required: the CuTeDSL backend only runs on SM80/SM89/SM90 GPUs."
)
device = torch.device("cuda", torch.cuda.current_device())
if not cutedsl_forward_available(device):
major, minor = torch.cuda.get_device_capability(device)
raise SystemExit(
f"CuTeDSL backend requires SM80/SM89/SM90. Detected device "
f"'{torch.cuda.get_device_name(device)}' with compute capability {major}.{minor}."
)
return cutedsl_max_supported_head_dim(device)
def _resolve_cutedsl_backends(args: argparse.Namespace) -> bool:
"""Auto-pair cutedsl backends; reject mixing cutedsl with a non-cutedsl peer.
Auto-promotion only fires when the peer side is still at its default
("triton"), so an explicit ``--backward-backend sdpa`` alongside
``--forward-backend cutedsl`` raises instead of being silently overridden.
"""
fwd, bwd = args.forward_backend, args.backward_backend
if CUTEDSL_BACKEND not in {fwd, bwd}:
return False
if fwd == CUTEDSL_BACKEND and bwd != CUTEDSL_BACKEND:
if bwd != "triton":
raise SystemExit(
f"--forward-backend cutedsl requires --backward-backend cutedsl; got {bwd!r}."
)
args.backward_backend = CUTEDSL_BACKEND
if bwd == CUTEDSL_BACKEND and fwd != CUTEDSL_BACKEND:
if fwd != "triton":
raise SystemExit(
f"--backward-backend cutedsl requires --forward-backend cutedsl; got {fwd!r}."
)
args.forward_backend = CUTEDSL_BACKEND
return True
def _slugify_device_name(device_name: str) -> str:
"""Convert a device name into a filesystem-friendly slug.
:param device_name: Human-readable device name.
:return: Lowercase slug safe for filenames.
"""
slug = re.sub(r"[^0-9A-Za-z]+", "-", device_name.strip().lower())
slug = re.sub(r"-+", "-", slug).strip("-")
return slug or "unknown-device"
def _output_stem(
device_name: str,
B: int,
H: int,
N: int,
D: int,
*,
cutedsl: bool = False
) -> Path:
"""Build the output stem shared by the PNG and Markdown files.
:param device_name: Device name used in the run.
:param B: Batch size.
:param H: Head count.
:param N: Sequence length.
:param D: Head dimension.
:param cutedsl: When ``True``, switch the prefix to keep cutedsl artifacts
from clobbering the standard ones.
:return: Output stem without extension.
"""
prefix = CUTEDSL_OUTPUT_STEM if cutedsl else DEFAULT_OUTPUT_STEM
device_slug = _slugify_device_name(device_name)
return Path(f"{prefix}_{device_slug}_B{B}_H{H}_N{N}_D{D}")
def _resolve_output_stem(
save_path: Path | None,
device_name: str,
B: int,
H: int,
N: int,
D: int,
*,
cutedsl: bool = False,
) -> Path:
"""Resolve the final output stem, optionally rooted at ``save_path``.
:param save_path: Optional output directory.
:param device_name: Device name used in the run.
:param B: Batch size.
:param H: Head count.
:param N: Sequence length.
:param D: Head dimension.
:param cutedsl: Forwarded to :func:`_output_stem` for prefix selection.
:return: Output stem without extension.
"""
default_stem = _output_stem(device_name, B, H, N, D, cutedsl=cutedsl)
output_dir = DEFAULT_OUTPUT_DIR if save_path is None else save_path
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir / default_stem.name
def _case_shape(case_name: str, sequence_length: int) -> tuple[int, int]:
"""Return the benchmark shape shown in Markdown for one case.
:param case_name: Canonical case name.
:param sequence_length: Base sequence length.
:return: ``(Nq, Nkv)`` pair.
"""
if case_name == "cross-attn":
return 1024, sequence_length
if case_name == "decode-attn":
return 1, sequence_length
if case_name == "attn-mask":
mask_n = max(sequence_length, 512)
return mask_n, mask_n
if case_name == "non-aligned":
non_aligned_n = sequence_length - 1 if sequence_length > 1 else sequence_length
return non_aligned_n, non_aligned_n
return sequence_length, sequence_length
def _mode_suffix(has_forward: bool, has_backward: bool) -> str:
"""Build the title suffix that matches the legacy title style.
:param has_forward: Whether forward data is present.
:param has_backward: Whether backward data is present.
:return: Title mode suffix.
"""
if has_forward and has_backward:
return "FWD & BWD"
if has_forward:
return "FWD"
return "BWD"
def _forward_section_label(
backend: str, tune_mode: str | None, fallback: bool
) -> str:
"""Describe the forward data source for Markdown headings.
:param backend: Forward backend.
:param tune_mode: Triton autotune mode.
:param fallback: Whether fallback hard-coded data is used.
:return: Human-readable section label.
"""
if fallback:
return "Fallback hard-coded data"
if backend == "cuda":
return "Legacy CUDA"
if backend == CUTEDSL_BACKEND:
return CUTEDSL_SECTION_LABEL
if tune_mode is not None:
return f"Triton w/ autotune ({tune_mode})"
return "Triton"
def _backward_section_label(
backend: str, tune_mode: str | None, fallback: bool
) -> str:
"""Describe the backward data source for Markdown headings.
:param backend: Backward backend.
:param tune_mode: Triton autotune mode.
:param fallback: Whether fallback hard-coded data is used.
:return: Human-readable section label.
"""
if fallback:
return "Fallback hard-coded data"
if backend == "sdpa":
return "SDPA backward"
if backend == CUTEDSL_BACKEND:
return CUTEDSL_SECTION_LABEL
if tune_mode is not None:
return f"Triton w/ autotune ({tune_mode})"
return "Triton"
def _decorate_rows(direction: str, rows: list[dict[str,
Any]]) -> list[RESULT_ROW]:
"""Attach the direction field to benchmark results.
:param direction: ``forward`` or ``backward``.
:param rows: Raw rows returned by the example helper.
:return: Decorated result rows.
"""
return [{"direction": direction, **row} for row in rows]
def _build_fallback_rows(
B: int,
H: int,
N: int,
D: int,
tasks: set[str] | None = None,
) -> tuple[list[RESULT_ROW], list[RESULT_ROW]]:
"""Build structured rows for the legacy hard-coded plot data.
:param B: Batch size shown in metadata.
:param H: Base head count shown in metadata.
:param N: Base sequence length shown in metadata.
:param D: Head dimension shown in metadata.
:param tasks: Optional case-name filter.
:return: ``(forward_rows, backward_rows)``.
"""
forward_rows: list[RESULT_ROW] = []
backward_rows: list[RESULT_ROW] = []
for direction, target in (("forward", forward_rows),
("backward", backward_rows)):
for case_name, _ in PLOT_CASES:
if tasks is not None and case_name not in tasks:
continue
nq, nkv = _case_shape(case_name, N)
for dtype in DTYPE_ORDER:
target.append({
"direction":
direction,
"case_name":
case_name,
"dtype":
dtype,
"B":
B,
"Hq":
H,
"Hkv":
H,
"Nq":
nq,
"Nkv":
nkv,
"D":
D,
"allclose":
None,
"ffpa_ms":
None,
"sdpa_ms":
None,
"speedup":
FALLBACK_SPEEDUPS[direction][case_name][dtype],
"forward_backend":
"hard-coded" if direction == "forward" else None,
"backward_backend":
"hard-coded" if direction == "backward" else None,
"dropout_p":
0.1 if case_name == "dropout" else 0.0,
"causal":
case_name == "causal",
})
return forward_rows, backward_rows
def _aggregate_speedups(
rows: list[RESULT_ROW],
direction: str,
plot_cases: list[tuple[str, str]] | None = None,
) -> list[float] | None:
"""Aggregate per-dtype rows into the bar heights used by the plot.
Prefers the bf16 value when present so the bar chart compares against a
single canonical dtype; falls back to the max of remaining dtypes when bf16
is absent (e.g. ``--dtype fp16``).
:param rows: Structured rows for one or both directions.
:param direction: ``forward`` or ``backward``.
:param plot_cases: Ordered cases to aggregate for plotting.
:return: Aggregated bar heights, or ``None`` when the direction is absent.
"""
active_plot_cases = PLOT_CASES if plot_cases is None else plot_cases
active_case_names = {case_name for case_name, _ in active_plot_cases}
case_to_speedups: dict[str, dict[str, float]] = {
case_name: {}
for case_name, _ in active_plot_cases
}
for row in rows:
if row["direction"] != direction:
continue
if row["case_name"] not in active_case_names:
continue
case_to_speedups[row["case_name"]][row["dtype"]] = float(row["speedup"])
if not any(case_to_speedups.values()):
return None
values: list[float] = []
for case_name, _ in active_plot_cases:
dtyped = case_to_speedups[case_name]
if "bf16" in dtyped:
values.append(dtyped["bf16"])
elif dtyped:
values.append(float(np.amax(list(dtyped.values()))))
else:
values.append(float("nan"))
return values
def _aggregate_metric(
rows: list[RESULT_ROW],
direction: str,
metric_key: str,
plot_cases: list[tuple[str, str]] | None = None,
) -> list[float] | None:
"""Aggregate one numeric metric per case for plotting.
Prefers the bf16 value when present (single canonical dtype on the bar
chart); falls back to the max of remaining dtypes when bf16 is absent
(e.g. ``--dtype fp16``).
:param rows: Structured rows for one or both directions.
:param direction: ``forward`` or ``backward``.
:param metric_key: Row field containing the numeric metric.
:param plot_cases: Ordered cases to aggregate for plotting.
:return: Aggregated bar heights, or ``None`` when the direction is absent.
"""
active_plot_cases = PLOT_CASES if plot_cases is None else plot_cases
active_case_names = {case_name for case_name, _ in active_plot_cases}
case_to_values: dict[str, dict[str, float]] = {
case_name: {}
for case_name, _ in active_plot_cases
}
for row in rows:
if row["direction"] != direction:
continue
if row["case_name"] not in active_case_names:
continue
value = row.get(metric_key)
if value is None:
continue
case_to_values[row["case_name"]][row["dtype"]] = float(value)
if not any(case_to_values.values()):
return None
values: list[float] = []
for case_name, _ in active_plot_cases:
dtyped = case_to_values[case_name]
if "bf16" in dtyped:
values.append(dtyped["bf16"])
elif dtyped:
values.append(float(np.amax(list(dtyped.values()))))
else:
values.append(float("nan"))
return values
def plot_speedup(
forward_rows: list[RESULT_ROW],
backward_rows: list[RESULT_ROW],
*,
device_name: str,
B: int,
H: int,
N: int,
D: int,
output_path: Path,
plot_cases: list[tuple[str, str]] | None = None,
cutedsl: bool = False,
) -> Path:
"""Render the speedup bar chart while preserving the legacy look.
:param forward_rows: Forward result rows.
:param backward_rows: Backward result rows.
:param device_name: Device name shown in the title.
:param B: Batch size shown in the title.
:param H: Head count shown in the title.
:param N: Sequence length shown in the title.
:param D: Head dimension shown in the title.
:param output_path: Output PNG path.
:param plot_cases: Ordered cases to include in the plot.
:param cutedsl: When ``True``, swap the title prefix to "FFPA CuTeDSL vs
SDPA Speedup" and apply the H20Z โ H200 display rename.
:return: Saved PNG path.
"""
active_plot_cases = PLOT_CASES if plot_cases is None else plot_cases
fwd_speedups = _aggregate_speedups(forward_rows, "forward", active_plot_cases)
bwd_speedups = _aggregate_speedups(
backward_rows, "backward", active_plot_cases
)
has_forward = fwd_speedups is not None
has_backward = bwd_speedups is not None
if not has_forward and not has_backward:
raise ValueError("No forward or backward rows were provided for plotting.")
attn_types = [label for _, label in active_plot_cases]
sdpa_speedups = [1.0] * len(attn_types)
fig, ax = plt.subplots(figsize=(32, 12))
width = 0.20
x = np.arange(len(attn_types))
def _autolabel(rects) -> None:
for rect in rects:
h = rect.get_height()
if not np.isfinite(h):
continue
offset = 8 if h >= 1 else 20
va_pos = "bottom" if h >= 1 else "top"
ax.annotate(
f"{h:.1f}x",
xy=(rect.get_x() + rect.get_width() / 2, h),
xytext=(0, offset),
textcoords="offset points",
ha="center",
va=va_pos,
fontsize=19,
fontweight="bold",
)
if has_forward and has_backward:
x_sdpa = x - width
x_fwd = x
x_bwd = x + width
elif has_forward:
x_sdpa = x - width / 2
x_fwd = x + width / 2
x_bwd = None
else:
x_sdpa = x - width / 2
x_fwd = None
x_bwd = x + width / 2
rect_sdpa = ax.bar(
x_sdpa,
sdpa_speedups,
width,
label="SDPA Baseline",
color="#b0b0b0",
edgecolor="white",
linewidth=1,
)
_autolabel(rect_sdpa)
finite_values = [1.0]
if x_fwd is not None and fwd_speedups is not None:
rect_fwd = ax.bar(
x_fwd,
fwd_speedups,
width,
label="FFPA Forward (FWD)",
color="#2171b5",
edgecolor="white",
linewidth=1,
)
_autolabel(rect_fwd)
finite_values.extend(value for value in fwd_speedups if np.isfinite(value))
if x_bwd is not None and bwd_speedups is not None:
rect_bwd = ax.bar(
x_bwd,
bwd_speedups,
width,
label="FFPA Backward (BWD)",
color="#fd493c",
edgecolor="white",
linewidth=1,
)
_autolabel(rect_bwd)
finite_values.extend(value for value in bwd_speedups if np.isfinite(value))
ax.axhline(y=1, color="#555555", linestyle="--", linewidth=2)
ax.set_ylabel("Speedup Ratio (FFPA / SDPA)", fontsize=18)
title_prefix = "FFPA CuTeDSL vs SDPA Speedup" if cutedsl else "FFPA vs SDPA Speedup"
title_device = _display_device_name(device_name) if cutedsl else device_name
fig.suptitle(
f"{title_prefix} ({_mode_suffix(has_forward, has_backward)}) | {title_device} | B={B}, N={N}, H={H}, D={D}",
fontsize=22,
fontweight="bold",
y=0.958,
)
ax.set_xticks(x)
ax.set_xticklabels(
attn_types, rotation=0, ha="center", fontsize=22, fontweight="bold"
)
ax.tick_params(axis="y", labelsize=16)
ymax = max(finite_values) if finite_values else 1.0
ax.set_ylim(0, ymax * 1.17 if ymax > 0 else 1.0)
ax.legend(
fontsize=20,
loc="upper center",
bbox_to_anchor=(0.5, 0.972),
ncol=3 if has_forward and has_backward else 2,
columnspacing=1.5,
handletextpad=0.6,
frameon=True,
)
ax.grid(axis="y", alpha=0.9)
fig.tight_layout(rect=(0, 0, 1, 0.955))
fig.savefig(output_path)
plt.close(fig)
return output_path
def plot_tflops(
forward_rows: list[RESULT_ROW],
backward_rows: list[RESULT_ROW],
*,
device_name: str,
B: int,
H: int,
N: int,
D: int,
output_path: Path,
plot_cases: list[tuple[str, str]] | None = None,
cutedsl: bool = False,
) -> Path | None:
"""Render the TFLOPS comparison bar chart.
:param forward_rows: Forward result rows.
:param backward_rows: Backward result rows.
:param device_name: Device name shown in the title.
:param B: Batch size shown in the title.
:param H: Head count shown in the title.
:param N: Sequence length shown in the title.
:param D: Head dimension shown in the title.
:param output_path: Output PNG path.
:param plot_cases: Ordered cases to include in the plot.
:param cutedsl: When ``True``, swap the title prefix to "FFPA CuTeDSL vs
SDPA TFLOPS" and apply the H20Z โ H200 display rename.
:return: Saved PNG path, or ``None`` when no TFLOPS data is available.
"""
active_plot_cases = TFLOPS_PLOT_CASES if plot_cases is None else plot_cases
fwd_ffpa_tflops = _aggregate_metric(
forward_rows, "forward", "ffpa_tflops", plot_cases=active_plot_cases
)
fwd_sdpa_tflops = _aggregate_metric(
forward_rows, "forward", "sdpa_tflops", plot_cases=active_plot_cases
)
bwd_ffpa_tflops = _aggregate_metric(
backward_rows, "backward", "ffpa_tflops", plot_cases=active_plot_cases
)
bwd_sdpa_tflops = _aggregate_metric(
backward_rows, "backward", "sdpa_tflops", plot_cases=active_plot_cases
)
has_forward = fwd_ffpa_tflops is not None and fwd_sdpa_tflops is not None
has_backward = bwd_ffpa_tflops is not None and bwd_sdpa_tflops is not None
if not has_forward and not has_backward:
return None
attn_types = [label for _, label in active_plot_cases]
x = np.arange(len(attn_types))
fig, ax = plt.subplots(figsize=(16, 12))
def _autolabel(rects) -> None:
for rect in rects:
h = rect.get_height()
if not np.isfinite(h):
continue
ax.annotate(
format_tflops_short(float(h)),
xy=(rect.get_x() + rect.get_width() / 2, h),
xytext=(0, 8),
textcoords="offset points",
ha="center",
va="bottom",
fontsize=19,
fontweight="bold",
)
finite_values: list[float] = []
if has_forward and has_backward:
width = 0.18
x_fwd_sdpa = x - 1.5 * width
x_fwd_ffpa = x - 0.5 * width
x_bwd_sdpa = x + 0.5 * width
x_bwd_ffpa = x + 1.5 * width