|
1 | | -"""Core functionalities for controller parametrization and registration.""" |
| 1 | +"""Core functionalities for controller parametrization.""" |
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import inspect |
| 6 | +import tomllib |
5 | 7 | from functools import partial |
6 | | -from typing import Any, Callable, ParamSpec, Protocol, TypeVar, runtime_checkable |
| 8 | +from pathlib import Path |
| 9 | +from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar |
7 | 10 |
|
8 | | -P = ParamSpec("P") |
9 | | -R = TypeVar("R") |
| 11 | +import numpy as np |
10 | 12 |
|
| 13 | +if TYPE_CHECKING: |
| 14 | + from types import ModuleType |
11 | 15 |
|
12 | | -controller_parameter_registry: dict[str, type[ControllerParams]] = {} |
| 16 | +P = ParamSpec("P") |
| 17 | +R = TypeVar("R") |
13 | 18 |
|
14 | 19 |
|
15 | | -def parametrize(fn: Callable[P, R], drone_model: str) -> Callable[P, R]: |
16 | | - """Parametrize a controller function with the default controller parameters for a drone model. |
| 20 | +def parametrize( |
| 21 | + fn: Callable[P, R], drone_model: str, xp: ModuleType | None = None, device: str | None = None |
| 22 | +) -> Callable[P, R]: |
| 23 | + """Parametrize a controller function with the default parameters for a drone model. |
17 | 24 |
|
18 | 25 | Args: |
19 | 26 | fn: The controller function to parametrize. |
20 | 27 | drone_model: The drone model to use. |
| 28 | + xp: The array API module to use. If not provided, numpy is used. |
| 29 | + device: The device to use. If none, the device is inferred from the xp module. |
21 | 30 |
|
22 | 31 | Example: |
23 | | - >>> from drone_models.controller import parametrize |
24 | | - >>> from drone_models.controller.mellinger import state2attitude |
25 | | - >>> controller_fn = parametrize(state2attitude, drone_model="cf2x_L250") |
26 | | - >>> command_rpyt, int_pos_err = controller_fn( |
27 | | - ... pos=pos, |
28 | | - ... quat=quat, |
29 | | - ... vel=vel, |
30 | | - ... ang_vel=ang_vel, |
31 | | - ... cmd=cmd, |
32 | | - ... ctrl_errors=(int_pos_err,), |
33 | | - ... ctrl_freq=100, |
| 32 | + >>> from drone_controllers import parametrize |
| 33 | + >>> from drone_controllers.mellinger import state2attitude |
| 34 | + >>> controller = parametrize(state2attitude, drone_model="cf2x_L250") |
| 35 | + >>> command_rpyt, int_pos_err = controller( |
| 36 | + ... pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, ctrl_freq=100 |
34 | 37 | ... ) |
35 | 38 |
|
36 | 39 | Returns: |
37 | | - The parametrized controller function with all keyword argument only parameters filled in. |
| 40 | + The parametrized controller function with all keyword-only parameters filled in. |
38 | 41 | """ |
39 | | - controller_id = fn.__module__ + "." + fn.__name__ |
| 42 | + xp = np if xp is None else xp |
| 43 | + controller = fn.__module__.split(".")[-2] |
| 44 | + fn_name = fn.__name__ |
40 | 45 | try: |
41 | | - params = controller_parameter_registry[controller_id].load(drone_model) |
| 46 | + sig = inspect.signature(fn) |
| 47 | + kwonly_params = [ |
| 48 | + name |
| 49 | + for name, param in sig.parameters.items() |
| 50 | + if param.kind == inspect.Parameter.KEYWORD_ONLY |
| 51 | + ] |
| 52 | + params = load_params(controller, fn_name, drone_model, xp=xp) |
| 53 | + params = {k: xp.asarray(v, device=device) for k, v in params.items() if k in kwonly_params} |
42 | 54 | except KeyError as e: |
43 | 55 | raise KeyError( |
44 | | - f"Controller `{controller_id}` does not exist in the parameter registry" |
| 56 | + f"Drone model `{drone_model}` not found for controller `{controller}.{fn_name}`" |
45 | 57 | ) from e |
46 | | - except ValueError as e: |
47 | | - raise ValueError(f"Drone model `{drone_model}` not supported for `{fn.__name__}`") from e |
48 | | - return partial(fn, **params._asdict()) |
49 | | - |
50 | | - |
51 | | -@runtime_checkable |
52 | | -class ControllerParams(Protocol): |
53 | | - """Protocol for controller parameters.""" |
| 58 | + return partial(fn, **params) |
54 | 59 |
|
55 | | - @staticmethod |
56 | | - def load(drone_model: str) -> ControllerParams: |
57 | | - """Load the parameters from the config file.""" |
58 | 60 |
|
59 | | - def _asdict(self) -> dict[str, Any]: |
60 | | - """Convert the parameters to a dictionary.""" |
| 61 | +def load_params( |
| 62 | + controller: str, fn_name: str, drone_model: str, xp: ModuleType | None = None |
| 63 | +) -> dict[str, Any]: |
| 64 | + """Load and merge parameters for a controller function and drone model. |
61 | 65 |
|
62 | | - |
63 | | -def register_controller_parameters( |
64 | | - params: ControllerParams | type[ControllerParams], |
65 | | -) -> Callable[[Callable[P, R]], Callable[P, R]]: |
66 | | - """Register the default controller parameters for this controller. |
67 | | -
|
68 | | - Warning: |
69 | | - The controller parameters **must** be a named tuple with a function `load` that takes in the |
70 | | - drone model name and returns an instance of itself, or a class that implements the |
71 | | - ControllerParams protocol. |
| 66 | + Reads parameters from the controller's ``params.toml`` file, merging the |
| 67 | + shared ``[drone_model.core]`` section with the function-specific |
| 68 | + ``[drone_model.{fn_name}]`` section (if it exists). |
72 | 69 |
|
73 | 70 | Args: |
74 | | - params: The controller parameter type. |
| 71 | + controller: Name of the controller sub-package, e.g. ``"mellinger"``. |
| 72 | + fn_name: Name of the controller function, e.g. ``"state2attitude"``. |
| 73 | + drone_model: Name of the drone configuration, e.g. ``"cf2x_L250"``. |
| 74 | + xp: Array API module used to convert parameter values. If ``None``, |
| 75 | + NumPy is used. |
75 | 76 |
|
76 | 77 | Returns: |
77 | | - A decorator function that registers the parameters and returns the function unchanged. |
78 | | - """ |
79 | | - if not isinstance(params, ControllerParams): |
80 | | - raise ValueError(f"{params} does not implement the ControllerParams protocol") |
| 78 | + A flat dict mapping parameter names to arrays in the requested array namespace. |
81 | 79 |
|
82 | | - def decorator(fn: Callable[P, R]) -> Callable[P, R]: |
83 | | - controller_id = fn.__module__ + "." + fn.__name__ |
84 | | - if controller_id in controller_parameter_registry: |
85 | | - raise ValueError(f"Controller `{controller_id}` already registered") |
86 | | - controller_parameter_registry[controller_id] = params |
87 | | - return fn |
88 | | - |
89 | | - return decorator |
| 80 | + Raises: |
| 81 | + KeyError: If ``drone_model`` is not found in the TOML file. |
| 82 | + """ |
| 83 | + xp = np if xp is None else xp |
| 84 | + with open(Path(__file__).parent / f"{controller}/params.toml", "rb") as f: |
| 85 | + all_params = tomllib.load(f) |
| 86 | + if drone_model not in all_params: |
| 87 | + raise KeyError(f"Drone model `{drone_model}` not found in {controller}/params.toml") |
| 88 | + drone_params = all_params[drone_model] |
| 89 | + params = dict(drone_params.get("core", {})) |
| 90 | + params |= drone_params.get(fn_name, {}) |
| 91 | + return {k: xp.asarray(v) for k, v in params.items()} |
0 commit comments