Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions torchx/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,61 @@

import functools
import warnings
from typing import Any, Callable
from typing import Any, Callable, Iterable


# Entry-point groups that have ``torchx_plugins.*`` namespace-package
# alternatives. Only these groups trigger a deprecation warning.
# Keep in sync with ``torchx.plugins._registry.PluginType``.
_PLUGIN_GROUPS: frozenset[str] = frozenset(
{
"torchx.schedulers",
"torchx.named_resources",
"torchx.tracker",
}
)


def deprecated_entrypoint(
group: str,
ep_names: Iterable[str],
*,
stacklevel: int = 2,
) -> None:
"""Emit a deprecation warning for entry-point based plugins.

Only warns for groups that have ``torchx_plugins.*`` namespace-package
equivalents (i.e., groups listed in
:py:class:`~torchx.plugins.PluginType`). Groups without namespace
alternatives (e.g., ``"torchx.schedulers.orchestrator"``,
``"torchx.components"``) are silently ignored.

Args:
group: The entry-point group name (e.g., ``"torchx.schedulers"``).
ep_names: Names of the entry-point plugins that were loaded.
stacklevel: Stack level for :py:func:`warnings.warn`. Default ``2``
points at the caller of this function.

Example::

>>> # In _registry._find():
>>> deprecated_entrypoint("torchx.schedulers", ["mast_conda"])

"""
if group not in _PLUGIN_GROUPS:
return

names = ", ".join(sorted(ep_names))
namespace = f"torchx_plugins.{group.removeprefix('torchx.')}"
warnings.warn(
f"Entry-point plugins in group '{group}' are deprecated. "
f"Migrate to the '{namespace}' namespace package using "
f"the @register decorator. "
f"Set TORCHX_NO_ENTRYPOINTS=1 to opt out early. "
f"Deprecated entry-point plugins: {names}",
DeprecationWarning,
stacklevel=stacklevel,
)


def deprecated_module(
Expand Down Expand Up @@ -66,10 +120,13 @@ def deprecated_module(
)


_F = Callable[..., object]


def deprecated(
*,
replacement: str | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[_F], _F]:
"""Mark a function or class as deprecated.

.. code-block:: python
Expand All @@ -88,17 +145,19 @@ def old_func():
on each call.
"""

def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
def decorator(fn: _F) -> _F:
parts = [f"[Deprecated] {fn.__qualname__} is deprecated"]
if replacement:
parts.append(f"-- use {replacement} instead")
msg: str = " ".join(parts) + "."

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# pyre-ignore[3]: Wrapper preserves fn's signature at runtime via wraps
def wrapper(*args: Any, **kwargs: Any) -> object:
warnings.warn(msg, UserWarning, stacklevel=2)
return fn(*args, **kwargs)

# pyre-ignore[7]: wrapper has same runtime signature as fn via wraps
return wrapper

return decorator
3 changes: 3 additions & 0 deletions torchx/plugins/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from types import ModuleType
from typing import Any, Callable, overload

from torchx.deprecations import deprecated_entrypoint
from torchx.util import entrypoints

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -421,6 +422,8 @@ def _find(self, plugin_type: PluginType) -> dict[str, Callable[..., Any]]:

if self._load_entrypoints:
ep_plugins = entrypoints.load_group(group) or {}
if ep_plugins:
deprecated_entrypoint(group, ep_plugins.keys(), stacklevel=4)
# Entry points override namespace plugins (higher priority).
merged.update(ep_plugins)
if ns_plugins and ep_plugins:
Expand Down
62 changes: 62 additions & 0 deletions torchx/plugins/test/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sys
import types
import unittest
import warnings
from pathlib import Path
from typing import Any
from unittest.mock import patch
Expand Down Expand Up @@ -238,6 +239,67 @@ def test_load_entrypoints_false(self) -> None:
"namespace plugins should still be discovered",
)

@mock_install_torchx_plugins()
def test_find_emits_deprecation_warning_for_entrypoints(self) -> None:
"""_find() calls deprecated_entrypoint() when entry points are loaded."""
ep_result = {"my_sched": lambda: None}

with patch.object(entrypoints, "load_group", return_value=ep_result):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
registry().get(PluginType.SCHEDULER)

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
1,
"should emit exactly one DeprecationWarning when entry-point plugins are loaded",
)
msg = str(dep_warnings[0].message)
self.assertIn(
"torchx.schedulers",
msg,
"warning should mention the deprecated group",
)
self.assertIn(
"my_sched",
msg,
"warning should mention the entry-point plugin name",
)

