2424import requests
2525
2626
27+ def _make_profile_options (** kwargs ) -> jax .profiler .ProfileOptions :
28+ options = jax .profiler .ProfileOptions ()
29+ for k , v in kwargs .items ():
30+ # Confirm that the attribute is one of the ProfileOptions attributes
31+ # that is supported by Pathways. This is not a test function so it cannot
32+ # use assertHasAttr.
33+ assert hasattr (options , k ), f"ProfileOptions does not have attribute { k } "
34+ setattr (options , k , v )
35+ return options
36+
37+
2738class ProfilingTest (parameterized .TestCase ):
2839 """Tests for Pathways on Cloud profiling."""
2940
@@ -305,8 +316,7 @@ def test_stop_trace_success(self):
305316 "ProfileOptions requires JAX 0.9.2 or newer" ,
306317 )
307318 def test_stop_trace_with_xprof_options_passes_out_avals (self ):
308- options = jax .profiler .ProfileOptions ()
309- options .duration_ms = 2000
319+ options = _make_profile_options (duration_ms = 2000 )
310320
311321 request = profiling ._create_profile_request (
312322 "gs://test_bucket/test_dir" , options
@@ -467,7 +477,7 @@ def test_monkey_patched_stop_server(self):
467477
468478 mocks ["stop_server" ].assert_called_once ()
469479
470- @parameterized .parameters (None , jax . profiler . ProfileOptions ())
480+ @parameterized .parameters (None , _make_profile_options ())
471481 def test_create_profile_request_default_options (self , profiler_options ):
472482 request = profiling ._create_profile_request (
473483 "gs://bucket/dir" , profiler_options = profiler_options
@@ -497,17 +507,18 @@ def test_create_profile_request_with_max_num_hosts(self):
497507 "ProfileOptions requires JAX 0.9.2 or newer" ,
498508 )
499509 def test_create_profile_request_with_options (self ):
500- options = jax .profiler .ProfileOptions ()
501- options .host_tracer_level = 2
502- options .python_tracer_level = 1
503- options .duration_ms = 2000
504- options .start_timestamp_ns = 123456789
505- options .session_id = "test_session"
506- options .advanced_configuration = {
507- "tpu_num_chips_to_profile_per_task" : 3 ,
508- "tpu_num_sparse_core_tiles_to_trace" : 5 ,
509- "tpu_trace_mode" : "TRACE_COMPUTE" ,
510- }
510+ options = _make_profile_options (
511+ host_tracer_level = 2 ,
512+ python_tracer_level = 1 ,
513+ duration_ms = 2000 ,
514+ start_timestamp_ns = 123456789 ,
515+ session_id = "test_session" ,
516+ advanced_configuration = {
517+ "tpu_num_chips_to_profile_per_task" : 3 ,
518+ "tpu_num_sparse_core_tiles_to_trace" : 5 ,
519+ "tpu_trace_mode" : "TRACE_COMPUTE" ,
520+ },
521+ )
511522
512523 request = profiling ._create_profile_request (
513524 "gs://bucket/dir" , profiler_options = options
@@ -609,10 +620,62 @@ def test_jax_profiler_trace_calls_patched_functions(self):
609620 jax .version .__version_info__ < (0 , 9 , 2 ),
610621 "ProfileOptions requires JAX 0.9.2 or newer" ,
611622 )
612- def test_is_default_profile_options_with_session_id (self ):
613- options = jax .profiler .ProfileOptions ()
614- options .session_id = "test_session"
615- self .assertFalse (profiling ._is_default_profile_options (options ))
623+ @parameterized .named_parameters (
624+ dict (
625+ testcase_name = "default_equal" ,
626+ options1 = _make_profile_options (),
627+ options2 = _make_profile_options (),
628+ ),
629+ dict (
630+ testcase_name = "session_id_equal" ,
631+ options1 = _make_profile_options (session_id = "test" ),
632+ options2 = _make_profile_options (session_id = "test" ),
633+ ),
634+ dict (
635+ testcase_name = "host_tracer_level_equal" ,
636+ options1 = _make_profile_options (host_tracer_level = 3 ),
637+ options2 = _make_profile_options (host_tracer_level = 3 ),
638+ ),
639+ dict (
640+ testcase_name = "advanced_config_equal" ,
641+ options1 = _make_profile_options (
642+ advanced_configuration = {"foo" : "bar" }
643+ ),
644+ options2 = _make_profile_options (
645+ advanced_configuration = {"foo" : "bar" }
646+ ),
647+ ),
648+ )
649+ def test_profile_options_equal (self , options1 , options2 ):
650+ self .assertEqual (options1 , options2 )
651+
652+ @absltest .skipIf (
653+ jax .version .__version_info__ < (0 , 9 , 2 ),
654+ "ProfileOptions requires JAX 0.9.2 or newer" ,
655+ )
656+ @parameterized .named_parameters (
657+ dict (
658+ testcase_name = "session_id_diff" ,
659+ options1 = _make_profile_options (session_id = "test1" ),
660+ options2 = _make_profile_options (session_id = "test2" ),
661+ ),
662+ dict (
663+ testcase_name = "host_tracer_level_diff" ,
664+ options1 = _make_profile_options (host_tracer_level = 1 ),
665+ options2 = _make_profile_options (host_tracer_level = 2 ),
666+ ),
667+ dict (
668+ testcase_name = "advanced_config_diff" ,
669+ options1 = _make_profile_options (
670+ advanced_configuration = {"foo" : "bar" }
671+ ),
672+ options2 = _make_profile_options (
673+ advanced_configuration = {"foo" : "baz" }
674+ ),
675+ ),
676+ )
677+ def test_profile_options_not_equal (self , options1 , options2 ):
678+ self .assertNotEqual (options1 , options2 )
616679
617680 @absltest .skipIf (
618681 jax .version .__version_info__ < (0 , 9 , 2 ),
@@ -623,8 +686,7 @@ def test_start_trace_compatibility_error(self):
623686 "Bad PluginProgram"
624687 )
625688
626- options = jax .profiler .ProfileOptions ()
627- options .session_id = "test_session"
689+ options = _make_profile_options (session_id = "test_session" )
628690
629691 with self .assertRaisesRegex (
630692 RuntimeError ,
0 commit comments