11from __future__ import annotations
22
33import jax .numpy as jnp
4- from drone_models .controller .mellinger import MellingerStateParams
4+ from drone_models .controller .mellinger import (
5+ MellingerAttitudeParams ,
6+ MellingerForceTorqueParams ,
7+ MellingerStateParams ,
8+ )
59from flax .struct import dataclass , field
610from jax import Array , Device
711
@@ -14,6 +18,8 @@ class MellingerStateData:
1418 A command consists of [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].
1519 We currently do not use the acceleration and angle rate components. This is subject to change.
1620 """
21+ staged_cmd : Array # (N, M, 13)
22+ """Staging buffer to store the most recent command until the next controller tick."""
1723 steps : Array # (N, 1)
1824 """Last simulation steps that the state control command was applied."""
1925 freq : int = field (pytree_node = False )
@@ -29,39 +35,90 @@ def create(
2935 ) -> MellingerStateData :
3036 """Create a default set of state data for the simulation."""
3137 cmd = jnp .zeros ((n_worlds , n_drones , 13 ), device = device )
32- steps = jnp .zeros ((n_worlds , 1 ), dtype = jnp .int32 , device = device )
38+ steps = - jnp .ones ((n_worlds , 1 ), dtype = jnp .int32 , device = device )
3339 pos_err_i = jnp .zeros ((n_worlds , n_drones , 3 ), device = device )
3440 params = MellingerStateParams .load (drone_model )
3541 return MellingerStateData (
36- cmd = cmd , steps = steps , freq = freq , pos_err_i = pos_err_i , params = params
42+ cmd = cmd , staged_cmd = cmd , steps = steps , freq = freq , pos_err_i = pos_err_i , params = params
3743 )
3844
3945
40- # @dataclass
41- # class MellingerAttitudeData:
42- # cmd: Array # (N, M, 4)
43- # """Full attitude control command for the drone.
46+ @dataclass
47+ class MellingerAttitudeData :
48+ cmd : Array # (N, M, 4)
49+ """Full attitude control command for the drone.
50+
51+ A command consists of [roll, pitch, yaw, collective thrust].
52+ """
53+ staged_cmd : Array # (N, M, 4)
54+ """Staging buffer to store the most recent command until the next controller tick."""
55+ steps : Array # (N, 1)
56+ """Last simulation steps that the attitude control command was applied."""
57+ freq : int = field (pytree_node = False )
58+ """Frequency of the attitude control command."""
59+ r_int_error : Array # (N, M, 3)
60+ """Integral errors of the attitude control command."""
61+ last_ang_vel : Array # (N, M, 3)
62+ """Last angular velocity of the drone."""
63+ # Parameters for the attitude controller
64+ params : MellingerAttitudeParams
65+
66+ @staticmethod
67+ def create (
68+ n_worlds : int , n_drones : int , freq : int , drone_model : str , device : Device
69+ ) -> MellingerAttitudeData :
70+ """Create a default set of attitude data for the simulation."""
71+ cmd = jnp .zeros ((n_worlds , n_drones , 4 ), device = device )
72+ steps = - jnp .ones ((n_worlds , 1 ), dtype = jnp .int32 , device = device )
73+ zeros_3d = jnp .zeros ((n_worlds , n_drones , 3 ), device = device )
74+ params = MellingerAttitudeParams .load (drone_model )
75+ return MellingerAttitudeData (
76+ cmd = cmd ,
77+ staged_cmd = cmd ,
78+ steps = steps ,
79+ freq = freq ,
80+ r_int_error = zeros_3d ,
81+ last_ang_vel = zeros_3d ,
82+ params = params ,
83+ )
84+
4485
45- # A command consists of [collective thrust, roll, pitch, yaw].
46- # """
47- # steps: Array # (N, 1)
48- # """Last simulation steps that the attitude control command was applied."""
49- # freq: int = field(pytree_node=False)
50- # """Frequency of the attitude control command."""
51- # pos_err_i: Array # (N, M, 3)
52- # """Integral errors of the attitude control command."""
53- # # Parameters for the attitude controller
54- # params: MellingerAttitudeParams
86+ @dataclass
87+ class MellingerForceTorqueData :
88+ cmd_force : Array # (N, M, 1)
89+ """Force command for the drone.
90+
91+ A command consists of [fz].
92+ """
93+ cmd_torque : Array # (N, M, 3)
94+ """Torque command for the drone.
5595
56- # @staticmethod
57- # def create(
58- # n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device
59- # ) -> MellingerAttitudeData:
60- # """Create a default set of attitude data for the simulation."""
61- # cmd = jnp.zeros((n_worlds, n_drones, 4), device=device)
62- # steps = jnp.zeros((n_worlds, 1), dtype=jnp.int32, device=device)
63- # pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device)
64- # params = MellingerAttitudeParams.load(drone_model)
65- # return MellingerAttitudeData(
66- # cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params
67- # )
96+ A command consists of [tx, ty, tz].
97+ """
98+ staged_cmd_force : Array # (N, M, 1)
99+ staged_cmd_torque : Array # (N, M, 3)
100+ """Staging buffer to store the most recent command until the next controller tick."""
101+ steps : Array # (N, 1)
102+ """Last simulation steps that the force and torque control command was applied."""
103+ freq : int = field (pytree_node = False )
104+ """Frequency of the force and torque control command."""
105+ # Parameters for the force and torque controller
106+ params : MellingerForceTorqueParams
107+
108+ @staticmethod
109+ def create (
110+ n_worlds : int , n_drones : int , freq : int , drone_model : str , device : Device
111+ ) -> MellingerForceTorqueData :
112+ zero_1d = jnp .zeros ((n_worlds , n_drones , 1 ), device = device )
113+ zero_3d = jnp .zeros ((n_worlds , n_drones , 3 ), device = device )
114+ steps = - jnp .ones ((n_worlds , 1 ), dtype = jnp .int32 , device = device )
115+ params = MellingerForceTorqueParams .load (drone_model )
116+ return MellingerForceTorqueData (
117+ cmd_force = zero_1d ,
118+ cmd_torque = zero_3d ,
119+ staged_cmd_force = zero_1d ,
120+ staged_cmd_torque = zero_3d ,
121+ steps = steps ,
122+ freq = freq ,
123+ params = params ,
124+ )
0 commit comments