Skip to content

Commit b4984d6

Browse files
committed
Fix mldiag attribute error
1 parent e92ca84 commit b4984d6

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

tests/unit/managed_mldiagnostics_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import unittest
1818
from unittest import mock
1919

20-
from maxtext.common.managed_mldiagnostics import ManagedMLDiagnostics
20+
from maxtext.common.managed_mldiagnostics import ManagedMLDiagnostics, mldiag
2121
import pytest
2222

2323

@@ -34,7 +34,7 @@ def test_not_enabled_noop(self):
3434
mock_config = mock.MagicMock()
3535
mock_config.managed_mldiagnostics = False
3636

37-
with mock.patch.object(ManagedMLDiagnostics.mldiag, "machinelearning_run") as mock_run:
37+
with mock.patch.object(mldiag, "machinelearning_run") as mock_run:
3838
ManagedMLDiagnostics(mock_config)
3939
mock_run.assert_not_called()
4040

@@ -45,15 +45,17 @@ def test_enabled_empty_region_passes_none(self):
4545
mock_config.run_name = "test_run"
4646
mock_config.managed_mldiagnostics_run_group = "test_group"
4747
mock_config.managed_mldiagnostics_dir = "gs://test_dir"
48+
mock_config.managed_mldiagnostics_on_demand_profiling = False
4849
mock_config.get_keys.return_value = {"key1": "val1"}
4950

50-
with mock.patch.object(ManagedMLDiagnostics.mldiag, "machinelearning_run") as mock_run:
51+
with mock.patch.object(mldiag, "machinelearning_run") as mock_run:
5152
ManagedMLDiagnostics(mock_config)
5253
mock_run.assert_called_once_with(
5354
name="test_run",
5455
run_group="test_group",
5556
configs={"key1": "val1"},
5657
gcs_path="gs://test_dir",
58+
on_demand_xprof=False,
5759
region=None,
5860
)
5961

@@ -64,15 +66,17 @@ def test_enabled_populated_region_passes_region(self):
6466
mock_config.run_name = "test_run"
6567
mock_config.managed_mldiagnostics_run_group = "test_group"
6668
mock_config.managed_mldiagnostics_dir = "gs://test_dir"
69+
mock_config.managed_mldiagnostics_on_demand_profiling = False
6770
mock_config.get_keys.return_value = {"key1": "val1"}
6871

69-
with mock.patch.object(ManagedMLDiagnostics.mldiag, "machinelearning_run") as mock_run:
72+
with mock.patch.object(mldiag, "machinelearning_run") as mock_run:
7073
ManagedMLDiagnostics(mock_config)
7174
mock_run.assert_called_once_with(
7275
name="test_run",
7376
run_group="test_group",
7477
configs={"key1": "val1"},
7578
gcs_path="gs://test_dir",
79+
on_demand_xprof=False,
7680
region="us-east1",
7781
)
7882

0 commit comments

Comments
 (0)