Skip to content

Commit 7491220

Browse files
committed
gh-148225: Validate profiling.sampling replay input
1 parent acf5229 commit 7491220

File tree

3 files changed

+132
-40
lines changed

3 files changed

+132
-40
lines changed

Lib/profiling/sampling/cli.py

Lines changed: 91 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __call__(self, parser, namespace, values, option_string=None):
8686
_PROCESS_KILL_TIMEOUT_SEC = 2.0
8787
_READY_MESSAGE = b"ready"
8888
_RECV_BUFFER_SIZE = 1024
89+
_BINARY_PROFILE_HEADER_SIZE = 64
90+
_BINARY_PROFILE_MAGICS = (b"HCAT", b"TACH")
8991

9092
# Format configuration
9193
FORMAT_EXTENSIONS = {
@@ -650,6 +652,89 @@ def _open_in_browser(path):
650652
print(f"Warning: Could not open browser: {e}", file=sys.stderr)
651653

652654

655+
def _validate_replay_input_file(filename):
656+
"""Validate that the replay input looks like a sampling binary profile."""
657+
try:
658+
with open(filename, "rb") as file:
659+
header = file.read(_BINARY_PROFILE_HEADER_SIZE)
660+
except OSError as exc:
661+
sys.exit(f"Error: Could not read input file {filename}: {exc}")
662+
663+
if (
664+
len(header) < _BINARY_PROFILE_HEADER_SIZE
665+
or header[:4] not in _BINARY_PROFILE_MAGICS
666+
):
667+
sys.exit(
668+
"Error: Input file is not a binary sampling profile. "
669+
"The replay command only accepts files created with --binary; "
670+
"the default --pstats output written by -o cannot be replayed."
671+
)
672+
673+
674+
def _replay_with_reader(args, reader):
675+
"""Replay samples from an open binary reader."""
676+
info = reader.get_info()
677+
interval = info['sample_interval_us']
678+
679+
print(f"Replaying {info['sample_count']} samples from {args.input_file}")
680+
print(f" Sample interval: {interval} us")
681+
print(
682+
" Compression: "
683+
f"{'zstd' if info.get('compression_type', 0) == 1 else 'none'}"
684+
)
685+
686+
collector = _create_collector(
687+
args.format, interval, skip_idle=False,
688+
diff_baseline=args.diff_baseline
689+
)
690+
691+
def progress_callback(current, total):
692+
if total > 0:
693+
pct = current / total
694+
bar_width = 40
695+
filled = int(bar_width * pct)
696+
bar = '█' * filled + '░' * (bar_width - filled)
697+
print(
698+
f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})",
699+
end="",
700+
flush=True,
701+
)
702+
703+
count = reader.replay_samples(collector, progress_callback)
704+
print()
705+
706+
if args.format == "pstats":
707+
if args.outfile:
708+
collector.export(args.outfile)
709+
else:
710+
sort_choice = (
711+
args.sort if args.sort is not None else "nsamples"
712+
)
713+
limit = args.limit if args.limit is not None else 15
714+
sort_mode = _sort_to_mode(sort_choice)
715+
collector.print_stats(
716+
sort_mode, limit, not args.no_summary,
717+
PROFILING_MODE_WALL
718+
)
719+
else:
720+
filename = (
721+
args.outfile
722+
or _generate_output_filename(args.format, os.getpid())
723+
)
724+
collector.export(filename)
725+
726+
# Auto-open browser for HTML output if --browser flag is set
727+
if (
728+
args.format in (
729+
'flamegraph', 'diff_flamegraph', 'heatmap'
730+
)
731+
and getattr(args, 'browser', False)
732+
):
733+
_open_in_browser(filename)
734+
735+
print(f"Replayed {count} samples")
736+
737+
653738
def _handle_output(collector, args, pid, mode):
654739
"""Handle output for the collector based on format and arguments.
655740
@@ -1201,47 +1286,13 @@ def _handle_replay(args):
12011286
if not os.path.exists(args.input_file):
12021287
sys.exit(f"Error: Input file not found: {args.input_file}")
12031288

1204-
with BinaryReader(args.input_file) as reader:
1205-
info = reader.get_info()
1206-
interval = info['sample_interval_us']
1289+
_validate_replay_input_file(args.input_file)
12071290

1208-
print(f"Replaying {info['sample_count']} samples from {args.input_file}")
1209-
print(f" Sample interval: {interval} us")
1210-
print(f" Compression: {'zstd' if info.get('compression_type', 0) == 1 else 'none'}")
1211-
1212-
collector = _create_collector(
1213-
args.format, interval, skip_idle=False,
1214-
diff_baseline=args.diff_baseline
1215-
)
1216-
1217-
def progress_callback(current, total):
1218-
if total > 0:
1219-
pct = current / total
1220-
bar_width = 40
1221-
filled = int(bar_width * pct)
1222-
bar = '█' * filled + '░' * (bar_width - filled)
1223-
print(f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})", end="", flush=True)
1224-
1225-
count = reader.replay_samples(collector, progress_callback)
1226-
print()
1227-
1228-
if args.format == "pstats":
1229-
if args.outfile:
1230-
collector.export(args.outfile)
1231-
else:
1232-
sort_choice = args.sort if args.sort is not None else "nsamples"
1233-
limit = args.limit if args.limit is not None else 15
1234-
sort_mode = _sort_to_mode(sort_choice)
1235-
collector.print_stats(sort_mode, limit, not args.no_summary, PROFILING_MODE_WALL)
1236-
else:
1237-
filename = args.outfile or _generate_output_filename(args.format, os.getpid())
1238-
collector.export(filename)
1239-
1240-
# Auto-open browser for HTML output if --browser flag is set
1241-
if args.format in ('flamegraph', 'diff_flamegraph', 'heatmap') and getattr(args, 'browser', False):
1242-
_open_in_browser(filename)
1243-
1244-
print(f"Replayed {count} samples")
1291+
try:
1292+
with BinaryReader(args.input_file) as reader:
1293+
_replay_with_reader(args, reader)
1294+
except (OSError, ValueError) as exc:
1295+
sys.exit(f"Error: {exc}")
12451296

12461297

12471298
if __name__ == "__main__":

Lib/test/test_profiling/test_sampling_profiler/test_cli.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tests for sampling profiler CLI argument parsing and functionality."""
22

