Skip to content

Commit f22ae6c

Browse files
lukebaumanncopybara-github
authored andcommitted
Plumb session_id from ProfileOptions to traceSessionName.
PiperOrigin-RevId: 919283544
1 parent 557f1a6 commit f22ae6c

2 files changed

Lines changed: 112 additions & 51 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535

3636

3737
class _ProfileState:
38+
"""Holds the state of an ongoing profiling session.
39+
40+
Attributes:
41+
executable: The `plugin_executable.PluginExecutable` instance used for the
42+
profiling session.
43+
profile_request: The mapping containing the profile request options.
44+
lock: A thread lock to protect access to the state.
45+
"""
3846
executable: plugin_executable.PluginExecutable | None = None
3947
profile_request: Mapping[str, Any] | None = None
4048
lock: threading.Lock
@@ -48,6 +56,37 @@ def reset(self) -> None:
4856
self.executable = None
4957
self.profile_request = None
5058

59+
def call_profile_executable(self) -> None:
60+
"""Calls the profiling executable and waits for the result."""
61+
if self.executable is None:
62+
raise RuntimeError(
63+
"_call_profile_executable called with no active executable."
64+
)
65+
# If the profile request contains xprofTraceOptions, then we need to pass
66+
# out_avals and out_shardings to the executable call because the
67+
# executable will return a future that needs to be resolved. This is true
68+
# for both starting and stopping a trace.
69+
if (
70+
self.profile_request is not None
71+
and "xprofTraceOptions" in self.profile_request
72+
):
73+
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
74+
out_shardings = [
75+
getattr(
76+
jax.sharding,
77+
"make_single_device_sharding",
78+
jax.sharding.SingleDeviceSharding,
79+
)(jax.devices()[0])
80+
]
81+
else:
82+
out_avals = ()
83+
out_shardings = ()
84+
85+
_, result_future = self.executable.call(
86+
out_avals=out_avals, out_shardings=out_shardings
87+
)
88+
result_future.result()
89+
5190

5291
_first_profile_start = True
5392
_profile_state = _ProfileState()
@@ -74,6 +113,7 @@ def _is_default_profile_options(
74113
== default_options.python_tracer_level
75114
and profiler_options.duration_ms == default_options.duration_ms
76115
and not getattr(profiler_options, "advanced_configuration", None)
116+
and not getattr(profiler_options, "session_id", None)
77117
)
78118

79119

