Skip to content

Commit dcad753

Browse files
lukebaumanncopybara-github
authored andcommitted
Port the collect_profile script from JAX to PathwaysUtils
PiperOrigin-RevId: 746578297
1 parent 42640da commit dcad753

3 files changed

Lines changed: 244 additions & 9 deletions

File tree

pathwaysutils/collect_profile.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Module for collecting JAX profiles for Pathways on Cloud.
15+
16+
This is a replacement for the `collect_profile` script in JAX that works with
17+
Pathways on Cloud.
18+
"""
19+
20+
import argparse
21+
import logging
22+
23+
from pathwaysutils import profiling
24+
25+
_logger = logging.getLogger(__name__)
26+
27+
28+
_DESCRIPTION = """
29+
To profile running JAX programs, you first need to start the profiler server
30+
in the program of interest. You can do this via
31+
`jax.profiler.start_server(<port>)`. Once the program is running and the
32+
profiler server has started, you can run `collect_profile` to trace the execution
33+
for a provided duration. The trace file will be dumped into a directory
34+
(determined by `--log_dir`).
35+
(determined by `--log_dir`).
36+
"""
37+
parser = argparse.ArgumentParser(description=_DESCRIPTION)
38+
parser.add_argument(
39+
"--log_dir",
40+
default=None,
41+
help=(
42+
"Directory to store log files. "
43+
"Uses a temporary directory if none provided."
44+
),
45+
type=str,
46+
)
47+
parser.add_argument("port", help="Port to collect trace", type=int)
48+
parser.add_argument(
49+
"duration_ms", help="Duration to collect trace in milliseconds", type=int
50+
)
51+
parser.add_argument(
52+
"--host",
53+
default="127.0.0.1",
54+
help="Host to collect trace. Defaults to 127.0.0.1",
55+
type=str,
56+
)
57+
58+
59+
def main(args):
60+
profiling.collect_profile(
61+
args.port, args.duration_ms, args.host, args.log_dir
62+
)
63+
64+
65+
if __name__ == "__main__":
66+
main(parser.parse_args())

pathwaysutils/profiling.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,28 @@
1515

1616
import dataclasses
1717
import logging
18+
import os
19+
import pathlib
20+
import tempfile
1821
import threading
1922
import time
23+
import urllib.parse
2024

2125
import fastapi
2226
import jax
2327
from jax import numpy as jnp
2428
from pathwaysutils import plugin_executable
29+
import requests
2530
import uvicorn
2631

27-
logger = logging.getLogger(__name__)
32+
33+
_logger = logging.getLogger(__name__)
2834

2935

3036
class _ProfileState:
37+
executable: plugin_executable.PluginExecutable | None = None
38+
lock: threading.Lock
39+
3140
def __init__(self):
3241
self.executable = None
3342
self.lock = threading.Lock()
@@ -88,7 +97,7 @@ def stop_trace():
8897
_original_stop_trace()
8998

9099

91-
_profiler_thread = None
100+
_profiler_thread: threading.Thread | None = None
92101

93102

94103
def start_server(port: int):
@@ -102,7 +111,7 @@ def start_server(port: int):
102111
port : The port to start the server on.
103112
"""
104113
def server_loop(port: int):
105-
logger.debug("Starting JAX profiler server on port %s", port)
114+
_logger.debug("Starting JAX profiler server on port %s", port)
106115
app = fastapi.FastAPI()
107116

108117
@dataclasses.dataclass
@@ -112,8 +121,8 @@ class ProfilingConfig:
112121

113122
@app.post("/profiling")
114123
async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
115-
logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
116-
logger.debug("Writing profiling data to %s", pc.repository_path)
124+
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
125+
_logger.debug("Writing profiling data to %s", pc.repository_path)
117126
jax.profiler.start_trace(pc.repository_path)
118127
time.sleep(pc.duration_ms / 1e3)
119128
jax.profiler.stop_trace()
@@ -138,6 +147,39 @@ def stop_server():
138147
raise ValueError("No active profiler server.")
139148

140149

