1717import unittest
1818from unittest import mock
1919
20- from maxtext .common .managed_mldiagnostics import ManagedMLDiagnostics
20+ from maxtext .common .managed_mldiagnostics import ManagedMLDiagnostics , mldiag
2121import pytest
2222
2323
@@ -34,7 +34,7 @@ def test_not_enabled_noop(self):
3434 mock_config = mock .MagicMock ()
3535 mock_config .managed_mldiagnostics = False
3636
37- with mock .patch .object (ManagedMLDiagnostics . mldiag , "machinelearning_run" ) as mock_run :
37+ with mock .patch .object (mldiag , "machinelearning_run" ) as mock_run :
3838 ManagedMLDiagnostics (mock_config )
3939 mock_run .assert_not_called ()
4040
@@ -45,15 +45,17 @@ def test_enabled_empty_region_passes_none(self):
4545 mock_config .run_name = "test_run"
4646 mock_config .managed_mldiagnostics_run_group = "test_group"
4747 mock_config .managed_mldiagnostics_dir = "gs://test_dir"
48+ mock_config .managed_mldiagnostics_on_demand_profiling = False
4849 mock_config .get_keys .return_value = {"key1" : "val1" }
4950
50- with mock .patch .object (ManagedMLDiagnostics . mldiag , "machinelearning_run" ) as mock_run :
51+ with mock .patch .object (mldiag , "machinelearning_run" ) as mock_run :
5152 ManagedMLDiagnostics (mock_config )
5253 mock_run .assert_called_once_with (
5354 name = "test_run" ,
5455 run_group = "test_group" ,
5556 configs = {"key1" : "val1" },
5657 gcs_path = "gs://test_dir" ,
58+ on_demand_xprof = False ,
5759 region = None ,
5860 )
5961
@@ -64,15 +66,17 @@ def test_enabled_populated_region_passes_region(self):
6466 mock_config .run_name = "test_run"
6567 mock_config .managed_mldiagnostics_run_group = "test_group"
6668 mock_config .managed_mldiagnostics_dir = "gs://test_dir"
69+ mock_config .managed_mldiagnostics_on_demand_profiling = False
6770 mock_config .get_keys .return_value = {"key1" : "val1" }
6871
69- with mock .patch .object (ManagedMLDiagnostics . mldiag , "machinelearning_run" ) as mock_run :
72+ with mock .patch .object (mldiag , "machinelearning_run" ) as mock_run :
7073 ManagedMLDiagnostics (mock_config )
7174 mock_run .assert_called_once_with (
7275 name = "test_run" ,
7376 run_group = "test_group" ,
7477 configs = {"key1" : "val1" },
7578 gcs_path = "gs://test_dir" ,
79+ on_demand_xprof = False ,
7680 region = "us-east1" ,
7781 )
7882
0 commit comments