Skip to content

Commit e14290a

Browse files
lukebaumanncopybara-github
authored andcommitted
Port the collect_profile script from JAX to PathwaysUtils
PiperOrigin-RevId: 746578297
1 parent 42640da commit e14290a

3 files changed

Lines changed: 273 additions & 10 deletions

File tree

pathwaysutils/collect_profile.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Module for collecting JAX profiles for Pathways on Cloud.
15+
16+
This is a replacement for the `collect_profile` script in JAX that works with
17+
Pathways on Cloud.
18+
"""
19+
20+
import argparse
21+
import logging
22+
23+
from pathwaysutils import profiling
24+
25+
_logger = logging.getLogger(__name__)
26+
27+
28+
_DESCRIPTION = """
29+
To profile running JAX programs, you first need to start the profiler server
30+
in the program of interest. You can do this via
31+
`jax.profiler.start_server(<port>)`. Once the program is running and the
32+
profiler server has started, you can run `collect_profile` to trace the execution
33+
for a provided duration. The trace file will be dumped into a GCS bucket
34+
(determined by `--log_dir`).
35+
"""
36+
parser = argparse.ArgumentParser(description=_DESCRIPTION)
37+
parser.add_argument(
38+
"--log_dir",
39+
required=True,
40+
help="GCS path to store log files.",
41+
type=str,
42+
)
43+
parser.add_argument("port", help="Port to collect trace", type=int)
44+
parser.add_argument(
45+
"duration_ms", help="Duration to collect trace in milliseconds", type=int
46+
)
47+
parser.add_argument(
48+
"--host",
49+
default="127.0.0.1",
50+
help=(
51+
"Host to collect trace. This host IP/DNS address should be accessible"
52+
" from where this API is being called. Defaults to 127.0.0.1"
53+
),
54+
type=str,
55+
)
56+
57+
58+
def main(args):
59+
if profiling.collect_profile(
60+
args.port, args.duration_ms, args.host, args.log_dir
61+
):
62+
_logger.info("Dumped profiling information in: %s", args.log_dir)
63+
else:
64+
_logger.error("Failed to collect profiling information.")
65+
66+
if __name__ == "__main__":
67+
main(parser.parse_args())

pathwaysutils/profiling.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,28 @@
1515

1616
import dataclasses
1717
import logging
18+
import os
19+
import pathlib
20+
import tempfile
1821
import threading
1922
import time
23+
import urllib.parse
2024

2125
import fastapi
2226
import jax
2327
from jax import numpy as jnp
2428
from pathwaysutils import plugin_executable
29+
import requests
2530
import uvicorn
2631

27-
logger = logging.getLogger(__name__)
32+
33+
_logger = logging.getLogger(__name__)
2834

2935

3036
class _ProfileState:
37+
executable: plugin_executable.PluginExecutable | None = None
38+
lock: threading.Lock
39+
3140
def __init__(self):
3241
self.executable = None
3342
self.lock = threading.Lock()
@@ -88,7 +97,7 @@ def stop_trace():
8897
_original_stop_trace()
8998

9099

91-
_profiler_thread = None
100+
_profiler_thread: threading.Thread | None = None
92101

93102

94103
def start_server(port: int):
@@ -102,7 +111,7 @@ def start_server(port: int):
102111
port : The port to start the server on.
103112
"""
104113
def server_loop(port: int):
105-
logger.debug("Starting JAX profiler server on port %s", port)
114+
_logger.debug("Starting JAX profiler server on port %s", port)
106115
app = fastapi.FastAPI()
107116

108117
@dataclasses.dataclass
@@ -112,14 +121,14 @@ class ProfilingConfig:
112121

113122
@app.post("/profiling")
114123
async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
115-
logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
116-
logger.debug("Writing profiling data to %s", pc.repository_path)
124+
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
125+
_logger.debug("Writing profiling data to %s", pc.repository_path)
117126
jax.profiler.start_trace(pc.repository_path)
118127
time.sleep(pc.duration_ms / 1e3)
119128
jax.profiler.stop_trace()
120129
return {"response": "profiling completed"}
121130

122-
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
131+
uvicorn.run(app, host="0.0.0.0", port=port, log_level="debug")
123132

124133
global _profiler_thread
125134
if _profiler_thread is not None:
@@ -138,6 +147,44 @@ def stop_server():
138147
raise ValueError("No active profiler server.")
139148

140149

