|
| 1 | +"""Consistency checks for integration wrappers. |
| 2 | +
|
| 3 | +Ensures that integration wrappers do not mix base-class ecosystems: |
| 4 | +if a class uses skbase conventions (``_tags`` dict or ``get_test_params``), |
| 5 | +it must also inherit from ``skbase.base.BaseObject``. Otherwise the |
| 6 | +conventions are dead code and the class drops out of hyperactive's |
| 7 | +registry-based test coverage. |
| 8 | +""" |
| 9 | + |
| 10 | +import importlib |
| 11 | +import inspect |
| 12 | +import pkgutil |
| 13 | + |
| 14 | +from skbase.base import BaseObject |
| 15 | + |
| 16 | +import hyperactive.integrations as _integrations_pkg |
| 17 | + |
| 18 | + |
| 19 | +def _iter_integration_classes(): |
| 20 | + """Yield classes defined under ``hyperactive.integrations``. |
| 21 | +
|
| 22 | + Test modules and classes re-exported from other packages are skipped. |
| 23 | + Modules that fail to import (e.g. due to missing optional deps at |
| 24 | + collection time) are skipped as well so this test does not mask soft-dep |
| 25 | + handling done inside the modules themselves. |
| 26 | + """ |
| 27 | + prefix = _integrations_pkg.__name__ + "." |
| 28 | + for mod_info in pkgutil.walk_packages(_integrations_pkg.__path__, prefix=prefix): |
| 29 | + name = mod_info.name |
| 30 | + if ".tests" in name or name.endswith(".tests"): |
| 31 | + continue |
| 32 | + try: |
| 33 | + module = importlib.import_module(name) |
| 34 | + except ImportError: |
| 35 | + continue |
| 36 | + for _, cls in inspect.getmembers(module, inspect.isclass): |
| 37 | + if cls.__module__ == name: |
| 38 | + yield cls |
| 39 | + |
| 40 | + |
| 41 | +def test_integrations_do_not_mix_base_class_ecosystems(): |
| 42 | + """No Mischling wrappers: skbase conventions require skbase inheritance. |
| 43 | +
|
| 44 | + A class that defines ``_tags`` as a dict or its own ``get_test_params`` |
| 45 | + classmethod signals that it wants to participate in skbase-based |
| 46 | + registry and testing. Such a class must inherit from |
| 47 | + ``skbase.base.BaseObject``, otherwise the conventions are silently |
| 48 | + ignored. |
| 49 | + """ |
| 50 | + offenders = [] |
| 51 | + for cls in _iter_integration_classes(): |
| 52 | + uses_skbase_conventions = ( |
| 53 | + isinstance(cls.__dict__.get("_tags"), dict) |
| 54 | + or "get_test_params" in cls.__dict__ |
| 55 | + ) |
| 56 | + if uses_skbase_conventions and not issubclass(cls, BaseObject): |
| 57 | + offenders.append(f"{cls.__module__}.{cls.__name__}") |
| 58 | + |
| 59 | + assert not offenders, ( |
| 60 | + "The following integration classes use skbase conventions (_tags " |
| 61 | + "or get_test_params) but do not inherit from skbase.base.BaseObject. " |
| 62 | + "Either inherit from BaseObject (or a skbase-based base class), or " |
| 63 | + "remove the skbase-style conventions:\n - " + "\n - ".join(sorted(offenders)) |
| 64 | + ) |
0 commit comments