150+
def collect_profile(
151+
port: int,
152+
duration_ms: int,
153+
host: str,
154+
log_dir: os.PathLike[str] | str | None,
155+
):
156+
"""Collects a JAX profile and saves it to the specified directory.
157+
158+
Args:
159+
port: The port on which the JAX profiler server is running.
160+
duration_ms: The duration in milliseconds for which to collect the profile.
161+
host: The host on which the JAX profiler server is running.
162+
log_dir: The directory to save the profile data. If None, a temporary
163+
directory will be used.
164+
"""
165+
if log_dir is None:
166+
log_dir = tempfile.mkdtemp()
167+
prefix = ""
168+
if isinstance(log_dir, str) and log_dir.startswith("gs://"):
169+
prefix = "gs://"
170+
log_dir = log_dir[5:]
171+
172+
log_dir = prefix + str(pathlib.Path(log_dir))
173+
174+
json = {
175+
"duration_ms": duration_ms,
176+
"repository_path": str(log_dir),
177+
}
178+
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
179+
requests.post(address, json=json)
180+
_logger.info("Dumped profiling information in: %s", log_dir)
181+
182+
141183
def monkey_patch_jax():
142184
"""Monkey patches JAX with Pathways versions of functions.
143185
@@ -158,25 +200,27 @@ def start_trace_patch(
158200
create_perfetto_link: bool = False, # pylint: disable=unused-argument
159201
create_perfetto_trace: bool = False, # pylint: disable=unused-argument
160202
) -> None:
161-
logger.debug("jax.profile.start_trace patched with pathways' start_trace")
203+
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
162204
return start_trace(log_dir)
163205

164206
jax.profiler.start_trace = start_trace_patch
165207

166208
def stop_trace_patch() -> None:
167-
logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
209+
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
168210
return stop_trace()
169211

170212
jax.profiler.stop_trace = stop_trace_patch
171213

172214
def start_server_patch(port: int):
173-
logger.debug("jax.profile.start_server patched with pathways' start_server")
215+
_logger.debug(
216+
"jax.profile.start_server patched with pathways' start_server"
217+
)
174218
return start_server(port)
175219

176220
jax.profiler.start_server = start_server_patch
177221

178222
def stop_server_patch():
179-
logger.debug("jax.profile.stop_server patched with pathways' stop_server")
223+
_logger.debug("jax.profile.stop_server patched with pathways' stop_server")
180224
return stop_server()
181225

182226
jax.profiler.stop_server = stop_server_patch
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pathlib
16+
from unittest import mock
17+
18+
from pathwaysutils import profiling
19+
import requests
20+
21+
from absl.testing import absltest
22+
from absl.testing import parameterized
23+
24+
25+
class ProfilingTest(parameterized.TestCase):
26+
"""Tests for Pathways on Cloud profiling."""
27+
28+
def setUp(self):
29+
super().setUp()
30+
self.mock_post = self.enter_context(
31+
mock.patch.object(requests, "post", autospec=True)
32+
)
33+
34+
@parameterized.parameters(
35+
(8000,),
36+
(1234,),
37+
)
38+
def test_collect_profile_port(self, port):
39+
profiling.collect_profile(
40+
port=port,
41+
duration_ms=1000,
42+
host="127.0.0.1",
43+
log_dir="/tmp/test_log_dir",
44+
)
45+
46+
self.mock_post.assert_called_once_with(
47+
f"http://127.0.0.1:{port}/profiling",
48+
json={
49+
"duration_ms": 1000,
50+
"repository_path": "/tmp/test_log_dir",
51+
},
52+
)
53+
54+
@parameterized.parameters(
55+
(1000,),
56+
(1234,),
57+
)
58+
def test_collect_profile_duration_ms(self, duration_ms):
59+
profiling.collect_profile(
60+
port=8000,
61+
duration_ms=duration_ms,
62+
host="127.0.0.1",
63+
log_dir="/tmp/test_log_dir",
64+
)
65+
66+
self.mock_post.assert_called_once_with(
67+
"http://127.0.0.1:8000/profiling",
68+
json={
69+
"duration_ms": duration_ms,
70+
"repository_path": "/tmp/test_log_dir",
71+
},
72+
)
73+
74+
@parameterized.parameters(
75+
("127.0.0.1",),
76+
("localhost",),
77+
("192.168.1.1",),
78+
)
79+
def test_collect_profile_host(self, host):
80+
profiling.collect_profile(
81+
port=8000, duration_ms=1000, host=host, log_dir="/tmp/test_log_dir"
82+
)
83+
84+
self.mock_post.assert_called_once_with(
85+
f"http://{host}:8000/profiling",
86+
json={
87+
"duration_ms": 1000,
88+
"repository_path": "/tmp/test_log_dir",
89+
},
90+
)
91+
92+
@parameterized.parameters(
93+
("gs://test_bucket/test_log_dir",),
94+
("/logs/test_log_dir",),
95+
("relative_path/my_log_dir",),
96+
(pathlib.Path("/tmp/test_log_dir"),),
97+
)
98+
def test_collect_profile_log_dir(self, log_dir):
99+
profiling.collect_profile(
100+
port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir
101+
)
102+
103+
self.mock_post.assert_called_once_with(
104+
"http://127.0.0.1:8000/profiling",
105+
json={
106+
"duration_ms": 1000,
107+
"repository_path": str(log_dir),
108+
},
109+
)
110+
111+
def test_collect_profile_no_log_dir(self):
112+
port = 8000
113+
duration_ms = 1000
114+
host = "127.0.0.1"
115+
116+
profiling.collect_profile(port, duration_ms, host, None)
117+
118+
self.mock_post.assert_called_once()
119+
_, mock_kwargs = self.mock_post.call_args
120+
self.assertIn("repository_path", mock_kwargs["json"])
121+
self.assertStartsWith(mock_kwargs["json"]["repository_path"], "/tmp/")
122+
123+
124+
if __name__ == "__main__":
125+
absltest.main()

0 commit comments

Comments
 (0)