diff --git a/torchx/tracker/test/api_test.py b/torchx/tracker/test/api_test.py index 91b679d5d..afd52e41c 100644 --- a/torchx/tracker/test/api_test.py +++ b/torchx/tracker/test/api_test.py @@ -83,8 +83,6 @@ def sources( run_id: str, artifact_name: str | None = None, ) -> Iterable[TrackerSource]: - source_data = self._sources[run_id] - sources = [] for artifact_name, source_id in self._sources[run_id].items(): if artifact_name == DEFAULT_SOURCE: @@ -289,6 +287,28 @@ def test_build_trackers_with_module(self) -> None: self.assertIsInstance(tracker, MLflowTracker) module.assert_called_once_with(config) + def test_build_trackers_with_non_callable_module(self) -> None: + """load_module returns a ModuleType when no :attr is specified. + Modules aren't callable -- build_trackers should skip them with a warning. + """ + from types import ModuleType + + module = ModuleType("fake_tracker_module") + with ( + patch("torchx.tracker.api.load_group", return_value=None), + patch( + "torchx.tracker.api.load_module", + return_value=module, + ), + ): + tracker_names = {"fake_tracker_module": "myconfig.txt"} + trackers = build_trackers(tracker_names) + self.assertEqual( + 0, + len(list(trackers)), + "non-callable module should be skipped, not called as a factory", + ) + def test_build_trackers(self) -> None: with patch("torchx.tracker.api.plugins") as plugins_mock: plugins_mock.registry.return_value.get.return_value = {