Skip to content

Commit 0f1dd71

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
add support for max_num_hosts in start_trace. the default now is to trace one host.
PiperOrigin-RevId: 900387988
1 parent a5984c0 commit 0f1dd71

2 files changed

Lines changed: 62 additions & 1 deletion

File tree

pathwaysutils/profiling.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ def _is_default_profile_options(
8080
def _create_profile_request(
8181
log_dir: os.PathLike[str] | str,
8282
profiler_options: jax.profiler.ProfileOptions | None = None,
83+
max_num_hosts: int = 1,
8384
) -> Mapping[str, Any]:
8485
"""Creates a profile request mapping from the given options."""
8586
profile_request: dict[str, Any] = {
8687
"traceLocation": str(log_dir),
88+
"maxNumHosts": max_num_hosts,
8789
}
8890

8991
if profiler_options is None or _is_default_profile_options(profiler_options):
@@ -173,6 +175,7 @@ def start_trace(
173175
create_perfetto_link: bool = False,
174176
create_perfetto_trace: bool = False,
175177
profiler_options: jax.profiler.ProfileOptions | None = None,
178+
max_num_hosts: int = 1,
176179
) -> None:
177180
"""Starts a profiler trace.
178181
@@ -201,6 +204,8 @@ def start_trace(
201204
This feature is experimental for Pathways on Cloud and may not be fully
202205
supported.
203206
profiler_options: Profiler options to configure the profiler for collection.
207+
max_num_hosts: An optional integer to limit the number of hosts profiled
208+
(defaults to 1).
204209
"""
205210
if not str(log_dir).startswith("gs://"):
206211
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
@@ -218,7 +223,9 @@ def start_trace(
218223
)
219224
profiler_options = None
220225

221-
profile_request = _create_profile_request(log_dir, profiler_options)
226+
profile_request = _create_profile_request(
227+
log_dir, profiler_options, max_num_hosts=max_num_hosts
228+
)
222229

223230
_logger.debug("Profile request: %s", profile_request)
224231

@@ -366,13 +373,15 @@ def start_trace_patch(
366373
create_perfetto_link: bool = False,
367374
create_perfetto_trace: bool = False,
368375
profiler_options: jax.profiler.ProfileOptions | None = None,
376+
max_num_hosts: int = 1,
369377
) -> None:
370378
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
371379
start_trace(
372380
log_dir,
373381
create_perfetto_link=create_perfetto_link,
374382
create_perfetto_trace=create_perfetto_trace,
375383
profiler_options=profiler_options,
384+
max_num_hosts=max_num_hosts,
376385
)
377386

378387
jax.profiler.start_trace = start_trace_patch

pathwaysutils/test/profiling_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def test_start_trace_success(self):
232232
json.dumps({
233233
"profileRequest": {
234234
"traceLocation": "gs://test_bucket/test_dir",
235+
"maxNumHosts": 1,
235236
}
236237
})
237238
)
@@ -243,6 +244,25 @@ def test_start_trace_success(self):
243244
)
244245
self.assertIsNotNone(profiling._profile_state.executable)
245246

247+
def test_start_trace_with_max_num_hosts(self):
248+
profiling.start_trace("gs://test_bucket/test_dir", max_num_hosts=10)
249+
250+
self.mock_toy_computation.assert_called_once()
251+
self.mock_plugin_executable_cls.assert_called_once_with(
252+
json.dumps({
253+
"profileRequest": {
254+
"traceLocation": "gs://test_bucket/test_dir",
255+
"maxNumHosts": 10,
256+
}
257+
})
258+
)
259+
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
260+
self.mock_original_start_trace.assert_called_once_with(
261+
log_dir="gs://test_bucket/test_dir",
262+
create_perfetto_link=False,
263+
create_perfetto_trace=False,
264+
)
265+
246266
def test_start_trace_no_toy_computation_second_time(self):
247267
profiling.start_trace("gs://test_bucket/test_dir")
248268
profiling.stop_trace()
@@ -408,6 +428,24 @@ def test_monkey_patched_start_trace(self, profiler_module):
408428
create_perfetto_link=False,
409429
create_perfetto_trace=False,
410430
profiler_options=None,
431+
max_num_hosts=1,
432+
)
433+
434+
@parameterized.named_parameters(
435+
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
436+
dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler),
437+
)
438+
def test_monkey_patched_start_trace_with_max_num_hosts(self, profiler_module):
439+
mocks = self._setup_monkey_patch()
440+
441+
profiler_module.start_trace("gs://bucket/dir", max_num_hosts=3)
442+
443+
mocks["start_trace"].assert_called_once_with(
444+
"gs://bucket/dir",
445+
create_perfetto_link=False,
446+
create_perfetto_trace=False,
447+
profiler_options=None,
448+
max_num_hosts=3,
411449
)
412450

413451
@parameterized.named_parameters(
@@ -444,6 +482,19 @@ def test_create_profile_request_default_options(self, profiler_options):
444482
request,
445483
{
446484
"traceLocation": "gs://bucket/dir",
485+
"maxNumHosts": 1,
486+
},
487+
)
488+
489+
def test_create_profile_request_with_max_num_hosts(self):
490+
request = profiling._create_profile_request(
491+
"gs://bucket/dir", max_num_hosts=5
492+
)
493+
self.assertEqual(
494+
request,
495+
{
496+
"traceLocation": "gs://bucket/dir",
497+
"maxNumHosts": 5,
447498
},
448499
)
449500

@@ -471,6 +522,7 @@ def test_create_profile_request_with_options(self):
471522
{
472523
"traceLocation": "gs://bucket/dir",
473524
"maxDurationSecs": 2.0,
525+
"maxNumHosts": 1,
474526
"xprofTraceOptions": {
475527
"traceDirectory": "gs://bucket/dir",
476528
"pwTraceOptions": {

0 commit comments

Comments
 (0)