|
1 | 1 | import itertools |
2 | 2 | from importlib import import_module |
| 3 | +from typing import Dict |
3 | 4 |
|
4 | 5 | from backports.entry_points_selectable import entry_points # backport for Python 3.9 |
5 | 6 |
|
|
12 | 13 |
|
13 | 14 | _PLUGINS: list[Plugin] = [] |
14 | 15 |
|
| 16 | +_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"} |
15 | 17 |
|
16 | | -def load_plugins(enabled_plugins: list[str]): |
17 | | - _PLUGINS.clear() |
18 | | - plugins_entrypoints = entry_points(group="dstack.plugins") |
19 | | - plugins_to_load = enabled_plugins.copy() |
20 | | - for entrypoint in plugins_entrypoints: |
21 | | - if entrypoint.name not in enabled_plugins: |
22 | | - logger.info( |
23 | | - ("Found not enabled plugin %s. Plugin will not be loaded."), |
24 | | - entrypoint.name, |
25 | | - ) |
26 | | - continue |
| 18 | + |
| 19 | +class PluginEntrypoint: |
| 20 | + def __init__(self, name: str, import_path: str, is_builtin: bool = False): |
| 21 | + self.name = name |
| 22 | + self.import_path = import_path |
| 23 | + self.is_builtin = is_builtin |
| 24 | + |
| 25 | + def load(self): |
| 26 | + module_path, _, class_name = self.import_path.partition(":") |
27 | 27 | try: |
28 | | - module_path, _, class_name = entrypoint.value.partition(":") |
29 | 28 | module = import_module(module_path) |
| 29 | + plugin_class = getattr(module, class_name, None) |
| 30 | + if plugin_class is None: |
| 31 | + logger.warning( |
| 32 | + ("Failed to load plugin %s: plugin class %s not found in module %s."), |
| 33 | + self.name, |
| 34 | + class_name, |
| 35 | + module_path, |
| 36 | + ) |
| 37 | + return None |
| 38 | + if not issubclass(plugin_class, Plugin): |
| 39 | + logger.warning( |
| 40 | + ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."), |
| 41 | + self.name, |
| 42 | + class_name, |
| 43 | + ) |
| 44 | + return None |
| 45 | + return plugin_class() |
30 | 46 | except ImportError: |
31 | 47 | logger.warning( |
32 | 48 | ( |
33 | 49 | "Failed to load plugin %s when importing %s." |
34 | 50 | " Ensure the module is on the import path." |
35 | 51 | ), |
36 | | - entrypoint.name, |
37 | | - entrypoint.value, |
| 52 | + self.name, |
| 53 | + self.import_path, |
38 | 54 | ) |
39 | | - continue |
40 | | - plugin_class = getattr(module, class_name, None) |
41 | | - if plugin_class is None: |
42 | | - logger.warning( |
43 | | - ("Failed to load plugin %s: plugin class %s not found in module %s."), |
| 55 | + return None |
| 56 | + |
| 57 | + |
| 58 | +def load_plugins(enabled_plugins: list[str]): |
| 59 | + _PLUGINS.clear() |
| 60 | + entrypoints: dict[str, PluginEntrypoint] = {} |
| 61 | + plugins_to_load = enabled_plugins.copy() |
| 62 | + for entrypoint in entry_points(group="dstack.plugins"): |
| 63 | + if entrypoint.name not in enabled_plugins: |
| 64 | + logger.info( |
| 65 | + ("Found not enabled plugin %s. Plugin will not be loaded."), |
44 | 66 | entrypoint.name, |
45 | | - class_name, |
46 | | - module_path, |
47 | 67 | ) |
48 | 68 | continue |
49 | | - if not issubclass(plugin_class, Plugin): |
50 | | - logger.warning( |
51 | | - ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."), |
52 | | - entrypoint.name, |
53 | | - class_name, |
| 69 | + else: |
| 70 | + entrypoints[entrypoint.name] = PluginEntrypoint( |
| 71 | + entrypoint.name, entrypoint.value, is_builtin=False |
54 | 72 | ) |
55 | | - continue |
56 | | - plugins_to_load.remove(entrypoint.name) |
57 | | - _PLUGINS.append(plugin_class()) |
58 | | - logger.info("Loaded plugin %s", entrypoint.name) |
| 73 | + |
| 74 | + for name, import_path in _BUILTIN_PLUGINS.items(): |
| 75 | + if name not in enabled_plugins: |
| 76 | + logger.info( |
| 77 | + ("Found not enabled builtin plugin %s. Plugin will not be loaded."), |
| 78 | + name, |
| 79 | + ) |
| 80 | + else: |
| 81 | + entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True) |
| 82 | + |
| 83 | + for plugin_name, plugin_entrypoint in entrypoints.items(): |
| 84 | + plugin_instance = plugin_entrypoint.load() |
| 85 | + if plugin_instance is not None: |
| 86 | + _PLUGINS.append(plugin_instance) |
| 87 | + plugins_to_load.remove(plugin_name) |
| 88 | + logger.info("Loaded plugin %s", plugin_name) |
| 89 | + |
59 | 90 | if plugins_to_load: |
60 | 91 | logger.warning("Enabled plugins not found: %s", plugins_to_load) |
61 | 92 |
|
|
0 commit comments