@@ -128,6 +168,9 @@ def _create_profile_request(
128168
if pw_trace_opts:
129169
xprof_options["pwTraceOptions"] = pw_trace_opts
130170

171+
if getattr(profiler_options, "session_id", None):
172+
xprof_options["traceSessionName"] = profiler_options.session_id
173+
131174
profile_request["xprofTraceOptions"] = xprof_options
132175

133176
if profiler_options.duration_ms > 0:
@@ -153,19 +196,30 @@ def _start_pathways_trace_from_profile_request(
153196
toy_computation()
154197

155198
if _profile_state.executable is not None:
156-
raise ValueError(
199+
raise RuntimeError(
157200
"start_trace called while a trace is already being taken!"
158201
)
159-
_profile_state.executable = plugin_executable.PluginExecutable(
160-
json.dumps({"profileRequest": profile_request})
161-
)
162-
_profile_state.profile_request = profile_request
163202
try:
164-
_, result_future = _profile_state.executable.call()
165-
result_future.result()
166-
except Exception:
167-
_logger.exception("Failed to start trace")
203+
_profile_state.executable = plugin_executable.PluginExecutable(
204+
json.dumps({"profileRequest": profile_request})
205+
)
206+
_profile_state.profile_request = profile_request
207+
_profile_state.call_profile_executable()
208+
except Exception as e:
168209
_profile_state.reset()
210+
if (
211+
"xprofTraceOptions" in profile_request
212+
and "traceSessionName" in profile_request["xprofTraceOptions"]
213+
):
214+
if "Bad PluginProgram" in str(e):
215+
raise RuntimeError(
216+
"Failed to start Pathways trace. The Pathways backend rejected "
217+
"the request, likely because the running Pathways server images "
218+
"do not support the trace session ID option. Please ensure you "
219+
"are running the latest versions of both Pathways server images "
220+
"and the pathwaysutils library."
221+
) from e
222+
_logger.exception("Failed to start trace")
169223
raise
170224

171225

@@ -243,28 +297,9 @@ def stop_trace() -> None:
243297
try:
244298
with _profile_state.lock:
245299
if _profile_state.executable is None:
246-
raise ValueError("stop_trace called before a trace is being taken!")
300+
raise RuntimeError("stop_trace called before a trace is being taken!")
247301
try:
248-
if (
249-
_profile_state.profile_request is not None
250-
and "xprofTraceOptions" in _profile_state.profile_request
251-
):
252-
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
253-
out_shardings = [
254-
getattr(
255-
jax.sharding,
256-
"make_single_device_sharding",
257-
lambda x: jax.sharding.SingleDeviceSharding(x),
258-
)(jax.devices()[0])
259-
]
260-
else:
261-
out_avals = ()
262-
out_shardings = ()
263-
264-
_, result_future = _profile_state.executable.call(
265-
out_avals=out_avals, out_shardings=out_shardings
266-
)
267-
result_future.result()
302+
_profile_state.call_profile_executable()
268303
finally:
269304
_profile_state.reset()
270305
finally:
@@ -306,7 +341,7 @@ async def profiling(pc: ProfilingConfig) -> Mapping[str, str]:
306341

307342
global _profiler_thread
308343
if _profiler_thread is not None:
309-
raise ValueError("Only one profiler server can be active at a time.")
344+
raise RuntimeError("Only one profiler server can be active at a time.")
310345

311346
_profiler_thread = threading.Thread(target=server_loop, args=(port,))
312347
_profiler_thread.start()
@@ -318,7 +353,7 @@ def stop_server() -> None:
318353
Pathways profiling servers are not stoppable at this time.
319354
"""
320355
if _profiler_thread is None:
321-
raise ValueError("No active profiler server.")
356+
raise RuntimeError("No active profiler server.")
322357

323358

324359
def collect_profile(

pathwaysutils/test/profiling_test.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_start_trace_no_toy_computation_second_time(self):
280280

281281
def test_start_trace_while_running_error(self):
282282
profiling.start_trace("gs://test_bucket/test_dir")
283-
with self.assertRaisesRegex(ValueError, "trace is already being taken"):
283+
with self.assertRaisesRegex(RuntimeError, "trace is already being taken"):
284284
profiling.start_trace("gs://test_bucket/test_dir2")
285285

286286
def test_stop_trace_success(self):
@@ -308,22 +308,16 @@ def test_stop_trace_with_xprof_options_passes_out_avals(self):
308308
options = jax.profiler.ProfileOptions()
309309
options.duration_ms = 2000
310310

311-
with mock.patch.object(
312-
profiling, "_profile_state", autospec=True
313-
) as mock_profile_state:
314-
request = profiling._create_profile_request(
315-
"gs://test_bucket/test_dir", options
316-
)
317-
mock_profile_state.profile_request = request
318-
mock_profile_state.executable = (
319-
self.mock_plugin_executable_cls.return_value
320-
)
321-
mock_profile_state.lock = mock.MagicMock()
322-
mock_profile_state.lock.locked.return_value = True
323-
mock_profile_state.lock.__enter__.return_value = None
324-
mock_profile_state.lock.__exit__.return_value = None
311+
request = profiling._create_profile_request(
312+
"gs://test_bucket/test_dir", options
313+
)
314+
profiling._profile_state.profile_request = request
315+
profiling._profile_state.executable = (
316+
self.mock_plugin_executable_cls.return_value
317+
)
318+
self.addCleanup(profiling._profile_state.reset)
325319

326-
profiling.stop_trace()
320+
profiling.stop_trace()
327321

328322
with self.subTest("plugin_executable_called"):
329323
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
@@ -340,7 +334,7 @@ def test_stop_trace_with_xprof_options_passes_out_avals(self):
340334

341335
def test_stop_trace_before_start_error(self):
342336
with self.assertRaisesRegex(
343-
ValueError, "stop_trace called before a trace is being taken!"
337+
RuntimeError, "stop_trace called before a trace is being taken!"
344338
):
345339
profiling.stop_trace()
346340

@@ -359,12 +353,12 @@ def test_start_server_twice_raises_error(self):
359353
)
360354
profiling.start_server(9000)
361355
with self.assertRaisesRegex(
362-
ValueError, "Only one profiler server can be active"
356+
RuntimeError, "Only one profiler server can be active"
363357
):
364358
profiling.start_server(9001)
365359

366360
def test_stop_server_no_server_raises_error(self):
367-
with self.assertRaisesRegex(ValueError, "No active profiler server"):
361+
with self.assertRaisesRegex(RuntimeError, "No active profiler server"):
368362
profiling.stop_server()
369363

370364
def test_stop_server_does_nothing_if_server_exists(self):
@@ -508,6 +502,7 @@ def test_create_profile_request_with_options(self):
508502
options.python_tracer_level = 1
509503
options.duration_ms = 2000
510504
options.start_timestamp_ns = 123456789
505+
options.session_id = "test_session"
511506
options.advanced_configuration = {
512507
"tpu_num_chips_to_profile_per_task": 3,
513508
"tpu_num_sparse_core_tiles_to_trace": 5,
@@ -525,6 +520,7 @@ def test_create_profile_request_with_options(self):
525520
"maxNumHosts": 1,
526521
"xprofTraceOptions": {
527522
"traceDirectory": "gs://bucket/dir",
523+
"traceSessionName": "test_session",
528524
"pwTraceOptions": {
529525
"enablePythonTracer": True,
530526
"advancedConfiguration": {
@@ -609,6 +605,36 @@ def test_jax_profiler_trace_calls_patched_functions(self):
609605
mocks["start_trace"].assert_called_once()
610606
mocks["stop_trace"].assert_called_once()
611607

608+
@absltest.skipIf(
609+
jax.version.__version_info__ < (0, 9, 2),
610+
"ProfileOptions requires JAX 0.9.2 or newer",
611+
)
612+
def test_is_default_profile_options_with_session_id(self):
613+
options = jax.profiler.ProfileOptions()
614+
options.session_id = "test_session"
615+
self.assertFalse(profiling._is_default_profile_options(options))
616+
617+
@absltest.skipIf(
618+
jax.version.__version_info__ < (0, 9, 2),
619+
"ProfileOptions requires JAX 0.9.2 or newer",
620+
)
621+
def test_start_trace_compatibility_error(self):
622+
self.mock_plugin_executable_cls.side_effect = RuntimeError(
623+
"Bad PluginProgram"
624+
)
625+
626+
options = jax.profiler.ProfileOptions()
627+
options.session_id = "test_session"
628+
629+
with self.assertRaisesRegex(
630+
RuntimeError,
631+
"likely because the running Pathways server images do not support the"
632+
" trace session ID option",
633+
):
634+
profiling.start_trace(
635+
"gs://test_bucket/test_dir", profiler_options=options
636+
)
637+
612638

613639
if __name__ == "__main__":
614640
absltest.main()

0 commit comments

Comments
 (0)