Skip to content

Commit 02ef7ed

Browse files
authored
Replace boolean plugin-source flags with TORCHX_PLUGINS_SOURCE + PluginSource IntFlag (#1286)
Differential Revision: D101807646 Pull Request resolved: #1286
1 parent 5143c9f commit 02ef7ed

4 files changed

Lines changed: 152 additions & 71 deletions

File tree

docs/source/advanced.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ sufficient.
2020
``setup.py`` / ``pyproject.toml``) is deprecated. Use the ``@register``
2121
decorator with ``torchx_plugins.*`` namespace packages instead.
2222

23-
Namespace plugins are **always** loaded. By default
24-
(``TORCHX_NO_ENTRYPOINTS=0`` or unset), entry points are also loaded
25-
for backward compatibility. Set ``TORCHX_NO_ENTRYPOINTS=1`` to load
26-
only namespace plugins. Entry-point loading will be removed in a
27-
future release. See :doc:`plugins` for the full API reference.
23+
By default both namespace plugins and entry points are loaded for
24+
backward compatibility. Set ``TORCHX_PLUGINS_SOURCE=1`` (namespace
25+
package only) to opt out of entry-point discovery early; entry-point
26+
loading will be removed in a future release. See
27+
:doc:`plugins` for the full API reference.
2828

2929
.. code-block:: text
3030

torchx/plugins/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def my_scheduler(session_name: str, **kwargs) -> Scheduler:
2828
2929
.. deprecated::
3030
Entry-point based registration (``[torchx.*]`` in ``pyproject.toml``)
31-
is deprecated. Set ``TORCHX_NO_ENTRYPOINTS=1`` to opt out early.
31+
is deprecated. Set ``TORCHX_PLUGINS_SOURCE=1`` (namespace package
32+
only) to opt out of entry-point discovery early.
3233
"""
3334

3435
from torchx.plugins._registration import (

torchx/plugins/_registry.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323
import pathlib
2424
import pkgutil
25+
from enum import auto
2526
from types import ModuleType
2627
from typing import Any, Callable, overload
2728

@@ -48,6 +49,19 @@ class PluginType(str, enum.Enum):
4849
TRACKER = "torchx.tracker"
4950

5051

52+
class PluginSource(enum.IntFlag):
53+
"""Bitmask of plugin discovery channels.
54+
55+
Combine with ``|`` and test with ``in``. The value of the
56+
``TORCHX_PLUGINS_SOURCE`` env var is parsed as the integer
57+
representation of this flag.
58+
"""
59+
60+
NONE = 0
61+
NAMESPACE_PKG = auto()
62+
ENTRYPOINT = auto()
63+
64+
5165
# ── Public API ────────────────────────────────────────────────────────────────
5266

5367

@@ -80,19 +94,24 @@ class PluginRegistry:
8094
scheds = reg.get(plugins.PluginType.SCHEDULER)
8195
print(reg)
8296
83-
Namespace plugins (``torchx_plugins.*``) are **always** loaded.
84-
The *load_entrypoints* flag only controls whether ``importlib.metadata``
85-
entry points are additionally merged in.
97+
Discovery channels are selected via *plugin_sources*, a :py:class:`PluginSource`
98+
bitmask. Defaults to all channels enabled; pass :py:attr:`PluginSource.NONE`
99+
for an empty registry or any combination of ``NAMESPACE_PKG`` / ``ENTRYPOINT``
100+
to enable a subset.
86101
87102
Args:
88-
load_entrypoints: Whether to **also** load plugins from
89-
``importlib.metadata`` entry points. Set to ``False`` to
90-
disable entry-point loading (namespace plugins are still
91-
discovered). Defaults to ``True`` for backward compatibility.
103+
plugin_sources: Bitmask of discovery channels to enable.
104+
Defaults to ``NAMESPACE_PKG | ENTRYPOINT``.
92105
"""
93106

