Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/mocked_plugin_as_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class VariantFeatureConfig:


namespace = "module_namespace"
is_build_plugin = False


def get_all_configs() -> list[VariantFeatureConfigType]:
Expand Down
26 changes: 12 additions & 14 deletions tests/mocked_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class MockedPluginA(PluginType):

is_build_plugin = True

def get_all_configs(self) -> list[VariantFeatureConfigType]:
@staticmethod
def get_all_configs() -> list[VariantFeatureConfigType]:
return [
VariantFeatureConfig(
"name1", ["val1a", "val1b", "val1c", "val1d"], multi_value=False
Expand All @@ -30,7 +31,8 @@ def get_all_configs(self) -> list[VariantFeatureConfigType]:
),
]

def get_supported_configs(self) -> list[VariantFeatureConfigType]:
@staticmethod
def get_supported_configs() -> list[VariantFeatureConfigType]:
return [
VariantFeatureConfig("name1", ["val1a", "val1b"], multi_value=False),
VariantFeatureConfig(
Expand All @@ -49,14 +51,16 @@ def get_supported_configs(self) -> list[VariantFeatureConfigType]:
class MockedPluginB:
namespace = "second_namespace"

def get_all_configs(self) -> list[MyVariantFeatureConfig]:
@classmethod
def get_all_configs(cls) -> list[MyVariantFeatureConfig]:
return [
MyVariantFeatureConfig(
"name3", ["val3a", "val3b", "val3c"], multi_value=False
),
]

def get_supported_configs(self) -> list[MyVariantFeatureConfig]:
@classmethod
def get_supported_configs(cls) -> list[MyVariantFeatureConfig]:
return [
MyVariantFeatureConfig("name3", ["val3a"], multi_value=False),
]
Expand All @@ -74,26 +78,20 @@ def __init__(self, name: str) -> None:
class MockedPluginC(PluginType):
namespace = "incompatible_namespace"

def get_all_configs(self) -> list[VariantFeatureConfigType]:
@classmethod
def get_all_configs(cls) -> list[VariantFeatureConfigType]:
return [
MyVariantFeatureConfig(x, ["on"], multi_value=False)
for x in ("flag1", "flag2", "flag3", "flag4")
]

def get_supported_configs(self) -> list[VariantFeatureConfigType]:
@staticmethod
def get_supported_configs() -> list[VariantFeatureConfigType]:
return []


class IndirectPath:
class MoreIndirection:
@classmethod
def plugin_a(cls) -> MockedPluginA:
return MockedPluginA()

@staticmethod
def plugin_b() -> MockedPluginB:
return MockedPluginB()

object_a = MockedPluginA()


Expand Down
78 changes: 16 additions & 62 deletions tests/plugins/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,16 @@
class ClashingPlugin(PluginType):
namespace = "test_namespace" # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]

def get_all_configs(self) -> list[VariantFeatureConfigType]:
@classmethod
def get_all_configs(cls) -> list[VariantFeatureConfigType]:
return [
VariantFeatureConfig(
"name1", ["val1a", "val1b", "val1c", "val1d"], multi_value=False
),
]

def get_supported_configs(self) -> list[VariantFeatureConfigType]:
@classmethod
def get_supported_configs(cls) -> list[VariantFeatureConfigType]:
return []


Expand All @@ -64,11 +66,13 @@ class ExceptionPluginBase(PluginType):

returned_value: list[VariantFeatureConfigType]

def get_all_configs(self) -> list[VariantFeatureConfigType]:
return self.returned_value
@classmethod
def get_all_configs(cls) -> list[VariantFeatureConfigType]:
return cls.returned_value

def get_supported_configs(self) -> list[VariantFeatureConfigType]:
return self.returned_value
@classmethod
def get_supported_configs(cls) -> list[VariantFeatureConfigType]:
return cls.returned_value


def test_get_all_configs(
Expand Down Expand Up @@ -275,7 +279,8 @@ def test_namespace_incorrect_name() -> None:
class IncompletePlugin:
namespace = "incomplete_plugin"

def get_supported_configs(self) -> list[VariantFeatureConfigType]:
@classmethod
def get_supported_configs(cls) -> list[VariantFeatureConfigType]:
return []


Expand All @@ -292,48 +297,7 @@ def test_namespace_incorrect_type() -> None:
pass


class RaisingInstantiationPlugin:
namespace = "raising_plugin"

def __init__(self) -> None:
raise RuntimeError("I failed to initialize")

def get_all_configs(self) -> list[VariantFeatureConfigType]:
return []

def get_supported_configs(self) -> list[VariantFeatureConfigType]:
return []


def test_namespace_instantiation_raises() -> None:
with (
pytest.raises(
PluginError,
match=(
"Instantiating the plugin from "
r"'tests.plugins.test_loader:RaisingInstantiationPlugin' failed: "
"I failed to initialize"
),
),
ListPluginLoader(["tests.plugins.test_loader:RaisingInstantiationPlugin"]),
):
pass


class CrossTypeInstantiationPlugin:
namespace = "cross_plugin"

def __new__(cls) -> IncompletePlugin: # type: ignore[misc]
return IncompletePlugin()

def get_all_configs(self) -> list[VariantFeatureConfigType]:
return []

def get_supported_configs(self) -> list[VariantFeatureConfigType]:
return []


@pytest.mark.parametrize("cls", ["IncompletePlugin", "CrossTypeInstantiationPlugin"])
@pytest.mark.parametrize("cls", ["IncompletePlugin"])
def test_namespace_instantiation_returns_incorrect_type(
cls: type,
) -> None:
Expand All @@ -342,9 +306,9 @@ def test_namespace_instantiation_returns_incorrect_type(
PluginError,
match=re.escape(
f"'tests.plugins.test_loader:{cls}' does not meet the PluginType "
"prototype: <tests.plugins.test_loader.IncompletePlugin object at"
)
+ r".*(missing attributes: get_all_configs)",
"prototype: <class 'tests.plugins.test_loader.IncompletePlugin'> "
"(missing attributes: get_all_configs)"
),
),
ListPluginLoader([f"tests.plugins.test_loader:{cls}"]),
):
Expand All @@ -361,16 +325,6 @@ def test_namespaces(
]


def test_non_class_attrs() -> None:
with ListPluginLoader(
[
"tests.mocked_plugins:IndirectPath.MoreIndirection.plugin_a",
"tests.mocked_plugins:IndirectPath.MoreIndirection.plugin_b",
]
) as loader:
assert loader.namespaces == ["test_namespace", "second_namespace"]


def test_non_callable_plugin() -> None:
with ListPluginLoader(
[
Expand Down
74 changes: 74 additions & 0 deletions tests/test_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

import pytest
from variantlib.models.provider import VariantFeatureConfig
from variantlib.protocols import PluginType
from variantlib.protocols import VariantFeatureConfigType

from tests import mocked_plugin_as_module
from tests.mocked_plugins import MockedPluginA
from tests.mocked_plugins import MockedPluginC
from tests.mocked_plugins import MyVariantFeatureConfig


class VariantFeatureConfigTypeSubclass(VariantFeatureConfigType):
name = "a"
values = ["b"]
multi_value = False

def __init__(self, *args):
pass


@pytest.mark.parametrize(
"cls",
[VariantFeatureConfig, MyVariantFeatureConfig, VariantFeatureConfigTypeSubclass],
)
def test_variant_feature_config_type(cls: type) -> None:
# TODO: why do we need to instantiate it? VariantFeatureConfig fails otherwise.
assert isinstance(
cls(name="x", values=["y"], multi_value=False), VariantFeatureConfigType
)


@pytest.mark.parametrize("missing", ["name", "values", "multi_value"])
def test_variant_feature_config_type_abstract(missing: str) -> None:
class PartialVariantFeatureConfigTypeSubclass(VariantFeatureConfigType):
if missing != "name":
name = "a"
if missing != "values":
values = ["b"]
if missing != "multi_value":
multi_value = False

with pytest.raises(TypeError):
PartialVariantFeatureConfigTypeSubclass()


@pytest.mark.parametrize("cls", [MockedPluginA, MockedPluginC, mocked_plugin_as_module])
def test_plugin_type(cls: type) -> None:
assert isinstance(cls, PluginType)


@pytest.mark.parametrize(
"missing", ["namespace", "get_all_configs", "get_supported_configs"]
)
def test_plugin_type_abstract(missing: str) -> None:
class PartialPluginTypeSubclass(PluginType):
if missing != "namespace":
namespace = "ns"

if missing != "get_all_configs":

@staticmethod
def get_all_configs() -> list[VariantFeatureConfigType]:
return []

if missing != "get_supported_configs":

@staticmethod
def get_supported_configs() -> list[VariantFeatureConfigType]:
return []

with pytest.raises(TypeError):
PartialPluginTypeSubclass()
20 changes: 4 additions & 16 deletions variantlib/plugins/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,12 @@ def load_plugins(plugin_apis: list[str]) -> Generator[PluginType]:
try:
module = importlib.import_module(import_name)
attr_chain = attr_path.split(".") if attr_path else []
plugin_callable = reduce(getattr, attr_chain, module)
plugin_instance = reduce(getattr, attr_chain, module)
except Exception as exc:
raise RuntimeError(
f"Loading the plugin from {plugin_api!r} failed: {exc}"
) from exc

# plugin-api can either be a callable (e.g. a class to instantiate
# or a function to call) or a ready object
plugin_instance: PluginType
if callable(plugin_callable):
try:
# Instantiate the plugin
plugin_instance = plugin_callable()
except Exception as exc:
raise RuntimeError(
f"Instantiating the plugin from {plugin_api!r} failed: {exc}"
) from exc
else:
plugin_instance = plugin_callable # pyright: ignore[reportAssignmentType]

# We cannot use isinstance() here since some of the PluginType methods
# are optional. Instead, we use @abstractmethod decorator to naturally
# annotate required methods, and the remaining methods are optional.
Expand All @@ -66,7 +52,9 @@ def load_plugins(plugin_apis: list[str]) -> Generator[PluginType]:
f"{', '.join(sorted(missing_attributes))})"
)

yield plugin_instance
# TODO: mypy complains about ModuleType -> PluginType conversion when PluginType
# contains class methods but not otherwise. Need to confirm if it's a mypy bug.
yield plugin_instance # type: ignore[misc]


def process_configs(
Expand Down
10 changes: 8 additions & 2 deletions variantlib/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def name(self) -> VariantFeatureName:
raise NotImplementedError

@property
@abstractmethod
def multi_value(self) -> bool:
"""Does this property allow multiple values per variant?"""
raise NotImplementedError
Expand All @@ -66,6 +67,9 @@ def values(self) -> list[VariantFeatureValue]:
class PluginType(Protocol):
"""A protocol for plugin classes"""

# Note: properties are used here for docstring purposes, these must
# be actually implemented as attributes.

@property
@abstractmethod
def namespace(self) -> VariantNamespace:
Expand All @@ -88,12 +92,14 @@ def is_build_plugin(self) -> bool:
"""
return False

@classmethod
@abstractmethod
def get_all_configs(self) -> list[VariantFeatureConfigType]:
def get_all_configs(cls) -> list[VariantFeatureConfigType]:
"""Get all valid configs for the plugin"""
raise NotImplementedError

@classmethod
@abstractmethod
def get_supported_configs(self) -> list[VariantFeatureConfigType]:
def get_supported_configs(cls) -> list[VariantFeatureConfigType]:
"""Get supported configs for the current system"""
raise NotImplementedError
Loading