2222import os
2323import pathlib
2424import pkgutil
25+ from enum import auto
2526from types import ModuleType
2627from typing import Any , Callable , overload
2728
@@ -48,6 +49,19 @@ class PluginType(str, enum.Enum):
4849 TRACKER = "torchx.tracker"
4950
5051
52+ class PluginSource (enum .IntFlag ):
53+ """Bitmask of plugin discovery channels.
54+
55+ Combine with ``|`` and test with ``in``. The value of the
56+ ``TORCHX_PLUGINS_SOURCE`` env var is parsed as the integer
57+ representation of this flag.
58+ """
59+
60+ NONE = 0
61+ NAMESPACE_PKG = auto ()
62+ ENTRYPOINT = auto ()
63+
64+
5165# ── Public API ────────────────────────────────────────────────────────────────
5266
5367
@@ -80,19 +94,24 @@ class PluginRegistry:
8094 scheds = reg.get(plugins.PluginType.SCHEDULER)
8195 print(reg)
8296
83- Namespace plugins (``torchx_plugins.*``) are **always** loaded.
84- The *load_entrypoints* flag only controls whether ``importlib.metadata``
85- entry points are additionally merged in.
97+ Discovery channels are selected via *plugin_sources*, a :py:class:`PluginSource`
98+ bitmask. Defaults to all channels enabled; pass :py:attr:`PluginSource.NONE`
99+ for an empty registry or any combination of ``NAMESPACE_PKG`` / ``ENTRYPOINT``
100+ to enable a subset.
86101
87102 Args:
88- load_entrypoints: Whether to **also** load plugins from
89- ``importlib.metadata`` entry points. Set to ``False`` to
90- disable entry-point loading (namespace plugins are still
91- discovered). Defaults to ``True`` for backward compatibility.
103+ plugin_sources: Bitmask of discovery channels to enable.
104+ Defaults to ``NAMESPACE_PKG | ENTRYPOINT``.
92105 """
93106
94- def __init__ (self , * , load_entrypoints : bool = True ) -> None :
95- self ._load_entrypoints : bool = load_entrypoints
107+ def __init__ (
108+ self ,
109+ * ,
110+ plugin_sources : PluginSource = (
111+ PluginSource .NAMESPACE_PKG | PluginSource .ENTRYPOINT
112+ ),
113+ ) -> None :
114+ self ._plugin_sources : PluginSource = plugin_sources
96115 # pyre-ignore[4]: plugin factories are heterogeneously typed
97116 self ._cache : dict [PluginType , dict [str , Callable [..., Any ]]] = {}
98117 self ._errors : list [RegistrationError ] = []
@@ -476,28 +495,20 @@ def _find(self, plugin_type: PluginType) -> dict[str, Callable[..., Any]]:
476495
477496 Merge priority (highest → lowest):
478497
479- 1. ``importlib.metadata`` entry points (when *load_entrypoints* is True)
480- 2. ``torchx_plugins.<suffix>`` namespace submodules
498+ 1. ``importlib.metadata`` entry points (when ``ENTRYPOINT`` is set)
499+ 2. ``torchx_plugins.<suffix>`` namespace submodules (when
500+ ``NAMESPACE_PKG`` is set)
481501 """
482502 group : str = plugin_type .value
483503 namespace = self ._namespace_for_type (plugin_type )
484- ns_plugins = self ._find_namespace_plugins (namespace , plugin_type )
485- merged = dict (ns_plugins )
486-
487- if self ._load_entrypoints :
488- ep_plugins = entrypoints .load_group (group ) or {}
489- # Entry points override namespace plugins (higher priority).
490- merged .update (ep_plugins )
491- if ns_plugins and ep_plugins :
492- new_keys = set (ns_plugins ) - set (ep_plugins )
493- if new_keys :
494- logger .debug (
495- "namespace plugins added keys %s for group `%s`" ,
496- new_keys ,
497- group ,
498- )
499-
500- return merged
504+ plugins = (
505+ self ._find_namespace_plugins (namespace , plugin_type )
506+ if PluginSource .NAMESPACE_PKG in self ._plugin_sources
507+ else {}
508+ )
509+ if PluginSource .ENTRYPOINT in self ._plugin_sources :
510+ plugins |= entrypoints .load_group (group ) or {}
511+ return plugins
501512
502513
503514@functools .lru_cache (maxsize = 1 )
@@ -507,9 +518,11 @@ def registry() -> PluginRegistry:
507518 The registry lazily discovers plugins per-group on first
508519 :py:meth:`~PluginRegistry.get` access and caches the results.
509520
510- Namespace plugins (``torchx_plugins.*``) are **always** loaded.
511- Entry points from ``importlib.metadata`` are additionally merged in
512- unless ``TORCHX_NO_ENTRYPOINTS=1`` is set.
521+ The ``TORCHX_PLUGINS_SOURCE`` environment variable selects which
522+ discovery channels are enabled. Its value is parsed as the integer
523+ representation of a :py:class:`PluginSource` bitmask: ``0`` for
524+ none, ``1`` for namespace package only, ``2`` for entry points only,
525+ ``3`` for both. Defaults to all channels enabled when unset.
513526
514527 Returns:
515528 The cached :py:class:`PluginRegistry` instance.
@@ -523,5 +536,17 @@ def registry() -> PluginRegistry:
523536 named = reg.get(plugins.PluginType.NAMED_RESOURCE)
524537 print(reg)
525538 """
526- load_entrypoints = os .environ .get ("TORCHX_NO_ENTRYPOINTS" ) != "1"
527- return PluginRegistry (load_entrypoints = load_entrypoints )
539+ all_sources = PluginSource .NAMESPACE_PKG | PluginSource .ENTRYPOINT
540+ raw = os .environ .get ("TORCHX_PLUGINS_SOURCE" , str (int (all_sources )))
541+ try :
542+ value = int (raw )
543+ if not 0 <= value <= int (all_sources ):
544+ raise ValueError
545+ plugin_sources = PluginSource (value )
546+ except ValueError as e :
547+ raise ValueError (
548+ f"TORCHX_PLUGINS_SOURCE={ raw !r} is invalid; expected one of"
549+ " 0 (none), 1 (namespace only), 2 (entry points only),"
550+ " 3 (both)."
551+ ) from e
552+ return PluginRegistry (plugin_sources = plugin_sources )
0 commit comments