Skip to content

Commit dc5b298

Browse files
kiukchungfacebook-github-bot
authored andcommitted
Fix type bug in build_trackers -- skip non-callable modules (#1249)
Summary: Fix Pyre type error in `build_trackers()` where `load_module()` can return a `ModuleType` that is then called as a factory function. **Problem:** - `load_module(factory_name)` returns `ModuleType | Callable | None` - When the user passes a bare module path (no `:attr` suffix), a `ModuleType` is returned, which is not callable - Calling `factory(config)` on a `ModuleType` is a runtime `TypeError` and a Pyre type error [29] **Fix:** - Add a `callable()` check after resolving the factory - Non-callable modules are skipped with a descriptive warning guiding the user to use `module.path:factory_function` syntax - Rename intermediate variable to `resolved` so Pyre can narrow the type before assigning to `factory` **Cleanup:** - Convert f-string log calls to `%s` formatting per TorchX conventions Differential Revision: D95932184
1 parent b3b5388 commit dc5b298

2 files changed

Lines changed: 37 additions & 7 deletions

File tree

torchx/tracker/api.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,26 @@ def build_trackers(
155155
logger.warning("No 'torchx.tracker' entry_points are defined.")
156156

157157
for factory_name, config in factory_and_config.items():
158-
factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
159-
if not factory:
158+
resolved = entrypoint_factories.get(factory_name) or load_module(factory_name)
159+
if not resolved:
160160
logger.warning(
161-
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://meta-pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
161+
"no tracker factory `%s` found in entry_points or modules."
162+
" See https://meta-pytorch.org/torchx/main/tracker.html#module-torchx.tracker",
163+
factory_name,
162164
)
163165
continue
166+
if not callable(resolved):
167+
logger.warning(
168+
"tracker factory `%s` resolved to a module, not a callable."
169+
" Use `module.path:factory_function` syntax to specify the factory",
170+
factory_name,
171+
)
172+
continue
173+
factory = resolved
164174
if config:
165-
logger.info(f"Tracker config found for `{factory_name}` as `{config}`")
175+
logger.info("tracker config found for `%s` as `%s`", factory_name, config)
166176
else:
167-
logger.info(f"No tracker config specified for `{factory_name}`")
177+
logger.info("no tracker config specified for `%s`", factory_name)
168178
tracker = factory(config)
169179
trackers.append(tracker)
170180
return trackers

torchx/tracker/test/api_test.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ def sources(
8383
run_id: str,
8484
artifact_name: str | None = None,
8585
) -> Iterable[TrackerSource]:
86-
source_data = self._sources[run_id]
87-
8886
sources = []
8987
for artifact_name, source_id in self._sources[run_id].items():
9088
if artifact_name == DEFAULT_SOURCE:
@@ -289,6 +287,28 @@ def test_build_trackers_with_module(self) -> None:
289287
self.assertIsInstance(tracker, MLflowTracker)
290288
module.assert_called_once_with(config)
291289

290+
def test_build_trackers_with_non_callable_module(self) -> None:
291+
"""load_module returns a ModuleType when no :attr is specified.
292+
Modules aren't callable -- build_trackers should skip them with a warning.
293+
"""
294+
from types import ModuleType
295+
296+
module = ModuleType("fake_tracker_module")
297+
with (
298+
patch("torchx.tracker.api.load_group", return_value=None),
299+
patch(
300+
"torchx.tracker.api.load_module",
301+
return_value=module,
302+
),
303+
):
304+
tracker_names = {"fake_tracker_module": "myconfig.txt"}
305+
trackers = build_trackers(tracker_names)
306+
self.assertEqual(
307+
0,
308+
len(list(trackers)),
309+
"non-callable module should be skipped, not called as a factory",
310+
)
311+
292312
def test_build_trackers(self) -> None:
293313
with patch(
294314
"torchx.tracker.api.load_group",

0 commit comments

Comments
 (0)