@@ -18,7 +18,7 @@ def state_control(data: SimData, controls: Array) -> SimData:
1818 """State control function."""
1919 assert data .controls .mode == Control .state , f"control type { data .controls .mode } not enabled"
2020 assert controls .shape == (data .core .n_worlds , data .core .n_drones , 13 ), "controls shape mismatch"
21- controls = to_device (controls , data .core .steps . device )
21+ controls = to_device (controls , data .core .device )
2222 data = data .replace (
2323 controls = data .controls .replace (state = data .controls .state .replace (staged_cmd = controls ))
2424 )
@@ -36,7 +36,7 @@ def attitude_control(data: SimData, controls: Array) -> SimData:
3636 """
3737 assert data .controls .mode == Control .attitude , f"control type { data .controls .mode } not enabled"
3838 assert controls .shape == (data .core .n_worlds , data .core .n_drones , 4 ), "controls shape mismatch"
39- controls = to_device (controls , data .core .steps . device )
39+ controls = to_device (controls , data .core .device )
4040 data = data .replace (
4141 controls = data .controls .replace (attitude = data .controls .attitude .replace (staged_cmd = controls ))
4242 )
@@ -49,7 +49,7 @@ def force_torque_control(data: SimData, controls: Array) -> SimData:
4949 f"control type { data .controls .mode } not enabled"
5050 )
5151 assert controls .shape == (data .core .n_worlds , data .core .n_drones , 4 ), "controls shape mismatch"
52- controls = to_device (controls , data .core .steps . device )
52+ controls = to_device (controls , data .core .device )
5353 data = data .replace (
5454 controls = data .controls .replace (
5555 force_torque = data .controls .force_torque .replace (staged_cmd = controls )
@@ -65,7 +65,7 @@ def rotor_vel_control(data: SimData, controls: Array) -> SimData:
6565 """
6666 assert data .controls .mode == Control .rotor_vel , f"control type { data .controls .mode } not enabled"
6767 assert controls .shape == (data .core .n_worlds , data .core .n_drones , 4 ), "controls shape mismatch"
68- controls = to_device (controls , data .core .steps . device )
68+ controls = to_device (controls , data .core .device )
6969 return data .replace (controls = data .controls .replace (rotor_vel = controls ))
7070
7171
@@ -81,7 +81,7 @@ def controllable(data: SimData) -> Array:
8181 control_steps = controls .force_torque .steps
8282 control_freq = controls .force_torque .freq
8383 case Control .rotor_vel :
84- return jnp .ones ((data .core .n_worlds , 1 ), dtype = bool , device = data .core .steps . device )
84+ return jnp .ones ((data .core .n_worlds , 1 ), dtype = bool , device = data .core .device )
8585 case _:
8686 raise NotImplementedError (f"Control mode { data .controls .mode } not implemented" )
8787 return _controllable (data .core .steps , data .core .freq , control_steps , control_freq )
0 commit comments