Skip to content

Commit 0bf23bb

Browse files
lukebaumanncopybara-github
authored andcommitted
Clean up ProfileOptions getattr usages in pathwaysutils.
PiperOrigin-RevId: 926483861
1 parent 9e2bd20 commit 0bf23bb

2 files changed

Lines changed: 119 additions & 38 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import asyncio
1717
from collections.abc import Mapping
1818
import dataclasses
19+
import datetime
1920
import json
2021
import logging
2122
import os
@@ -112,8 +113,8 @@ def _is_default_profile_options(
112113
and profiler_options.python_tracer_level
113114
== default_options.python_tracer_level
114115
and profiler_options.duration_ms == default_options.duration_ms
115-
and not getattr(profiler_options, "advanced_configuration", None)
116-
and not getattr(profiler_options, "session_id", None)
116+
and not profiler_options.advanced_configuration
117+
and not profiler_options.session_id
117118
)
118119

119120

@@ -132,9 +133,9 @@ def _create_profile_request(
132133
return profile_request
133134

134135
advanced_config = None
135-
if getattr(profiler_options, "advanced_configuration", None):
136+
if profiler_options.advanced_configuration:
136137
advanced_config = {}
137-
for k, v in getattr(profiler_options, "advanced_configuration").items():
138+
for k, v in profiler_options.advanced_configuration.items():
138139
# Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue
139140
# json-compatible dict
140141
if isinstance(v, bool):
@@ -168,7 +169,7 @@ def _create_profile_request(
168169
if pw_trace_opts:
169170
xprof_options["pwTraceOptions"] = pw_trace_opts
170171

171-
if getattr(profiler_options, "session_id", None):
172+
if profiler_options.session_id:
172173
xprof_options["traceSessionName"] = profiler_options.session_id
173174

174175
profile_request["xprofTraceOptions"] = xprof_options
@@ -270,26 +271,44 @@ def start_trace(
270271
"features for Pathways on Cloud and may not be fully supported."
271272
)
272273

273-
if jax.version.__version_info__ < (0, 9, 2) and profiler_options is not None:
274-
_logger.warning(
275-
"ProfileOptions are not supported until JAX 0.9.2 and will be omitted. "
276-
"Some options can be specified via command line flags."
277-
)
278-
profiler_options = None
274+
if jax.version.__version_info__ < (0, 9, 2):
275+
if profiler_options is not None:
276+
_logger.warning(
277+
"ProfileOptions are not supported until JAX 0.9.2 and will be omitted. "
278+
"Some options can be specified via command line flags."
279+
)
280+
profiler_options = None
281+
else:
282+
if profiler_options is None:
283+
profiler_options = jax.profiler.ProfileOptions()
284+
if not profiler_options.session_id:
285+
profiler_options.session_id = datetime.datetime.now().strftime(
286+
"%Y_%m_%d_%H_%M_%S"
287+
)
279288

280289
profile_request = _create_profile_request(
281-
log_dir, profiler_options, max_num_hosts=max_num_hosts
290+
log_dir,
291+
profiler_options,
292+
max_num_hosts=max_num_hosts,
282293
)
283294

284295
_logger.debug("Profile request: %s", profile_request)
285296

286297
_start_pathways_trace_from_profile_request(profile_request)
287298

288-
_original_start_trace(
289-
log_dir=log_dir,
290-
create_perfetto_link=create_perfetto_link,
291-
create_perfetto_trace=create_perfetto_trace,
292-
)
299+
if jax.version.__version_info__ >= (0, 9, 2):
300+
_original_start_trace(
301+
log_dir=log_dir,
302+
create_perfetto_link=create_perfetto_link,
303+
create_perfetto_trace=create_perfetto_trace,
304+
profiler_options=profiler_options,
305+
)
306+
else:
307+
_original_start_trace(
308+
log_dir=log_dir,
309+
create_perfetto_link=create_perfetto_link,
310+
create_perfetto_trace=create_perfetto_trace,
311+
)
293312

294313

295314
def stop_trace() -> None:

pathwaysutils/test/profiling_test.py

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import logging
1717
from unittest import mock
18+
from typing import Any
1819

1920
from absl.testing import absltest
2021
from absl.testing import parameterized
@@ -53,6 +54,40 @@ def setUp(self):
5354
self.mock_original_stop_trace = self.enter_context(
5455
mock.patch.object(profiling, "_original_stop_trace", autospec=True)
5556
)
57+
self.mock_datetime = self.enter_context(
58+
mock.patch.object(profiling.datetime, "datetime", autospec=True)
59+
)
60+
self.mock_datetime.now.return_value.strftime.return_value = (
61+
"2026_06_04_05_29_33"
62+
)
63+
64+
def _get_expected_profile_request(
65+
self,
66+
trace_location: str,
67+
max_num_hosts: int = 1,
68+
session_id: str = "2026_06_04_05_29_33",
69+
) -> dict[str, Any]:
70+
if jax.version.__version_info__ >= (0, 9, 2):
71+
return {
72+
"profileRequest": {
73+
"traceLocation": trace_location,
74+
"maxNumHosts": max_num_hosts,
75+
"xprofTraceOptions": {
76+
"traceDirectory": trace_location,
77+
"pwTraceOptions": {
78+
"enablePythonTracer": True,
79+
},
80+
"traceSessionName": session_id,
81+
},
82+
}
83+
}
84+
else:
85+
return {
86+
"profileRequest": {
87+
"traceLocation": trace_location,
88+
"maxNumHosts": max_num_hosts,
89+
}
90+
}
5691

5792
@parameterized.parameters(8000, 1234)
5893
def test_collect_profile_port(self, port):
@@ -228,40 +263,67 @@ def test_start_trace_success(self):
228263
profiling.start_trace("gs://test_bucket/test_dir")
229264

230265
self.mock_toy_computation.assert_called_once()
266+
expected_request = self._get_expected_profile_request(
267+
"gs://test_bucket/test_dir", max_num_hosts=1
268+
)
231269
self.mock_plugin_executable_cls.assert_called_once_with(
232-
json.dumps({
233-
"profileRequest": {
234-
"traceLocation": "gs://test_bucket/test_dir",
235-
"maxNumHosts": 1,
236-
}
237-
})
270+
json.dumps(expected_request)
238271
)
239272
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
240-
self.mock_original_start_trace.assert_called_once_with(
241-
log_dir="gs://test_bucket/test_dir",
242-
create_perfetto_link=False,
243-
create_perfetto_trace=False,
244-
)
273+
self.mock_original_start_trace.assert_called_once()
274+
call_args = self.mock_original_start_trace.call_args[1]
275+
self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir")
276+
self.assertFalse(call_args["create_perfetto_link"])
277+
self.assertFalse(call_args["create_perfetto_trace"])
278+
if jax.version.__version_info__ >= (0, 9, 2):
279+
self.assertEqual(
280+
call_args["profiler_options"].session_id, "2026_06_04_05_29_33"
281+
)
245282
self.assertIsNotNone(profiling._profile_state.executable)
246283

247284
def test_start_trace_with_max_num_hosts(self):
248285
profiling.start_trace("gs://test_bucket/test_dir", max_num_hosts=10)
249286

250287
self.mock_toy_computation.assert_called_once()
288+
expected_request = self._get_expected_profile_request(
289+
"gs://test_bucket/test_dir", max_num_hosts=10
290+
)
251291
self.mock_plugin_executable_cls.assert_called_once_with(
252-
json.dumps({
253-
"profileRequest": {
254-
"traceLocation": "gs://test_bucket/test_dir",
255-
"maxNumHosts": 10,
256-
}
257-
})
292+
json.dumps(expected_request)
258293
)
259294
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
260-
self.mock_original_start_trace.assert_called_once_with(
261-
log_dir="gs://test_bucket/test_dir",
262-
create_perfetto_link=False,
263-
create_perfetto_trace=False,
295+
self.mock_original_start_trace.assert_called_once()
296+
call_args = self.mock_original_start_trace.call_args[1]
297+
self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir")
298+
self.assertFalse(call_args["create_perfetto_link"])
299+
self.assertFalse(call_args["create_perfetto_trace"])
300+
if jax.version.__version_info__ >= (0, 9, 2):
301+
self.assertEqual(
302+
call_args["profiler_options"].session_id, "2026_06_04_05_29_33"
303+
)
304+
305+
@absltest.skipIf(
306+
jax.version.__version_info__ < (0, 9, 2),
307+
"ProfileOptions requires JAX 0.9.2 or newer",
308+
)
309+
def test_start_trace_with_session_id_in_options(self):
310+
options = jax.profiler.ProfileOptions()
311+
options.session_id = "options_session"
312+
profiling.start_trace("gs://test_bucket/test_dir", profiler_options=options)
313+
314+
expected_request = self._get_expected_profile_request(
315+
"gs://test_bucket/test_dir", max_num_hosts=1, session_id="options_session"
264316
)
317+
self.mock_plugin_executable_cls.assert_called_once_with(
318+
json.dumps(expected_request)
319+
)
320+
self.assertEqual(options.session_id, "options_session")
321+
self.mock_original_start_trace.assert_called_once()
322+
call_args = self.mock_original_start_trace.call_args[1]
323+
self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir")
324+
self.assertFalse(call_args["create_perfetto_link"])
325+
self.assertFalse(call_args["create_perfetto_trace"])
326+
self.assertEqual(call_args["profiler_options"].session_id, "options_session")
265327

266328
def test_start_trace_no_toy_computation_second_time(self):
267329
profiling.start_trace("gs://test_bucket/test_dir")

0 commit comments

Comments
 (0)