@mock_install_torchx_plugins()
def test_find_no_warning_when_entrypoints_empty(self) -> None:
"""No deprecation warning when entry points return empty."""
with patch.object(entrypoints, "load_group", return_value=None):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
registry().get(PluginType.SCHEDULER)

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
0,
"should not emit DeprecationWarning when no entry-point plugins are found",
)

@mock_install_torchx_plugins()
def test_find_no_warning_when_entrypoints_disabled(self) -> None:
"""No deprecation warning when load_entrypoints=False."""
ep_result = {"my_sched": lambda: None}

with patch.object(entrypoints, "load_group", return_value=ep_result):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
reg = PluginRegistry(load_entrypoints=False)
reg.get(PluginType.SCHEDULER)

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
0,
"should not emit DeprecationWarning when entrypoints are disabled",
)

@mock_install_torchx_plugins()
def test_info_returns_all_groups(self) -> None:
"""info() returns dict[PluginType, dict[str, Callable]]."""
Expand Down
130 changes: 129 additions & 1 deletion torchx/test/deprecations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
import unittest
import warnings

from torchx.deprecations import deprecated, deprecated_module
from torchx.deprecations import (
_PLUGIN_GROUPS,
deprecated,
deprecated_entrypoint,
deprecated_module,
)


class DeprecatedModuleTest(unittest.TestCase):
Expand Down Expand Up @@ -137,3 +142,126 @@ def add(a: int, b: int, extra: int = 0) -> int:
result = add(1, 2, extra=3)

self.assertEqual(result, 6, "decorated function should pass args correctly")


class DeprecatedEntrypointTest(unittest.TestCase):
"""Tests for ``deprecated_entrypoint()``."""

def test_warns_for_scheduler_group(self) -> None:
"""Emits DeprecationWarning for the 'torchx.schedulers' group."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
deprecated_entrypoint("torchx.schedulers", ["local_cwd", "slurm"])

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
1,
"should emit exactly one DeprecationWarning for torchx.schedulers",
)
msg = str(dep_warnings[0].message)
self.assertIn(
"torchx.schedulers",
msg,
"warning should mention the deprecated group",
)
self.assertIn(
"torchx_plugins.schedulers",
msg,
"warning should mention the namespace-package alternative",
)
self.assertIn(
"TORCHX_NO_ENTRYPOINTS=1",
msg,
"warning should mention the opt-out env var",
)

def test_warns_for_named_resource_group(self) -> None:
"""Emits DeprecationWarning for the 'torchx.named_resources' group."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
deprecated_entrypoint("torchx.named_resources", ["aws_p5"])

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
1,
"should emit exactly one DeprecationWarning for torchx.named_resources",
)

def test_warns_for_tracker_group(self) -> None:
"""Emits DeprecationWarning for the 'torchx.tracker' group."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
deprecated_entrypoint("torchx.tracker", ["mlflow"])

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
1,
"should emit exactly one DeprecationWarning for torchx.tracker",
)

def test_no_warning_for_orchestrator_group(self) -> None:
"""No warning for 'torchx.schedulers.orchestrator' — no namespace alternative."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
deprecated_entrypoint("torchx.schedulers.orchestrator", ["fblearner"])

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
0,
"should not warn for groups without namespace-plugin alternatives",
)

def test_no_warning_for_components_group(self) -> None:
"""No warning for 'torchx.components' — no namespace alternative."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
deprecated_entrypoint("torchx.components", ["my_component"])

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
0,
"should not warn for torchx.components group",
)

def test_no_warning_for_file_group(self) -> None:
"""No warning for 'torchx.file' — no namespace alternative."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
deprecated_entrypoint("torchx.file", ["get_file_contents"])

dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
self.assertEqual(
len(dep_warnings),
0,
"should not warn for torchx.file group",
)

def test_warning_includes_sorted_plugin_names(self) -> None:
"""Warning message lists plugin names in sorted order."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
deprecated_entrypoint("torchx.schedulers", ["zz_scheduler", "aa_scheduler"])

msg = str(caught[0].message)
self.assertIn(
"aa_scheduler, zz_scheduler",
msg,
"plugin names should be sorted alphabetically in the warning",
)

def test_plugin_groups_matches_plugin_type(self) -> None:
"""_PLUGIN_GROUPS covers all PluginType values."""
from torchx.plugins._registry import PluginType

expected = {pt.value for pt in PluginType}
self.assertEqual(
_PLUGIN_GROUPS,
expected,
"_PLUGIN_GROUPS should match PluginType values exactly. "
"If you added a new PluginType, add it to _PLUGIN_GROUPS too.",
)
Loading