@@ -232,6 +232,7 @@ def test_start_trace_success(self):
232232 json .dumps ({
233233 "profileRequest" : {
234234 "traceLocation" : "gs://test_bucket/test_dir" ,
235+ "maxNumHosts" : 1 ,
235236 }
236237 })
237238 )
@@ -243,6 +244,25 @@ def test_start_trace_success(self):
243244 )
244245 self .assertIsNotNone (profiling ._profile_state .executable )
245246
247+ def test_start_trace_with_max_num_hosts (self ):
248+ profiling .start_trace ("gs://test_bucket/test_dir" , max_num_hosts = 10 )
249+
250+ self .mock_toy_computation .assert_called_once ()
251+ self .mock_plugin_executable_cls .assert_called_once_with (
252+ json .dumps ({
253+ "profileRequest" : {
254+ "traceLocation" : "gs://test_bucket/test_dir" ,
255+ "maxNumHosts" : 10 ,
256+ }
257+ })
258+ )
259+ self .mock_plugin_executable_cls .return_value .call .assert_called_once ()
260+ self .mock_original_start_trace .assert_called_once_with (
261+ log_dir = "gs://test_bucket/test_dir" ,
262+ create_perfetto_link = False ,
263+ create_perfetto_trace = False ,
264+ )
265+
246266 def test_start_trace_no_toy_computation_second_time (self ):
247267 profiling .start_trace ("gs://test_bucket/test_dir" )
248268 profiling .stop_trace ()
@@ -408,6 +428,24 @@ def test_monkey_patched_start_trace(self, profiler_module):
408428 create_perfetto_link = False ,
409429 create_perfetto_trace = False ,
410430 profiler_options = None ,
431+ max_num_hosts = 1 ,
432+ )
433+
434+ @parameterized .named_parameters (
435+ dict (testcase_name = "jax_profiler" , profiler_module = jax .profiler ),
436+ dict (testcase_name = "jax_src_profiler" , profiler_module = jax ._src .profiler ),
437+ )
438+ def test_monkey_patched_start_trace_with_max_num_hosts (self , profiler_module ):
439+ mocks = self ._setup_monkey_patch ()
440+
441+ profiler_module .start_trace ("gs://bucket/dir" , max_num_hosts = 3 )
442+
443+ mocks ["start_trace" ].assert_called_once_with (
444+ "gs://bucket/dir" ,
445+ create_perfetto_link = False ,
446+ create_perfetto_trace = False ,
447+ profiler_options = None ,
448+ max_num_hosts = 3 ,
411449 )
412450
413451 @parameterized .named_parameters (
@@ -444,6 +482,19 @@ def test_create_profile_request_default_options(self, profiler_options):
444482 request ,
445483 {
446484 "traceLocation" : "gs://bucket/dir" ,
485+ "maxNumHosts" : 1 ,
486+ },
487+ )
488+
489+ def test_create_profile_request_with_max_num_hosts (self ):
490+ request = profiling ._create_profile_request (
491+ "gs://bucket/dir" , max_num_hosts = 5
492+ )
493+ self .assertEqual (
494+ request ,
495+ {
496+ "traceLocation" : "gs://bucket/dir" ,
497+ "maxNumHosts" : 5 ,
447498 },
448499 )
449500
@@ -471,6 +522,7 @@ def test_create_profile_request_with_options(self):
471522 {
472523 "traceLocation" : "gs://bucket/dir" ,
473524 "maxDurationSecs" : 2.0 ,
525+ "maxNumHosts" : 1 ,
474526 "xprofTraceOptions" : {
475527 "traceDirectory" : "gs://bucket/dir" ,
476528 "pwTraceOptions" : {
0 commit comments