Skip to content

Commit eaff354

Browse files
authored
Merge branch 'main' into elastic-readme
2 parents b8c1471 + 4dd4f54 commit eaff354

4 files changed

Lines changed: 37 additions & 20 deletions

File tree

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
## [0.1.1] - 2025-04-25
27+
28+
* Port the `collect_profile` script from JAX to PathwaysUtils
29+
* Remove support for legacy initialize
30+
* Add collect_profile as a script of pathwaysutils
31+
* Make CloudPathwaysArrayHandler compatible with async directory creation feature in orbax
32+
2633
## [0.1.0] - 2025-04-07
2734
* Bump the JAX requirement to 0.5.1
2835
* Introduce `pathwaysutils.initialize()` to remove relying on side-effects from `import pathwaysutils`. by @copybara-service in https://github.com/AI-Hypercomputer/pathways-utils/pull/47
@@ -47,5 +54,5 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
4754
* Persistence enabled
4855
* General argument type fixes
4956

50-
[Unreleased]: https://github.com/AI-Hypercomputer/pathways-utils/compare/v0.1.0...HEAD
57+
[Unreleased]: https://github.com/AI-Hypercomputer/pathways-utils/compare/v0.1.1...HEAD
5158

pathwaysutils/sidecar/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ spec:
218218
- -c
219219
- |
220220
pip install --upgrade pip
221-
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
221+
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
222222
pip install pathwaysutils
223223
python -c "import jax; import pathwaysutils; print(\"Number of JAX devices is\", len(jax.devices()))"
224224
```

pathwaysutils/test/profiling_test.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
from unittest import mock
1617

1718
from pathwaysutils import profiling
@@ -32,13 +33,14 @@ def setUp(self):
3233

3334
@parameterized.parameters(8000, 1234)
3435
def test_collect_profile_port(self, port):
35-
profiling.collect_profile(
36+
result = profiling.collect_profile(
3637
port=port,
3738
duration_ms=1000,
3839
host="127.0.0.1",
3940
log_dir="gs://test_bucket/test_dir",
4041
)
4142

43+
self.assertTrue(result)
4244
self.mock_post.assert_called_once_with(
4345
f"http://127.0.0.1:{port}/profiling",
4446
json={
@@ -49,13 +51,14 @@ def test_collect_profile_port(self, port):
4951

5052
@parameterized.parameters(1000, 1234)
5153
def test_collect_profile_duration_ms(self, duration_ms):
52-
profiling.collect_profile(
54+
result = profiling.collect_profile(
5355
port=8000,
5456
duration_ms=duration_ms,
5557
host="127.0.0.1",
5658
log_dir="gs://test_bucket/test_dir",
5759
)
5860

61+
self.assertTrue(result)
5962
self.mock_post.assert_called_once_with(
6063
"http://127.0.0.1:8000/profiling",
6164
json={
@@ -66,13 +69,14 @@ def test_collect_profile_duration_ms(self, duration_ms):
6669

6770
@parameterized.parameters("127.0.0.1", "localhost", "192.168.1.1")
6871
def test_collect_profile_host(self, host):
69-
profiling.collect_profile(
72+
result = profiling.collect_profile(
7073
port=8000,
7174
duration_ms=1000,
7275
host=host,
7376
log_dir="gs://test_bucket/test_dir",
7477
)
7578

79+
self.assertTrue(result)
7680
self.mock_post.assert_called_once_with(
7781
f"http://{host}:8000/profiling",
7882
json={
@@ -87,10 +91,11 @@ def test_collect_profile_host(self, host):
8791
"gs://test_bucket3/test/log/dir",
8892
)
8993
def test_collect_profile_log_dir(self, log_dir):
90-
profiling.collect_profile(
94+
result = profiling.collect_profile(
9195
port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir
9296
)
9397

98+
self.assertTrue(result)
9499
self.mock_post.assert_called_once_with(
95100
"http://127.0.0.1:8000/profiling",
96101
json={
@@ -107,22 +112,27 @@ def test_collect_profile_log_dir_error(self, log_dir):
107112
)
108113

109114
@parameterized.parameters(
110-
requests.exceptions.ConnectionError,
111-
requests.exceptions.Timeout,
112-
requests.exceptions.TooManyRedirects,
113-
requests.exceptions.RequestException,
114-
requests.exceptions.HTTPError,
115+
requests.exceptions.ConnectionError("Connection error"),
116+
requests.exceptions.Timeout("Timeout"),
117+
requests.exceptions.TooManyRedirects("Too many redirects"),
118+
requests.exceptions.RequestException("Request exception"),
119+
requests.exceptions.HTTPError("HTTP error"),
115120
)
116-
def test_collect_profile_request_error(self, exception_type):
117-
self.mock_post.side_effect = exception_type
121+
def test_collect_profile_request_error(self, exception):
122+
self.mock_post.side_effect = exception
123+
124+
with self.assertLogs(profiling._logger, level=logging.ERROR) as logs:
125+
result = profiling.collect_profile(
126+
port=8000,
127+
duration_ms=1000,
128+
host="127.0.0.1",
129+
log_dir="gs://test_bucket/test_dir",
130+
)
118131

119-
result = profiling.collect_profile(
120-
port=8000,
121-
duration_ms=1000,
122-
host="127.0.0.1",
123-
log_dir="gs://test_bucket/test_dir",
132+
self.assertLen(logs.output, 1)
133+
self.assertIn(
134+
f"Failed to collect profiling data: {exception}", logs.output[0]
124135
)
125-
126136
self.assertFalse(result)
127137
self.mock_post.assert_called_once()
128138

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ keywords = []
1414
dependencies = [
1515
"absl-py",
1616
"fastapi",
17-
"google-cloud-logging",
1817
"jax>=0.4.26",
1918
"orbax-checkpoint",
2019
"uvicorn",
20+
"requests",
2121
]
2222

2323
# `version` is automatically set by flit to use `my_project.__version__`

0 commit comments

Comments
 (0)