diff --git a/pathwaysutils/collect_profile.py b/pathwaysutils/collect_profile.py new file mode 100644 index 0000000..18ecf57 --- /dev/null +++ b/pathwaysutils/collect_profile.py @@ -0,0 +1,67 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for collecting JAX profiles for Pathways on Cloud. + +This is a replacement for the `collect_profile` script in JAX that works with +Pathways on Cloud. +""" + +import argparse +import logging + +from pathwaysutils import profiling + +_logger = logging.getLogger(__name__) + + +_DESCRIPTION = """ +To profile running JAX programs, you first need to start the profiler server +in the program of interest. You can do this via +`jax.profiler.start_server()`. Once the program is running and the +profiler server has started, you can run `collect_profile` to trace the execution +for a provided duration. The trace file will be dumped into a GCS bucket +(determined by `--log_dir`). +""" +parser = argparse.ArgumentParser(description=_DESCRIPTION) +parser.add_argument( + "--log_dir", + required=True, + help="GCS path to store log files.", + type=str, +) +parser.add_argument("port", help="Port to collect trace", type=int) +parser.add_argument( + "duration_ms", help="Duration to collect trace in milliseconds", type=int +) +parser.add_argument( + "--host", + default="127.0.0.1", + help=( + "Host to collect trace. This host IP/DNS address should be accessible" + " from where this API is being called. Defaults to 127.0.0.1" + ), + type=str, +) + + +def main(args): + if profiling.collect_profile( + args.port, args.duration_ms, args.host, args.log_dir + ): + _logger.info("Dumped profiling information in: %s", args.log_dir) + else: + _logger.error("Failed to collect profiling information.") + +if __name__ == "__main__": + main(parser.parse_args()) diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index e8ccbd7..1fc705f 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -15,19 +15,28 @@ import dataclasses import logging +import os +import pathlib +import tempfile import threading import time +import urllib.parse import fastapi import jax from jax import numpy as jnp from pathwaysutils import plugin_executable +import requests import uvicorn -logger = logging.getLogger(__name__) + +_logger = logging.getLogger(__name__) class _ProfileState: + executable: plugin_executable.PluginExecutable | None = None + lock: threading.Lock + def __init__(self): self.executable = None self.lock = threading.Lock() @@ -88,7 +97,7 @@ def stop_trace(): _original_stop_trace() -_profiler_thread = None +_profiler_thread: threading.Thread | None = None def start_server(port: int): @@ -102,7 +111,7 @@ def start_server(port: int): port : The port to start the server on. """ def server_loop(port: int): - logger.debug("Starting JAX profiler server on port %s", port) + _logger.debug("Starting JAX profiler server on port %s", port) app = fastapi.FastAPI() @dataclasses.dataclass @@ -112,14 +121,14 @@ class ProfilingConfig: @app.post("/profiling") async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable - logger.debug("Capturing profiling data for %s ms", pc.duration_ms) - logger.debug("Writing profiling data to %s", pc.repository_path) + _logger.debug("Capturing profiling data for %s ms", pc.duration_ms) + _logger.debug("Writing profiling data to %s", pc.repository_path) jax.profiler.start_trace(pc.repository_path) time.sleep(pc.duration_ms / 1e3) jax.profiler.stop_trace() return {"response": "profiling completed"} - uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") + uvicorn.run(app, host="0.0.0.0", port=port, log_level="debug") global _profiler_thread if _profiler_thread is not None: @@ -138,6 +147,44 @@ def stop_server(): raise ValueError("No active profiler server.") +def collect_profile( + port: int, + duration_ms: int, + host: str, + log_dir: str, +) -> bool: + """Collects a JAX profile and saves it to the specified directory. + + Args: + port: The port on which the JAX profiler server is running. + duration_ms: The duration in milliseconds for which to collect the profile. + host: The host on which the JAX profiler server is running. + log_dir: The GCS path to save the profile data. + + Returns: + True if the profile was collected successfully, False otherwise. + + Raises: + ValueError: If the log_dir is not a GCS path. + """ + if not log_dir.startswith("gs://"): + raise ValueError("log_dir must be a GCS path.") + + json = { + "duration_ms": duration_ms, + "repository_path": log_dir, + } + address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling") + try: + response = requests.post(address, json=json) + response.raise_for_status() + except requests.exceptions.RequestException as e: + _logger.error("Failed to collect profiling data: %s", e) + return False + + return True + + def monkey_patch_jax(): """Monkey patches JAX with Pathways versions of functions. @@ -158,25 +205,27 @@ def start_trace_patch( create_perfetto_link: bool = False, # pylint: disable=unused-argument create_perfetto_trace: bool = False, # pylint: disable=unused-argument ) -> None: - logger.debug("jax.profile.start_trace patched with pathways' start_trace") + _logger.debug("jax.profile.start_trace patched with pathways' start_trace") return start_trace(log_dir) jax.profiler.start_trace = start_trace_patch def stop_trace_patch() -> None: - logger.debug("jax.profile.stop_trace patched with pathways' stop_trace") + _logger.debug("jax.profile.stop_trace patched with pathways' stop_trace") return stop_trace() jax.profiler.stop_trace = stop_trace_patch def start_server_patch(port: int): - logger.debug("jax.profile.start_server patched with pathways' start_server") + _logger.debug( + "jax.profile.start_server patched with pathways' start_server" + ) return start_server(port) jax.profiler.start_server = start_server_patch def stop_server_patch(): - logger.debug("jax.profile.stop_server patched with pathways' stop_server") + _logger.debug("jax.profile.stop_server patched with pathways' stop_server") return stop_server() jax.profiler.stop_server = stop_server_patch diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py new file mode 100644 index 0000000..26a24d9 --- /dev/null +++ b/pathwaysutils/test/profiling_test.py @@ -0,0 +1,147 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from pathwaysutils import profiling +import requests + +from absl.testing import absltest +from absl.testing import parameterized + + +class ProfilingTest(parameterized.TestCase): + """Tests for Pathways on Cloud profiling.""" + + def setUp(self): + super().setUp() + self.mock_post = self.enter_context( + mock.patch.object(requests, "post", autospec=True) + ) + + @parameterized.parameters(8000, 1234) + def test_collect_profile_port(self, port): + profiling.collect_profile( + port=port, + duration_ms=1000, + host="127.0.0.1", + log_dir="gs://test_bucket/test_dir", + ) + + self.mock_post.assert_called_once_with( + f"http://127.0.0.1:{port}/profiling", + json={ + "duration_ms": 1000, + "repository_path": "gs://test_bucket/test_dir", + }, + ) + + @parameterized.parameters(1000, 1234) + def test_collect_profile_duration_ms(self, duration_ms): + profiling.collect_profile( + port=8000, + duration_ms=duration_ms, + host="127.0.0.1", + log_dir="gs://test_bucket/test_dir", + ) + + self.mock_post.assert_called_once_with( + "http://127.0.0.1:8000/profiling", + json={ + "duration_ms": duration_ms, + "repository_path": "gs://test_bucket/test_dir", + }, + ) + + @parameterized.parameters("127.0.0.1", "localhost", "192.168.1.1") + def test_collect_profile_host(self, host): + profiling.collect_profile( + port=8000, + duration_ms=1000, + host=host, + log_dir="gs://test_bucket/test_dir", + ) + + self.mock_post.assert_called_once_with( + f"http://{host}:8000/profiling", + json={ + "duration_ms": 1000, + "repository_path": "gs://test_bucket/test_dir", + }, + ) + + @parameterized.parameters( + "gs://test_bucket/test_log_dir", + "gs://test_bucket2", + "gs://test_bucket3/test/log/dir", + ) + def test_collect_profile_log_dir(self, log_dir): + profiling.collect_profile( + port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir + ) + + self.mock_post.assert_called_once_with( + "http://127.0.0.1:8000/profiling", + json={ + "duration_ms": 1000, + "repository_path": log_dir, + }, + ) + + @parameterized.parameters("/logs/test_log_dir", "relative_path/my_log_dir") + def test_collect_profile_log_dir_error(self, log_dir): + with self.assertRaises(ValueError): + profiling.collect_profile( + port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir + ) + + @parameterized.parameters( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.TooManyRedirects, + requests.exceptions.RequestException, + requests.exceptions.HTTPError, + ) + def test_collect_profile_request_error(self, exception_type): + self.mock_post.side_effect = exception_type + + result = profiling.collect_profile( + port=8000, + duration_ms=1000, + host="127.0.0.1", + log_dir="gs://test_bucket/test_dir", + ) + + self.assertFalse(result) + self.mock_post.assert_called_once() + + def test_collect_profile_success(self): + mock_response = mock.Mock() + mock_response.raise_for_status.return_value = None + self.mock_post.return_value = mock_response + + result = profiling.collect_profile( + port=8000, + duration_ms=1000, + host="127.0.0.1", + log_dir="gs://test_bucket/test_dir", + ) + + self.assertTrue(result) + self.mock_post.assert_called_once() + mock_response.raise_for_status.assert_called_once() + + +if __name__ == "__main__": + absltest.main()