33
import io
4+
import os
45
import subprocess
56
import sys
7+
import tempfile
68
import unittest
79
from unittest import mock
810

@@ -722,3 +724,39 @@ def test_cli_attach_nonexistent_pid(self):
722724
main()
723725

724726
self.assertIn(fake_pid, str(cm.exception))
727+
728+
def test_cli_replay_rejects_non_binary_profile(self):
729+
with tempfile.TemporaryDirectory() as tempdir:
730+
profile = os.path.join(tempdir, "output.prof")
731+
with open(profile, "wb") as file:
732+
file.write(b"not a binary sampling profile")
733+
734+
with mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]):
735+
with self.assertRaises(SystemExit) as cm:
736+
main()
737+
738+
error = str(cm.exception)
739+
self.assertIn("not a binary sampling profile", error)
740+
self.assertIn("--binary", error)
741+
self.assertIn("--pstats", error)
742+
743+
def test_cli_replay_reader_errors_exit_cleanly(self):
744+
with tempfile.TemporaryDirectory() as tempdir:
745+
profile = os.path.join(tempdir, "output.bin")
746+
with open(profile, "wb") as file:
747+
file.write(b"HCAT" + (b"\0" * 60))
748+
749+
with (
750+
mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]),
751+
mock.patch(
752+
"profiling.sampling.cli.BinaryReader",
753+
side_effect=ValueError("Unsupported format version 2"),
754+
),
755+
):
756+
with self.assertRaises(SystemExit) as cm:
757+
main()
758+
759+
self.assertEqual(
760+
str(cm.exception),
761+
"Error: Unsupported format version 2",
762+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
The :mod:`profiling.sampling` ``replay`` command now rejects non-binary
2+
profile files with a clear error explaining that replay only accepts files
3+
created with ``--binary``.

0 commit comments

Comments
 (0)