Skip to content

Commit 86c2eb1

Browse files
lukebaumanncopybara-github
authored andcommitted
Implement __eq__ monkey patch for JAX ProfileOptions.
PiperOrigin-RevId: 919254885
1 parent f22ae6c commit 86c2eb1

2 files changed

Lines changed: 121 additions & 35 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,29 @@ def toy_computation() -> None:
100100
x.block_until_ready()
101101

102102

103-
def _is_default_profile_options(
104-
profiler_options: jax.profiler.ProfileOptions,
105-
) -> bool:
106-
if jax.version.__version_info__ < (0, 9, 2):
107-
return True
108-
109-
default_options = jax.profiler.ProfileOptions()
103+
def _profile_options_eq(self, other: Any) -> bool:
104+
if not isinstance(other, jax.profiler.ProfileOptions):
105+
return NotImplemented
110106
return (
111-
profiler_options.host_tracer_level == default_options.host_tracer_level
112-
and profiler_options.python_tracer_level
113-
== default_options.python_tracer_level
114-
and profiler_options.duration_ms == default_options.duration_ms
115-
and not getattr(profiler_options, "advanced_configuration", None)
116-
and not getattr(profiler_options, "session_id", None)
107+
getattr(self, "include_dataset_ops", None)
108+
== getattr(other, "include_dataset_ops", None)
109+
and getattr(self, "host_tracer_level", None)
110+
== getattr(other, "host_tracer_level", None)
111+
and getattr(self, "python_tracer_level", None)
112+
== getattr(other, "python_tracer_level", None)
113+
and getattr(self, "enable_hlo_proto", None)
114+
== getattr(other, "enable_hlo_proto", None)
115+
and getattr(self, "start_timestamp_ns", None)
116+
== getattr(other, "start_timestamp_ns", None)
117+
and getattr(self, "duration_ms", None)
118+
== getattr(other, "duration_ms", None)
119+
and getattr(self, "raise_error_on_start_failure", None)
120+
== getattr(other, "raise_error_on_start_failure", None)
121+
and getattr(self, "advanced_configuration", None)
122+
== getattr(other, "advanced_configuration", None)
123+
and getattr(self, "repository_path", None)
124+
== getattr(other, "repository_path", None)
125+
and getattr(self, "session_id", None) == getattr(other, "session_id", None)
117126
)
118127

119128

@@ -128,7 +137,11 @@ def _create_profile_request(
128137
"maxNumHosts": max_num_hosts,
129138
}
130139

131-
if profiler_options is None or _is_default_profile_options(profiler_options):
140+
if (
141+
profiler_options is None
142+
or jax.version.__version_info__ < (0, 9, 2)
143+
or profiler_options == jax.profiler.ProfileOptions()
144+
):
132145
return profile_request
133146

134147
advanced_config = None
@@ -188,6 +201,10 @@ def _start_pathways_trace_from_profile_request(
188201
189202
Args:
190203
profile_request: A mapping containing the profile request options.
204+
205+
Raises:
206+
RuntimeError: If a trace is already active, or if starting the trace fails
207+
due to an incompatible Pathways backend version when using a session ID.
191208
"""
192209
with _profile_state.lock:
193210
global _first_profile_start
@@ -217,7 +234,7 @@ def _start_pathways_trace_from_profile_request(
217234
"the request, likely because the running Pathways server images "
218235
"do not support the trace session ID option. Please ensure you "
219236
"are running the latest versions of both Pathways server images "
220-
"and the pathwaysutils library."
237+
"and the pathwaysutils package."
221238
) from e
222239
_logger.exception("Failed to start trace")
223240
raise
@@ -318,6 +335,9 @@ def start_server(port: int) -> None:
318335
319336
Args:
320337
port : The port to start the server on.
338+
339+
Raises:
340+
RuntimeError: If a profiler server is already active.
321341
"""
322342
def server_loop(port: int):
323343
_logger.debug("Starting JAX profiler server on port %s", port)
@@ -448,3 +468,7 @@ def stop_server_patch() -> None:
448468
stop_server()
449469

450470
jax.profiler.stop_server = stop_server_patch
471+
472+
473+
if hasattr(jax.profiler, "ProfileOptions"):
474+
jax.profiler.ProfileOptions.__eq__ = _profile_options_eq

pathwaysutils/test/profiling_test.py

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
import requests
2525

2626

27+
def _make_profile_options(**kwargs) -> jax.profiler.ProfileOptions:
28+
options = jax.profiler.ProfileOptions()
29+
for k, v in kwargs.items():
30+
# Confirm that the attribute is one of the ProfileOptions attributes
31+
# that is supported by Pathways. This is not a test function so it cannot
32+
# use assertHasAttr.
33+
assert hasattr(options, k), f"ProfileOptions does not have attribute {k}"
34+
setattr(options, k, v)
35+
return options
36+
37+
2738
class ProfilingTest(parameterized.TestCase):
2839
"""Tests for Pathways on Cloud profiling."""
2940

@@ -305,8 +316,7 @@ def test_stop_trace_success(self):
305316
"ProfileOptions requires JAX 0.9.2 or newer",
306317
)
307318
def test_stop_trace_with_xprof_options_passes_out_avals(self):
308-
options = jax.profiler.ProfileOptions()
309-
options.duration_ms = 2000
319+
options = _make_profile_options(duration_ms=2000)
310320

