Skip to content

Commit d4f6ed9

Browse files
committed
Add device kwarg to load_params
1 parent 003c97c commit d4f6ed9

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

drone_controllers/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ def parametrize(
6464

6565

6666
def load_params(
67-
controller: str, fn_name: str, drone_model: str, xp: ModuleType | None = None
67+
controller: str,
68+
fn_name: str,
69+
drone_model: str,
70+
xp: ModuleType | None = None,
71+
device: str | None = None,
6872
) -> dict[str, Any]:
6973
"""Load and merge controller parameters for a specific function.
7074
@@ -77,6 +81,7 @@ def load_params(
7781
fn_name: Name of the controller function, e.g. ``"state2attitude"``.
7882
drone_model: Name of the drone configuration, e.g. ``"cf2x_L250"``.
7983
xp: The array API module to use. If not provided, numpy is used.
84+
device: The device to use. If None, the device is inferred from the xp module.
8085
8186
Returns:
8287
A flat dict mapping parameter names to arrays in the requested array namespace.
@@ -91,4 +96,4 @@ def load_params(
9196
raise KeyError(f"Drone model `{drone_model}` not found in {controller}/params.toml")
9297
model_params = params[drone_model]
9398
merged = model_params.get("core", {}) | model_params.get(fn_name, {})
94-
return {k: xp.asarray(v) for k, v in merged.items()}
99+
return {k: xp.asarray(v, device=device) for k, v in merged.items()}

0 commit comments

Comments
 (0)