Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions torchx/tracker/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def sources(
run_id: str,
artifact_name: str | None = None,
) -> Iterable[TrackerSource]:
source_data = self._sources[run_id]

sources = []
for artifact_name, source_id in self._sources[run_id].items():
if artifact_name == DEFAULT_SOURCE:
Expand Down Expand Up @@ -289,6 +287,28 @@ def test_build_trackers_with_module(self) -> None:
self.assertIsInstance(tracker, MLflowTracker)
module.assert_called_once_with(config)

def test_build_trackers_with_non_callable_module(self) -> None:
"""load_module returns a ModuleType when no :attr is specified.
Modules aren't callable -- build_trackers should skip them with a warning.
"""
from types import ModuleType

module = ModuleType("fake_tracker_module")
with (
patch("torchx.tracker.api.load_group", return_value=None),
patch(
"torchx.tracker.api.load_module",
return_value=module,
),
):
tracker_names = {"fake_tracker_module": "myconfig.txt"}
trackers = build_trackers(tracker_names)
self.assertEqual(
0,
len(list(trackers)),
"non-callable module should be skipped, not called as a factory",
)

def test_build_trackers(self) -> None:
with patch("torchx.tracker.api.plugins") as plugins_mock:
plugins_mock.registry.return_value.get.return_value = {
Expand Down
Loading