Skip to content

Commit 4ea9e13

Browse files
lukebaumanncopybara-github
authored andcommitted
Small update to collect_profile so that it can be added as a script in pyproject.toml
PiperOrigin-RevId: 748425772
1 parent 92cc542 commit 4ea9e13

1 file changed

Lines changed: 33 additions & 22 deletions

File tree

pathwaysutils/collect_profile.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pathwaysutils import profiling
2424

2525
_logger = logging.getLogger(__name__)
26+
_logger.setLevel(logging.INFO)
2627

2728

2829
_DESCRIPTION = """
@@ -33,35 +34,45 @@
3334
for a provided duration. The trace file will be dumped into a GCS bucket
3435
(determined by `--log_dir`).
3536
"""
36-
parser = argparse.ArgumentParser(description=_DESCRIPTION)
37-
parser.add_argument(
38-
"--log_dir",
39-
required=True,
40-
help="GCS path to store log files.",
41-
type=str,
42-
)
43-
parser.add_argument("port", help="Port to collect trace", type=int)
44-
parser.add_argument(
45-
"duration_ms", help="Duration to collect trace in milliseconds", type=int
46-
)
47-
parser.add_argument(
48-
"--host",
49-
default="127.0.0.1",
50-
help=(
51-
"Host to collect trace. This host IP/DNS address should be accessible"
52-
" from where this API is being called. Defaults to 127.0.0.1"
53-
),
54-
type=str,
55-
)
5637

5738

58-
def main(args):
39+
def _get_parser():
40+
"""Returns an argument parser for the collect_profile script."""
41+
parser = argparse.ArgumentParser(description=_DESCRIPTION)
42+
parser.add_argument(
43+
"--log_dir",
44+
required=True,
45+
help="GCS path to store log files.",
46+
type=str,
47+
)
48+
parser.add_argument("port", help="Port to collect trace", type=int)
49+
parser.add_argument(
50+
"duration_ms", help="Duration to collect trace in milliseconds", type=int
51+
)
52+
parser.add_argument(
53+
"--host",
54+
default="127.0.0.1",
55+
help=(
56+
"Host to collect trace. This host IP/DNS address should be accessible"
57+
" from where this API is being called. Defaults to 127.0.0.1"
58+
),
59+
type=str,
60+
)
61+
62+
return parser
63+
64+
65+
def main():
66+
parser = _get_parser()
67+
args = parser.parse_args()
68+
5969
if profiling.collect_profile(
6070
args.port, args.duration_ms, args.host, args.log_dir
6171
):
6272
_logger.info("Dumped profiling information in: %s", args.log_dir)
6373
else:
6474
_logger.error("Failed to collect profiling information.")
6575

76+
6677
if __name__ == "__main__":
67-
main(parser.parse_args())
78+
main()

0 commit comments

Comments
 (0)