311321
request = profiling._create_profile_request(
312322
"gs://test_bucket/test_dir", options
@@ -467,7 +477,7 @@ def test_monkey_patched_stop_server(self):
467477

468478
mocks["stop_server"].assert_called_once()
469479

470-
@parameterized.parameters(None, jax.profiler.ProfileOptions())
480+
@parameterized.parameters(None, _make_profile_options())
471481
def test_create_profile_request_default_options(self, profiler_options):
472482
request = profiling._create_profile_request(
473483
"gs://bucket/dir", profiler_options=profiler_options
@@ -497,17 +507,18 @@ def test_create_profile_request_with_max_num_hosts(self):
497507
"ProfileOptions requires JAX 0.9.2 or newer",
498508
)
499509
def test_create_profile_request_with_options(self):
500-
options = jax.profiler.ProfileOptions()
501-
options.host_tracer_level = 2
502-
options.python_tracer_level = 1
503-
options.duration_ms = 2000
504-
options.start_timestamp_ns = 123456789
505-
options.session_id = "test_session"
506-
options.advanced_configuration = {
507-
"tpu_num_chips_to_profile_per_task": 3,
508-
"tpu_num_sparse_core_tiles_to_trace": 5,
509-
"tpu_trace_mode": "TRACE_COMPUTE",
510-
}
510+
options = _make_profile_options(
511+
host_tracer_level=2,
512+
python_tracer_level=1,
513+
duration_ms=2000,
514+
start_timestamp_ns=123456789,
515+
session_id="test_session",
516+
advanced_configuration={
517+
"tpu_num_chips_to_profile_per_task": 3,
518+
"tpu_num_sparse_core_tiles_to_trace": 5,
519+
"tpu_trace_mode": "TRACE_COMPUTE",
520+
},
521+
)
511522

512523
request = profiling._create_profile_request(
513524
"gs://bucket/dir", profiler_options=options
@@ -609,10 +620,62 @@ def test_jax_profiler_trace_calls_patched_functions(self):
609620
jax.version.__version_info__ < (0, 9, 2),
610621
"ProfileOptions requires JAX 0.9.2 or newer",
611622
)
612-
def test_is_default_profile_options_with_session_id(self):
613-
options = jax.profiler.ProfileOptions()
614-
options.session_id = "test_session"
615-
self.assertFalse(profiling._is_default_profile_options(options))
623+
@parameterized.named_parameters(
624+
dict(
625+
testcase_name="default_equal",
626+
options1=_make_profile_options(),
627+
options2=_make_profile_options(),
628+
),
629+
dict(
630+
testcase_name="session_id_equal",
631+
options1=_make_profile_options(session_id="test"),
632+
options2=_make_profile_options(session_id="test"),
633+
),
634+
dict(
635+
testcase_name="host_tracer_level_equal",
636+
options1=_make_profile_options(host_tracer_level=3),
637+
options2=_make_profile_options(host_tracer_level=3),
638+
),
639+
dict(
640+
testcase_name="advanced_config_equal",
641+
options1=_make_profile_options(
642+
advanced_configuration={"foo": "bar"}
643+
),
644+
options2=_make_profile_options(
645+
advanced_configuration={"foo": "bar"}
646+
),
647+
),
648+
)
649+
def test_profile_options_equal(self, options1, options2):
650+
self.assertEqual(options1, options2)
651+
652+
@absltest.skipIf(
653+
jax.version.__version_info__ < (0, 9, 2),
654+
"ProfileOptions requires JAX 0.9.2 or newer",
655+
)
656+
@parameterized.named_parameters(
657+
dict(
658+
testcase_name="session_id_diff",
659+
options1=_make_profile_options(session_id="test1"),
660+
options2=_make_profile_options(session_id="test2"),
661+
),
662+
dict(
663+
testcase_name="host_tracer_level_diff",
664+
options1=_make_profile_options(host_tracer_level=1),
665+
options2=_make_profile_options(host_tracer_level=2),
666+
),
667+
dict(
668+
testcase_name="advanced_config_diff",
669+
options1=_make_profile_options(
670+
advanced_configuration={"foo": "bar"}
671+
),
672+
options2=_make_profile_options(
673+
advanced_configuration={"foo": "baz"}
674+
),
675+
),
676+
)
677+
def test_profile_options_not_equal(self, options1, options2):
678+
self.assertNotEqual(options1, options2)
616679

617680
@absltest.skipIf(
618681
jax.version.__version_info__ < (0, 9, 2),
@@ -623,8 +686,7 @@ def test_start_trace_compatibility_error(self):
623686
"Bad PluginProgram"
624687
)
625688

626-
options = jax.profiler.ProfileOptions()
627-
options.session_id = "test_session"
689+
options = _make_profile_options(session_id="test_session")
628690

629691
with self.assertRaisesRegex(
630692
RuntimeError,

0 commit comments

Comments
 (0)