94-
def __init__(self, *, load_entrypoints: bool = True) -> None:
95-
self._load_entrypoints: bool = load_entrypoints
107+
def __init__(
108+
self,
109+
*,
110+
plugin_sources: PluginSource = (
111+
PluginSource.NAMESPACE_PKG | PluginSource.ENTRYPOINT
112+
),
113+
) -> None:
114+
self._plugin_sources: PluginSource = plugin_sources
96115
# pyre-ignore[4]: plugin factories are heterogeneously typed
97116
self._cache: dict[PluginType, dict[str, Callable[..., Any]]] = {}
98117
self._errors: list[RegistrationError] = []
@@ -476,28 +495,20 @@ def _find(self, plugin_type: PluginType) -> dict[str, Callable[..., Any]]:
476495
477496
Merge priority (highest → lowest):
478497
479-
1. ``importlib.metadata`` entry points (when *load_entrypoints* is True)
480-
2. ``torchx_plugins.<suffix>`` namespace submodules
498+
1. ``importlib.metadata`` entry points (when ``ENTRYPOINT`` is set)
499+
2. ``torchx_plugins.<suffix>`` namespace submodules (when
500+
``NAMESPACE_PKG`` is set)
481501
"""
482502
group: str = plugin_type.value
483503
namespace = self._namespace_for_type(plugin_type)
484-
ns_plugins = self._find_namespace_plugins(namespace, plugin_type)
485-
merged = dict(ns_plugins)
486-
487-
if self._load_entrypoints:
488-
ep_plugins = entrypoints.load_group(group) or {}
489-
# Entry points override namespace plugins (higher priority).
490-
merged.update(ep_plugins)
491-
if ns_plugins and ep_plugins:
492-
new_keys = set(ns_plugins) - set(ep_plugins)
493-
if new_keys:
494-
logger.debug(
495-
"namespace plugins added keys %s for group `%s`",
496-
new_keys,
497-
group,
498-
)
499-
500-
return merged
504+
plugins = (
505+
self._find_namespace_plugins(namespace, plugin_type)
506+
if PluginSource.NAMESPACE_PKG in self._plugin_sources
507+
else {}
508+
)
509+
if PluginSource.ENTRYPOINT in self._plugin_sources:
510+
plugins |= entrypoints.load_group(group) or {}
511+
return plugins
501512

502513

503514
@functools.lru_cache(maxsize=1)
@@ -507,9 +518,11 @@ def registry() -> PluginRegistry:
507518
The registry lazily discovers plugins per-group on first
508519
:py:meth:`~PluginRegistry.get` access and caches the results.
509520
510-
Namespace plugins (``torchx_plugins.*``) are **always** loaded.
511-
Entry points from ``importlib.metadata`` are additionally merged in
512-
unless ``TORCHX_NO_ENTRYPOINTS=1`` is set.
521+
The ``TORCHX_PLUGINS_SOURCE`` environment variable selects which
522+
discovery channels are enabled. Its value is parsed as the integer
523+
representation of a :py:class:`PluginSource` bitmask: ``0`` for
524+
none, ``1`` for namespace package only, ``2`` for entry points only,
525+
``3`` for both. Defaults to all channels enabled when unset.
513526
514527
Returns:
515528
The cached :py:class:`PluginRegistry` instance.
@@ -523,5 +536,17 @@ def registry() -> PluginRegistry:
523536
named = reg.get(plugins.PluginType.NAMED_RESOURCE)
524537
print(reg)
525538
"""
526-
load_entrypoints = os.environ.get("TORCHX_NO_ENTRYPOINTS") != "1"
527-
return PluginRegistry(load_entrypoints=load_entrypoints)
539+
all_sources = PluginSource.NAMESPACE_PKG | PluginSource.ENTRYPOINT
540+
raw = os.environ.get("TORCHX_PLUGINS_SOURCE", str(int(all_sources)))
541+
try:
542+
value = int(raw)
543+
if not 0 <= value <= int(all_sources):
544+
raise ValueError
545+
plugin_sources = PluginSource(value)
546+
except ValueError as e:
547+
raise ValueError(
548+
f"TORCHX_PLUGINS_SOURCE={raw!r} is invalid; expected one of"
549+
" 0 (none), 1 (namespace only), 2 (entry points only),"
550+
" 3 (both)."
551+
) from e
552+
return PluginRegistry(plugin_sources=plugin_sources)

0 commit comments

Comments
 (0)