150+
def collect_profile(
151+
port: int,
152+
duration_ms: int,
153+
host: str,
154+
log_dir: str,
155+
) -> bool:
156+
"""Collects a JAX profile and saves it to the specified directory.
157+
158+
Args:
159+
port: The port on which the JAX profiler server is running.
160+
duration_ms: The duration in milliseconds for which to collect the profile.
161+
host: The host on which the JAX profiler server is running.
162+
log_dir: The GCS path to save the profile data.
163+
164+
Returns:
165+
True if the profile was collected successfully, False otherwise.
166+
167+
Raises:
168+
ValueError: If the log_dir is not a GCS path.
169+
"""
170+
if not log_dir.startswith("gs://"):
171+
raise ValueError("log_dir must be a GCS path.")
172+
173+
json = {
174+
"duration_ms": duration_ms,
175+
"repository_path": log_dir,
176+
}
177+
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
178+
try:
179+
response = requests.post(address, json=json)
180+
response.raise_for_status()
181+
except requests.exceptions.RequestException as e:
182+
_logger.error("Failed to collect profiling data: %s", e)
183+
return False
184+
185+
return True
186+
187+
141188
def monkey_patch_jax():
142189
"""Monkey patches JAX with Pathways versions of functions.
143190
@@ -158,25 +205,27 @@ def start_trace_patch(
158205
create_perfetto_link: bool = False, # pylint: disable=unused-argument
159206
create_perfetto_trace: bool = False, # pylint: disable=unused-argument
160207
) -> None:
161-
logger.debug("jax.profile.start_trace patched with pathways' start_trace")
208+
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
162209
return start_trace(log_dir)
163210

164211
jax.profiler.start_trace = start_trace_patch
165212

166213
def stop_trace_patch() -> None:
167-
logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
214+
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
168215
return stop_trace()
169216

170217
jax.profiler.stop_trace = stop_trace_patch
171218

172219
def start_server_patch(port: int):
173-
logger.debug("jax.profile.start_server patched with pathways' start_server")
220+
_logger.debug(
221+
"jax.profile.start_server patched with pathways' start_server"
222+
)
174223
return start_server(port)
175224

176225
jax.profiler.start_server = start_server_patch
177226

178227
def stop_server_patch():
179-
logger.debug("jax.profile.stop_server patched with pathways' stop_server")
228+
_logger.debug("jax.profile.stop_server patched with pathways' stop_server")
180229
return stop_server()
181230

182231
jax.profiler.stop_server = stop_server_patch
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
from pathwaysutils import profiling
18+
import requests
19+
20+
from absl.testing import absltest
21+
from absl.testing import parameterized
22+
23+
24+
class ProfilingTest(parameterized.TestCase):
25+
"""Tests for Pathways on Cloud profiling."""
26+
27+
def setUp(self):
28+
super().setUp()
29+
self.mock_post = self.enter_context(
30+
mock.patch.object(requests, "post", autospec=True)
31+
)
32+
33+
@parameterized.parameters(8000, 1234)
34+
def test_collect_profile_port(self, port):
35+
profiling.collect_profile(
36+
port=port,
37+
duration_ms=1000,
38+
host="127.0.0.1",
39+
log_dir="gs://test_bucket/test_dir",
40+
)
41+
42+
self.mock_post.assert_called_once_with(
43+
f"http://127.0.0.1:{port}/profiling",
44+
json={
45+
"duration_ms": 1000,
46+
"repository_path": "gs://test_bucket/test_dir",
47+
},
48+
)
49+
50+
@parameterized.parameters(1000, 1234)
51+
def test_collect_profile_duration_ms(self, duration_ms):
52+
profiling.collect_profile(
53+
port=8000,
54+
duration_ms=duration_ms,
55+
host="127.0.0.1",
56+
log_dir="gs://test_bucket/test_dir",
57+
)
58+
59+
self.mock_post.assert_called_once_with(
60+
"http://127.0.0.1:8000/profiling",
61+
json={
62+
"duration_ms": duration_ms,
63+
"repository_path": "gs://test_bucket/test_dir",
64+
},
65+
)
66+
67+
@parameterized.parameters("127.0.0.1", "localhost", "192.168.1.1")
68+
def test_collect_profile_host(self, host):
69+
profiling.collect_profile(
70+
port=8000,
71+
duration_ms=1000,
72+
host=host,
73+
log_dir="gs://test_bucket/test_dir",
74+
)
75+
76+
self.mock_post.assert_called_once_with(
77+
f"http://{host}:8000/profiling",
78+
json={
79+
"duration_ms": 1000,
80+
"repository_path": "gs://test_bucket/test_dir",
81+
},
82+
)
83+
84+
@parameterized.parameters(
85+
"gs://test_bucket/test_log_dir",
86+
"gs://test_bucket2",
87+
"gs://test_bucket3/test/log/dir",
88+
)
89+
def test_collect_profile_log_dir(self, log_dir):
90+
profiling.collect_profile(
91+
port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir
92+
)
93+
94+
self.mock_post.assert_called_once_with(
95+
"http://127.0.0.1:8000/profiling",
96+
json={
97+
"duration_ms": 1000,
98+
"repository_path": log_dir,
99+
},
100+
)
101+
102+
@parameterized.parameters("/logs/test_log_dir", "relative_path/my_log_dir")
103+
def test_collect_profile_log_dir_error(self, log_dir):
104+
with self.assertRaises(ValueError):
105+
profiling.collect_profile(
106+
port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir
107+
)
108+
109+
@parameterized.parameters(
110+
requests.exceptions.ConnectionError,
111+
requests.exceptions.Timeout,
112+
requests.exceptions.TooManyRedirects,
113+
requests.exceptions.RequestException,
114+
requests.exceptions.HTTPError,
115+
)
116+
def test_collect_profile_request_error(self, exception_type):
117+
self.mock_post.side_effect = exception_type
118+
119+
result = profiling.collect_profile(
120+
port=8000,
121+
duration_ms=1000,
122+
host="127.0.0.1",
123+
log_dir="gs://test_bucket/test_dir",
124+
)
125+
126+
self.assertFalse(result)
127+
self.mock_post.assert_called_once()
128+
129+
def test_collect_profile_success(self):
130+
mock_response = mock.Mock()
131+
mock_response.raise_for_status.return_value = None
132+
self.mock_post.return_value = mock_response
133+
134+
result = profiling.collect_profile(
135+
port=8000,
136+
duration_ms=1000,
137+
host="127.0.0.1",
138+
log_dir="gs://test_bucket/test_dir",
139+
)
140+
141+
self.assertTrue(result)
142+
self.mock_post.assert_called_once()
143+
mock_response.raise_for_status.assert_called_once()
144+
145+
146+
if __name__ == "__main__":
147+
absltest.main()

0 commit comments

Comments
 (0)