@@ -64,7 +64,11 @@ def parametrize(
6464
6565
6666def load_params (
67- controller : str , fn_name : str , drone_model : str , xp : ModuleType | None = None
67+ controller : str ,
68+ fn_name : str ,
69+ drone_model : str ,
70+ xp : ModuleType | None = None ,
71+ device : str | None = None ,
6872) -> dict [str , Any ]:
6973 """Load and merge controller parameters for a specific function.
7074
@@ -77,6 +81,7 @@ def load_params(
7781 fn_name: Name of the controller function, e.g. ``"state2attitude"``.
7882 drone_model: Name of the drone configuration, e.g. ``"cf2x_L250"``.
7983 xp: The array API module to use. If not provided, numpy is used.
84+ device: The device to use. If None, the device is inferred from the xp module.
8085
8186 Returns:
8287 A flat dict mapping parameter names to arrays in the requested array namespace.
@@ -91,4 +96,4 @@ def load_params(
9196 raise KeyError (f"Drone model `{ drone_model } ` not found in { controller } /params.toml" )
9297 model_params = params [drone_model ]
9398 merged = model_params .get ("core" , {}) | model_params .get (fn_name , {})
94- return {k : xp .asarray (v ) for k , v in merged .items ()}
99+ return {k : xp .asarray (v , device = device ) for k , v in merged .items ()}
0 commit comments