diff --git a/pyosmo/config.py b/pyosmo/config.py index 6ca1a7e..0a7345a 100644 --- a/pyosmo/config.py +++ b/pyosmo/config.py @@ -1,5 +1,6 @@ import logging from random import Random, randint +from typing import Any from pyosmo.algorithm import RandomAlgorithm from pyosmo.algorithm.base import OsmoAlgorithm @@ -14,6 +15,99 @@ DEFAULT_SEED_MAX = 10000 +class ConfigurationError(ValueError): + """Raised when configuration validation fails.""" + + pass + + +class ConfigValidator: + """Validates configuration values with comprehensive error messages.""" + + @staticmethod + def validate_algorithm(algorithm: Any) -> None: + """Validate algorithm configuration. + + Args: + algorithm: Algorithm to validate + + Raises: + ConfigurationError: If algorithm is invalid + """ + if algorithm is None: + raise ConfigurationError('Algorithm cannot be None. Please provide a valid OsmoAlgorithm instance.') + + if not isinstance(algorithm, OsmoAlgorithm): + raise ConfigurationError( + f'Algorithm must be an instance of OsmoAlgorithm, ' + f'got {type(algorithm).__name__}. ' + f'Available algorithms: RandomAlgorithm, WeightedAlgorithm, BalancingAlgorithm.' + ) + + @staticmethod + def validate_end_condition(condition: Any, name: str = 'End condition') -> None: + """Validate end condition configuration. + + Args: + condition: End condition to validate + name: Name of the condition for error messages + + Raises: + ConfigurationError: If end condition is invalid + """ + if condition is None: + raise ConfigurationError(f'{name} cannot be None. Please provide a valid OsmoEndCondition instance.') + + if not isinstance(condition, OsmoEndCondition): + raise ConfigurationError( + f'{name} must be an instance of OsmoEndCondition, ' + f'got {type(condition).__name__}. ' + f'Available conditions: Length, Time, StepCoverage, Endless, And, Or.' + ) + + @staticmethod + def validate_error_strategy(strategy: Any, name: str = 'Error strategy') -> None: + """Validate error strategy configuration. + + Args: + strategy: Error strategy to validate + name: Name of the strategy for error messages + + Raises: + ConfigurationError: If error strategy is invalid + """ + if strategy is None: + raise ConfigurationError(f'{name} cannot be None. Please provide a valid OsmoErrorStrategy instance.') + + if not isinstance(strategy, OsmoErrorStrategy): + raise ConfigurationError( + f'{name} must be an instance of OsmoErrorStrategy, ' + f'got {type(strategy).__name__}. ' + f'Available strategies: AlwaysRaise, AlwaysIgnore, IgnoreAsserts, AllowCount.' + ) + + @staticmethod + def validate_seed(seed: Any) -> None: + """Validate random seed value. + + Args: + seed: Seed value to validate + + Raises: + ConfigurationError: If seed is invalid + """ + if not isinstance(seed, int): + raise ConfigurationError(f'Seed must be an integer, got {type(seed).__name__}.') + + if seed < 0: + raise ConfigurationError(f'Seed must be non-negative, got {seed}.') + + if seed > 2**32 - 1: + raise ConfigurationError( + f'Seed must fit in 32 bits (max {2**32 - 1}), got {seed}. Use a smaller seed value for reproducibility.' + ) + + class OsmoConfig: """Osmo run configuration object""" @@ -35,10 +129,16 @@ def algorithm(self) -> OsmoAlgorithm: return self._algorithm @algorithm.setter - def algorithm(self, value: OsmoAlgorithm): - """Set test generation algorithm""" - if not isinstance(value, OsmoAlgorithm): - raise AttributeError('algorithm needs to be OsmoAlgorithm') + def algorithm(self, value: OsmoAlgorithm) -> None: + """Set test generation algorithm with validation. + + Args: + value: Algorithm instance + + Raises: + ConfigurationError: If algorithm is invalid + """ + ConfigValidator.validate_algorithm(value) self._algorithm = value @property @@ -46,10 +146,16 @@ def test_end_condition(self) -> OsmoEndCondition: return self._test_end_condition @test_end_condition.setter - def test_end_condition(self, value: OsmoEndCondition): - """Set test generation test_end_condition""" - if not isinstance(value, OsmoEndCondition): - raise AttributeError('test_end_condition needs to be OsmoEndCondition') + def test_end_condition(self, value: OsmoEndCondition) -> None: + """Set test end condition with validation. + + Args: + value: End condition instance + + Raises: + ConfigurationError: If end condition is invalid + """ + ConfigValidator.validate_end_condition(value, 'Test end condition') self._test_end_condition = value @property @@ -57,10 +163,16 @@ def test_suite_end_condition(self) -> OsmoEndCondition: return self._test_suite_end_condition @test_suite_end_condition.setter - def test_suite_end_condition(self, value: OsmoEndCondition): - """Set test generation test_suite_end_condition""" - if not isinstance(value, OsmoEndCondition): - raise AttributeError('test_suite_end_condition needs to be OsmoEndCondition') + def test_suite_end_condition(self, value: OsmoEndCondition) -> None: + """Set test suite end condition with validation. + + Args: + value: End condition instance + + Raises: + ConfigurationError: If end condition is invalid + """ + ConfigValidator.validate_end_condition(value, 'Test suite end condition') self._test_suite_end_condition = value @property @@ -68,10 +180,16 @@ def test_error_strategy(self) -> OsmoErrorStrategy: return self._test_error_strategy @test_error_strategy.setter - def test_error_strategy(self, value: OsmoErrorStrategy): - """Set test generation test_suite_end_condition""" - if not isinstance(value, OsmoErrorStrategy): - raise AttributeError('test_error_strategy needs to be OsmoErrorStrategy') + def test_error_strategy(self, value: OsmoErrorStrategy) -> None: + """Set test error strategy with validation. + + Args: + value: Error strategy instance + + Raises: + ConfigurationError: If error strategy is invalid + """ + ConfigValidator.validate_error_strategy(value, 'Test error strategy') self._test_error_strategy = value @property @@ -79,8 +197,14 @@ def test_suite_error_strategy(self) -> OsmoErrorStrategy: return self._test_suite_error_strategy @test_suite_error_strategy.setter - def test_suite_error_strategy(self, value: OsmoErrorStrategy): - """Set test generation test_suite_end_condition""" - if not isinstance(value, OsmoErrorStrategy): - raise AttributeError('test_suite_error_strategy needs to be OsmoErrorStrategy') + def test_suite_error_strategy(self, value: OsmoErrorStrategy) -> None: + """Set test suite error strategy with validation. + + Args: + value: Error strategy instance + + Raises: + ConfigurationError: If error strategy is invalid + """ + ConfigValidator.validate_error_strategy(value, 'Test suite error strategy') self._test_suite_error_strategy = value diff --git a/pyosmo/discovery/__init__.py b/pyosmo/discovery/__init__.py new file mode 100644 index 0000000..b39869c --- /dev/null +++ b/pyosmo/discovery/__init__.py @@ -0,0 +1,18 @@ +"""Model discovery strategies for PyOsmo. + +This module provides extensible discovery mechanisms for finding test steps, +guards, and weights in model classes. +""" + +from pyosmo.discovery.base import DiscoveryStrategy, ModelMetadata +from pyosmo.discovery.decorator import DecoratorBasedDiscovery +from pyosmo.discovery.naming import NamingConventionDiscovery +from pyosmo.discovery.orchestrator import ModelDiscovery + +__all__ = [ + 'DiscoveryStrategy', + 'ModelMetadata', + 'DecoratorBasedDiscovery', + 'NamingConventionDiscovery', + 'ModelDiscovery', +] diff --git a/pyosmo/discovery/base.py b/pyosmo/discovery/base.py new file mode 100644 index 0000000..6ec22b5 --- /dev/null +++ b/pyosmo/discovery/base.py @@ -0,0 +1,67 @@ +"""Base classes for model discovery strategies.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class StepMetadata: + """Metadata for a discovered test step.""" + + name: str # Step name (without 'step_' prefix) + function_name: str # Full function name + method: Any # The actual method object + is_decorator_based: bool = False + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ModelMetadata: + """Complete metadata for a discovered model.""" + + steps: list[StepMetadata] + model: object + + def get_step_names(self) -> list[str]: + """Get list of step names.""" + return [step.name for step in self.steps] + + def get_step_by_name(self, name: str) -> StepMetadata | None: + """Get step metadata by name.""" + for step in self.steps: + if step.name == name: + return step + return None + + +class DiscoveryStrategy(ABC): + """Base class for model discovery strategies. + + Discovery strategies are responsible for finding test steps, + guards, and weights in model classes using different mechanisms + (naming conventions, decorators, annotations, etc.). + """ + + @abstractmethod + def discover_steps(self, model: object) -> list[StepMetadata]: + """Discover test steps in the model. + + Args: + model: Model instance to inspect + + Returns: + List of discovered step metadata + """ + pass + + def get_priority(self) -> int: + """Get priority for this discovery strategy. + + Lower numbers = higher priority (checked first). + Default is 100. + + Returns: + Priority value + """ + return 100 diff --git a/pyosmo/discovery/decorator.py b/pyosmo/discovery/decorator.py new file mode 100644 index 0000000..f86154d --- /dev/null +++ b/pyosmo/discovery/decorator.py @@ -0,0 +1,51 @@ +"""Decorator-based discovery strategy.""" + +import inspect + +from pyosmo.discovery.base import DiscoveryStrategy, StepMetadata + + +class DecoratorBasedDiscovery(DiscoveryStrategy): + """Discover steps via @step, @guard decorators. + + This strategy finds methods decorated with @step and extracts + their metadata. It has higher priority than naming convention + to allow explicit override of naming-based discovery. + """ + + def discover_steps(self, model: object) -> list[StepMetadata]: + """Discover steps via decorators. + + Args: + model: Model instance to inspect + + Returns: + List of step metadata for decorator-based steps + """ + steps = [] + + for attr_name, method in inspect.getmembers(model, predicate=callable): + # Skip private/protected methods + if attr_name.startswith('_'): + continue + + # Check for @step decorator + if hasattr(method, '_osmo_step'): + step_name = getattr(method, '_osmo_step_name', attr_name) + metadata = getattr(method, '_osmo_metadata', {}) + + steps.append( + StepMetadata( + name=step_name, + function_name=attr_name, + method=method, + is_decorator_based=True, + metadata=metadata, + ) + ) + + return steps + + def get_priority(self) -> int: + """Decorator-based discovery has high priority (10).""" + return 10 diff --git a/pyosmo/discovery/naming.py b/pyosmo/discovery/naming.py new file mode 100644 index 0000000..f5406eb --- /dev/null +++ b/pyosmo/discovery/naming.py @@ -0,0 +1,55 @@ +"""Naming convention discovery strategy.""" + +import inspect + +from pyosmo.discovery.base import DiscoveryStrategy, StepMetadata + + +class NamingConventionDiscovery(DiscoveryStrategy): + """Discover steps via step_* naming convention. + + This strategy finds methods that follow the naming convention: + - step_* for test steps + - guard_* for guards + - weight_* for weights + + It has lower priority than decorator-based discovery to allow + decorators to override naming-based behavior. + """ + + def discover_steps(self, model: object) -> list[StepMetadata]: + """Discover steps via naming convention. + + Args: + model: Model instance to inspect + + Returns: + List of step metadata for naming convention steps + """ + steps = [] + + for attr_name, method in inspect.getmembers(model, predicate=callable): + # Skip private/protected methods + if attr_name.startswith('_'): + continue + + # Check for step_* naming convention + if attr_name.startswith('step_'): + # Extract step name (remove 'step_' prefix) + step_name = attr_name[5:] + + steps.append( + StepMetadata( + name=step_name, + function_name=attr_name, + method=method, + is_decorator_based=False, + metadata={}, + ) + ) + + return steps + + def get_priority(self) -> int: + """Naming convention discovery has lower priority (50).""" + return 50 diff --git a/pyosmo/discovery/orchestrator.py b/pyosmo/discovery/orchestrator.py new file mode 100644 index 0000000..72a6af1 --- /dev/null +++ b/pyosmo/discovery/orchestrator.py @@ -0,0 +1,63 @@ +"""Discovery orchestrator that combines multiple strategies.""" + +from pyosmo.discovery.base import DiscoveryStrategy, ModelMetadata, StepMetadata +from pyosmo.discovery.decorator import DecoratorBasedDiscovery +from pyosmo.discovery.naming import NamingConventionDiscovery + + +class ModelDiscovery: + """Orchestrates multiple discovery strategies. + + This class manages multiple discovery strategies and combines their + results, handling deduplication and priority ordering. + """ + + def __init__(self, strategies: list[DiscoveryStrategy] | None = None): + """Initialize with discovery strategies. + + Args: + strategies: List of discovery strategies. If None, uses default + strategies (decorator-based and naming convention). + """ + if strategies is None: + strategies = [ + DecoratorBasedDiscovery(), + NamingConventionDiscovery(), + ] + + # Sort strategies by priority (lower number = higher priority) + self.strategies = sorted(strategies, key=lambda s: s.get_priority()) + + def discover(self, model: object) -> ModelMetadata: + """Discover all model components using all strategies. + + Args: + model: Model instance to inspect + + Returns: + ModelMetadata with all discovered steps + """ + all_steps: list[StepMetadata] = [] + seen_names: set[str] = set() + + # Run strategies in priority order + for strategy in self.strategies: + steps = strategy.discover_steps(model) + + for step in steps: + # Skip if we've already found this step (higher priority strategy won) + if step.function_name not in seen_names: + all_steps.append(step) + seen_names.add(step.function_name) + + return ModelMetadata(steps=all_steps, model=model) + + def add_strategy(self, strategy: DiscoveryStrategy) -> None: + """Add a new discovery strategy. + + Args: + strategy: Discovery strategy to add + """ + self.strategies.append(strategy) + # Re-sort by priority + self.strategies.sort(key=lambda s: s.get_priority()) diff --git a/pyosmo/model.py b/pyosmo/model.py index d60a7d6..226bd3d 100644 --- a/pyosmo/model.py +++ b/pyosmo/model.py @@ -71,8 +71,8 @@ def weight(self) -> float: return float(weight_function.execute()) # Check weight attribute (legacy decorator) - if 'weight' in dir(self.func): - return float(self.func.weight) # type: ignore[attr-defined] + if hasattr(self.func, 'weight'): + return float(self.func.weight) return self.default_weight # Default value @@ -105,12 +105,18 @@ def is_available(self) -> bool: return True if self.guard_function is None else bool(self.guard_function.execute()) def _find_decorator_guard(self) -> Optional['ModelFunction']: - """Find a guard method decorated with @guard("step_name") for this step.""" - for attr_name in dir(self.object_instance): - method = getattr(self.object_instance, attr_name) + """Find a guard method decorated with @guard("step_name") for this step. + + Uses inspect.getmembers() for robust introspection. + Supports both instance methods and static methods. + """ + for attr_name, method in inspect.getmembers(self.object_instance, predicate=callable): + # Skip private/protected methods + if attr_name.startswith('_'): + continue + if ( - callable(method) - and hasattr(method, '_osmo_guard') + hasattr(method, '_osmo_guard') and hasattr(method, '_osmo_guard_for') and method._osmo_guard_for == self.name ): @@ -123,9 +129,14 @@ def guard_function(self) -> Optional['ModelFunction']: return self.return_function_if_exists(self.guard_name) def return_function_if_exists(self, name: str) -> Optional['ModelFunction']: - """Return ModelFunction if method exists in the model instance, otherwise None.""" - if name in dir(self.object_instance): - return ModelFunction(name, self.object_instance) + """Return ModelFunction if method exists in the model instance, otherwise None. + + Uses hasattr() instead of dir() for robust attribute checking. + """ + if hasattr(self.object_instance, name): + attr = getattr(self.object_instance, name) + if callable(attr): + return ModelFunction(name, self.object_instance) return None @@ -136,57 +147,88 @@ def __init__(self) -> None: # Format: functions[function_name] = link_of_instance self.sub_models: list[object] = [] self.debug: bool = False + # Performance optimization: cache discovered steps + self._steps_cache: list[TestStep] | None = None + self._cache_valid: bool = False def _discover_steps(self, sub_model: object) -> Iterator[TestStep]: - """Discover steps using both naming convention and decorators.""" + """Discover steps using both naming convention and decorators. + + Uses inspect.getmembers() for robust introspection, avoiding + fragile dir() patterns and properly handling inherited methods. + Supports both instance methods and static methods. + """ discovered_step_names = set() # First, discover decorator-based steps - for attr_name in dir(sub_model): - method = getattr(sub_model, attr_name) - if callable(method) and hasattr(method, '_osmo_step'): - step_name = method._osmo_step_name + for attr_name, method in inspect.getmembers(sub_model, predicate=callable): + # Skip private/protected methods + if attr_name.startswith('_'): + continue + + if hasattr(method, '_osmo_step'): + step_name = method._osmo_step_name # type: ignore[attr-defined] discovered_step_names.add(attr_name) yield TestStep(attr_name, sub_model, step_name, is_decorator_based=True) # Then, discover naming convention steps (skip if already found via decorator) - for attr_name in dir(sub_model): - if attr_name in discovered_step_names: + for attr_name, _method in inspect.getmembers(sub_model, predicate=callable): + # Skip if already discovered or is private/protected + if attr_name in discovered_step_names or attr_name.startswith('_'): continue - if callable(getattr(sub_model, attr_name)) and attr_name.startswith('step_'): + + if attr_name.startswith('step_'): yield TestStep(attr_name, sub_model) @property def all_steps(self) -> Iterator[TestStep]: - return (step for sub_model in self.sub_models for step in self._discover_steps(sub_model)) + """Get all discovered steps (with caching for performance). + + Steps are discovered once and cached until models are added/modified. + This improves performance when repeatedly accessing all_steps. + """ + if not self._cache_valid or self._steps_cache is None: + # Rebuild cache + self._steps_cache = [step for sub_model in self.sub_models for step in self._discover_steps(sub_model)] + self._cache_valid = True + + return iter(self._steps_cache) def get_step_by_name(self, name: str) -> TestStep | None: - """Get step by function name""" - steps = ( - TestStep(f, sub_model) - for sub_model in self.sub_models - for f in dir(sub_model) - if callable(getattr(sub_model, f)) and f == name - ) - for step in steps: - return step - return None # noqa + """Get step by function name. + + Uses inspect.getmembers() for robust introspection. + Supports both instance methods and static methods. + """ + for sub_model in self.sub_models: + for attr_name, _method in inspect.getmembers(sub_model, predicate=callable): + if attr_name == name: + return TestStep(attr_name, sub_model) + return None def functions_by_name(self, name: str) -> Iterator[ModelFunction]: - return ( - ModelFunction(f, sub_model) - for sub_model in self.sub_models - for f in dir(sub_model) - if callable(getattr(sub_model, f)) and f == name - ) + """Get all functions with a specific name from all sub-models. + + Uses inspect.getmembers() for robust introspection. + Supports both instance methods and static methods. + """ + for sub_model in self.sub_models: + for attr_name, _method in inspect.getmembers(sub_model, predicate=callable): + if attr_name == name: + yield ModelFunction(attr_name, sub_model) def add_model(self, model: object) -> None: - """Add model for osmo""" + """Add model for osmo. + + Invalidates step cache since new model may add steps. + """ # Check if model is a class (not an instance) and instantiate it if inspect.isclass(model): model = model() self.sub_models.append(model) + # Invalidate cache since we added a model + self._cache_valid = False logger.debug(f'Loaded model: {model.__class__}') def execute_optional(self, function_name: str) -> None: diff --git a/pyosmo/osmo.py b/pyosmo/osmo.py index eedda7c..d0f1c51 100644 --- a/pyosmo/osmo.py +++ b/pyosmo/osmo.py @@ -47,10 +47,18 @@ def seed(self) -> int: @seed.setter def seed(self, value: int) -> None: - """Set random seed for test generation""" + """Set random seed for test generation with validation. + + Args: + value: Random seed value (must be non-negative int that fits in 32 bits) + + Raises: + ConfigurationError: If seed value is invalid + """ + from pyosmo.config import ConfigValidator + + ConfigValidator.validate_seed(value) logger.debug(f'Set seed: {value}') - if not isinstance(value, int): - raise AttributeError('Seed value must be an integer.') self._seed = value self._random = Random(self._seed) # update osmo_random in all models @@ -73,18 +81,52 @@ def add_model(self, model: object) -> None: def _run_step(self, step: TestStep) -> None: """ - Run step and save it to the history - :param step: Test step - :return: + Run step and save it to the history. + + Args: + step: Test step to execute + + Raises: + KeyboardInterrupt: User interrupted execution (preserved) + Exception: Any error during step execution (with proper error chaining) """ logger.debug(f'Run step: {step}') start_time = datetime.now() try: step.execute() self.history.add_step(step, datetime.now() - start_time) + except KeyboardInterrupt: + # Preserve keyboard interrupt for user cancellation + # Note: KeyboardInterrupt is BaseException, not Exception, so we can't log it + duration = datetime.now() - start_time + self.history.add_step(step, duration, None) + raise + except AssertionError as error: + # Test assertion failed + duration = datetime.now() - start_time + self.history.add_step(step, duration, error) + logger.debug(f'Step {step} assertion failed: {error}') + raise + except AttributeError as error: + # Missing attribute/method in model or step + duration = datetime.now() - start_time + self.history.add_step(step, duration, error) + raise RuntimeError( + f"Step '{step}' tried to access missing attribute. Check your model implementation." + ) from error + except TypeError as error: + # Method signature or type issue + duration = datetime.now() - start_time + self.history.add_step(step, duration, error) + raise RuntimeError( + f"Step '{step}' has invalid signature or type mismatch. Check your step method implementation." + ) from error except Exception as error: - self.history.add_step(step, datetime.now() - start_time, error) - raise error + # Other runtime errors + duration = datetime.now() - start_time + self.history.add_step(step, duration, error) + logger.debug(f'Step {step} failed with {type(error).__name__}: {error}') + raise def run(self) -> None: """Same as generate but in online usage this sounds more natural""" @@ -113,7 +155,11 @@ def generate(self) -> None: self.model.execute_optional(f'pre_{step}') try: self._run_step(step) + except KeyboardInterrupt: + # User interrupted - re-raise immediately + raise except BaseException as error: + # Let error strategy decide how to handle self.test_error_strategy.failure_in_test(self.history, self.model, error) self.model.execute_optional(f'post_{step.name}') # General after step which is run after each step @@ -122,7 +168,11 @@ def generate(self) -> None: if self.test_end_condition.end_test(self.history, self.model): break self.model.execute_optional('after_test') + except KeyboardInterrupt: + # User interrupted - re-raise immediately without error strategy processing + raise except BaseException as error: + # Let suite error strategy decide how to handle self.test_suite_error_strategy.failure_in_suite(self.history, self.model, error) if self.test_suite_end_condition.end_suite(self.history, self.model): break diff --git a/pyosmo/plugins/__init__.py b/pyosmo/plugins/__init__.py new file mode 100644 index 0000000..8a66f4a --- /dev/null +++ b/pyosmo/plugins/__init__.py @@ -0,0 +1,13 @@ +"""Plugin registry system for PyOsmo. + +This module provides a centralized registry for algorithms, end conditions, +and error strategies, enabling easy discovery and extension. +""" + +from pyosmo.plugins.registry import PluginRegistry, get_registry, register_algorithm + +__all__ = [ + 'PluginRegistry', + 'get_registry', + 'register_algorithm', +] diff --git a/pyosmo/plugins/registry.py b/pyosmo/plugins/registry.py new file mode 100644 index 0000000..609d950 --- /dev/null +++ b/pyosmo/plugins/registry.py @@ -0,0 +1,311 @@ +"""Central registry for all plugins.""" + +import logging +from typing import Any + +logger = logging.getLogger('osmo') + + +class PluginError(Exception): + """Raised when plugin registration or retrieval fails.""" + + pass + + +class PluginRegistry: + """Central registry for all PyOsmo plugins. + + This registry maintains collections of algorithms, end conditions, + and error strategies, enabling discoverability and extensibility. + """ + + def __init__(self) -> None: + """Initialize empty registry.""" + self._algorithms: dict[str, tuple[type[Any], str]] = {} + self._end_conditions: dict[str, tuple[type[Any], str]] = {} + self._error_strategies: dict[str, tuple[type[Any], str]] = {} + + def register_algorithm( + self, + name: str, + algorithm_class: type[Any], + *, + description: str = '', + replace: bool = False, + ) -> None: + """Register a test generation algorithm. + + Args: + name: Unique name for this algorithm + algorithm_class: Algorithm class to register + description: Human-readable description + replace: If True, replace existing registration (default: False) + + Raises: + PluginError: If name already registered and replace=False + """ + if name in self._algorithms and not replace: + raise PluginError(f"Algorithm '{name}' already registered. Use replace=True to override.") + + from pyosmo.algorithm.base import OsmoAlgorithm + + if not (isinstance(algorithm_class, type) and issubclass(algorithm_class, OsmoAlgorithm)): + raise PluginError(f'Algorithm class must extend OsmoAlgorithm, got {algorithm_class}') + + self._algorithms[name] = (algorithm_class, description) + logger.debug(f'Registered algorithm: {name}') + + def register_end_condition( + self, + name: str, + condition_class: type[Any], + *, + description: str = '', + replace: bool = False, + ) -> None: + """Register an end condition. + + Args: + name: Unique name for this condition + condition_class: End condition class to register + description: Human-readable description + replace: If True, replace existing registration (default: False) + + Raises: + PluginError: If name already registered and replace=False + """ + if name in self._end_conditions and not replace: + raise PluginError(f"End condition '{name}' already registered. Use replace=True to override.") + + from pyosmo.end_conditions.base import OsmoEndCondition + + if not (isinstance(condition_class, type) and issubclass(condition_class, OsmoEndCondition)): + raise PluginError(f'End condition class must extend OsmoEndCondition, got {condition_class}') + + self._end_conditions[name] = (condition_class, description) + logger.debug(f'Registered end condition: {name}') + + def register_error_strategy( + self, + name: str, + strategy_class: type[Any], + *, + description: str = '', + replace: bool = False, + ) -> None: + """Register an error strategy. + + Args: + name: Unique name for this strategy + strategy_class: Error strategy class to register + description: Human-readable description + replace: If True, replace existing registration (default: False) + + Raises: + PluginError: If name already registered and replace=False + """ + if name in self._error_strategies and not replace: + raise PluginError(f"Error strategy '{name}' already registered. Use replace=True to override.") + + from pyosmo.error_strategy.base import OsmoErrorStrategy + + if not (isinstance(strategy_class, type) and issubclass(strategy_class, OsmoErrorStrategy)): + raise PluginError(f'Error strategy class must extend OsmoErrorStrategy, got {strategy_class}') + + self._error_strategies[name] = (strategy_class, description) + logger.debug(f'Registered error strategy: {name}') + + def get_algorithm(self, name: str) -> type[Any]: + """Get algorithm class by name. + + Args: + name: Algorithm name + + Returns: + Algorithm class + + Raises: + PluginError: If algorithm not found + """ + if name not in self._algorithms: + available = ', '.join(sorted(self._algorithms.keys())) + raise PluginError(f"Algorithm '{name}' not found. Available: {available or 'none'}") + return self._algorithms[name][0] + + def get_end_condition(self, name: str) -> type[Any]: + """Get end condition class by name. + + Args: + name: End condition name + + Returns: + End condition class + + Raises: + PluginError: If end condition not found + """ + if name not in self._end_conditions: + available = ', '.join(sorted(self._end_conditions.keys())) + raise PluginError(f"End condition '{name}' not found. Available: {available or 'none'}") + return self._end_conditions[name][0] + + def get_error_strategy(self, name: str) -> type[Any]: + """Get error strategy class by name. + + Args: + name: Error strategy name + + Returns: + Error strategy class + + Raises: + PluginError: If error strategy not found + """ + if name not in self._error_strategies: + available = ', '.join(sorted(self._error_strategies.keys())) + raise PluginError(f"Error strategy '{name}' not found. Available: {available or 'none'}") + return self._error_strategies[name][0] + + def list_algorithms(self) -> dict[str, str]: + """List all registered algorithms with descriptions. + + Returns: + Dict mapping algorithm names to descriptions + """ + return {name: desc for name, (_, desc) in self._algorithms.items()} + + def list_end_conditions(self) -> dict[str, str]: + """List all registered end conditions with descriptions. + + Returns: + Dict mapping end condition names to descriptions + """ + return {name: desc for name, (_, desc) in self._end_conditions.items()} + + def list_error_strategies(self) -> dict[str, str]: + """List all registered error strategies with descriptions. + + Returns: + Dict mapping error strategy names to descriptions + """ + return {name: desc for name, (_, desc) in self._error_strategies.items()} + + +# Global registry instance +_global_registry: PluginRegistry | None = None + + +def get_registry() -> PluginRegistry: + """Get the global plugin registry. + + Returns: + The global PluginRegistry instance + """ + global _global_registry + if _global_registry is None: + _global_registry = PluginRegistry() + # Register built-in plugins + _register_builtin_plugins(_global_registry) + return _global_registry + + +def _register_builtin_plugins(registry: PluginRegistry) -> None: + """Register built-in algorithms, end conditions, and error strategies. + + Args: + registry: Registry to populate + """ + # Register built-in algorithms + from pyosmo.algorithm import BalancingAlgorithm, RandomAlgorithm, WeightedAlgorithm + + registry.register_algorithm( + 'random', + RandomAlgorithm, + description='Purely random step selection', + ) + registry.register_algorithm( + 'weighted', + WeightedAlgorithm, + description='Weight-based random selection', + ) + registry.register_algorithm( + 'balancing', + BalancingAlgorithm, + description='Coverage-balancing algorithm', + ) + + # Register built-in end conditions + from pyosmo.end_conditions import Endless, Length, StepCoverage, Time + + registry.register_end_condition( + 'length', + Length, + description='Stop after N steps', + ) + registry.register_end_condition( + 'time', + Time, + description='Stop after elapsed time', + ) + registry.register_end_condition( + 'coverage', + StepCoverage, + description='Stop when coverage threshold reached', + ) + registry.register_end_condition( + 'endless', + Endless, + description='Run forever (online testing)', + ) + + # Register built-in error strategies + from pyosmo.error_strategy import ( + AllowCount, + AlwaysIgnore, + AlwaysRaise, + IgnoreAsserts, + ) + + registry.register_error_strategy( + 'raise', + AlwaysRaise, + description='Fail fast on any error', + ) + registry.register_error_strategy( + 'ignore', + AlwaysIgnore, + description='Continue on all errors', + ) + registry.register_error_strategy( + 'ignore_asserts', + IgnoreAsserts, + description='Ignore assertion errors only', + ) + registry.register_error_strategy( + 'allow_count', + AllowCount, + description='Allow up to N errors', + ) + + +def register_algorithm(name: str, description: str = '') -> Any: + """Decorator to register an algorithm with the global registry. + + Args: + name: Unique name for the algorithm + description: Human-readable description + + Returns: + Decorator function + + Example: + @register_algorithm("my_algo", "Custom algorithm") + class MyAlgorithm(OsmoAlgorithm): + ... + """ + + def decorator(cls: type[Any]) -> type[Any]: + get_registry().register_algorithm(name, cls, description=description) + return cls + + return decorator diff --git a/pyosmo/tests/test_config.py b/pyosmo/tests/test_config.py index 1288d9b..accb829 100644 --- a/pyosmo/tests/test_config.py +++ b/pyosmo/tests/test_config.py @@ -2,6 +2,7 @@ from pyosmo import Osmo from pyosmo.algorithm import RandomAlgorithm +from pyosmo.config import ConfigurationError from pyosmo.end_conditions import Length @@ -33,14 +34,14 @@ def test_wrong_config_objects(): osmo = Osmo(OneStepModel()) try: osmo.test_end_condition = RandomAlgorithm() # type: ignore[assignment] - except AttributeError: + except ConfigurationError: pass - except: - raise + except Exception as e: + raise AssertionError(f'Expected ConfigurationError, got {type(e).__name__}: {e}') from e try: osmo.algorithm = Length(1) # type: ignore[assignment] - except AttributeError: + except ConfigurationError: pass - except: - raise + except Exception as e: + raise AssertionError(f'Expected ConfigurationError, got {type(e).__name__}: {e}') from e