Skip to content

Commit 1584abe

Browse files
lukebaumanncopybara-github
authored andcommitted
Small update to collect_profile so that it can be added as a script in pyproject.toml
PiperOrigin-RevId: 748425772
1 parent 92cc542 commit 1584abe

1 file changed

Lines changed: 32 additions & 22 deletions

File tree

pathwaysutils/collect_profile.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,45 @@
3333
for a provided duration. The trace file will be dumped into a GCS bucket
3434
(determined by `--log_dir`).
3535
"""
36-
parser = argparse.ArgumentParser(description=_DESCRIPTION)
37-
parser.add_argument(
38-
"--log_dir",
39-
required=True,
40-
help="GCS path to store log files.",
41-
type=str,
42-
)
43-
parser.add_argument("port", help="Port to collect trace", type=int)
44-
parser.add_argument(
45-
"duration_ms", help="Duration to collect trace in milliseconds", type=int
46-
)
47-
parser.add_argument(
48-
"--host",
49-
default="127.0.0.1",
50-
help=(
51-
"Host to collect trace. This host IP/DNS address should be accessible"
52-
" from where this API is being called. Defaults to 127.0.0.1"
53-
),
54-
type=str,
55-
)
5636

5737

58-
def main(args):
38+
def _get_parser():
39+
"""Returns an argument parser for the collect_profile script."""
40+
parser = argparse.ArgumentParser(description=_DESCRIPTION)
41+
parser.add_argument(
42+
"--log_dir",
43+
required=True,
44+
help="GCS path to store log files.",
45+
type=str,
46+
)
47+
parser.add_argument("port", help="Port to collect trace", type=int)
48+
parser.add_argument(
49+
"duration_ms", help="Duration to collect trace in milliseconds", type=int
50+
)
51+
parser.add_argument(
52+
"--host",
53+
default="127.0.0.1",
54+
help=(
55+
"Host to collect trace. This host IP/DNS address should be accessible"
56+
" from where this API is being called. Defaults to 127.0.0.1"
57+
),
58+
type=str,
59+
)
60+
61+
return parser
62+
63+
64+
def main():
65+
parser = _get_parser()
66+
args = parser.parse_args()
67+
5968
if profiling.collect_profile(
6069
args.port, args.duration_ms, args.host, args.log_dir
6170
):
6271
_logger.info("Dumped profiling information in: %s", args.log_dir)
6372
else:
6473
_logger.error("Failed to collect profiling information.")
6574

75+
6676
if __name__ == "__main__":
67-
main(parser.parse_args())
77+
main()

0 commit comments

Comments
 (0)