Skip to content

Commit b417c51

Browse files
committed
Fix and add example
1 parent 37633f6 commit b417c51

2 files changed

Lines changed: 87 additions & 1 deletion

File tree

examples/attitude.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
3+
import numpy as np
4+
5+
os.environ["SCIPY_ARRAY_API"] = "1"
6+
7+
from scipy.spatial.transform import Rotation as R
8+
9+
from crazyflow.control import Control
10+
from crazyflow.sim import Sim
11+
12+
kp = np.array([0.4, 0.4, 1.25])
13+
ki = np.array([0.05, 0.05, 0.05])
14+
kd = np.array([0.2, 0.2, 0.4])
15+
g = 9.81
16+
17+
18+
def control(
19+
t: float, obs: dict[str, np.ndarray], pos_start: np.ndarray, drone_mass: float
20+
) -> np.ndarray:
21+
des_pos = np.zeros(3)
22+
des_pos[..., :2] = pos_start[:2] + np.array([np.cos(t) - 1, np.sin(t)])
23+
des_pos[..., 2] = 0.2 * t
24+
des_vel = np.zeros_like(des_pos)
25+
des_yaw = t
26+
27+
# Calculate the deviations from the desired trajectory
28+
pos_error = des_pos - np.array(obs["pos"])
29+
vel_error = des_vel - np.array(obs["vel"])
30+
31+
# Compute target thrust
32+
target_thrust = np.zeros(3)
33+
target_thrust += kp * pos_error
34+
target_thrust += kd * vel_error
35+
target_thrust[2] += drone_mass * g
36+
37+
# Update z_axis to the current orientation of the drone
38+
z_axis = R.from_quat(obs["quat"]).as_matrix()[:, 2]
39+
40+
# update current thrust
41+
thrust_desired = target_thrust.dot(z_axis)
42+
43+
# update z_axis_desired
44+
z_axis_desired = target_thrust / np.linalg.norm(target_thrust)
45+
x_c_des = np.array([np.cos(des_yaw), np.sin(des_yaw), 0.0])
46+
y_axis_desired = np.cross(z_axis_desired, x_c_des)
47+
y_axis_desired /= np.linalg.norm(y_axis_desired)
48+
x_axis_desired = np.cross(y_axis_desired, z_axis_desired)
49+
50+
R_desired = np.vstack([x_axis_desired, y_axis_desired, z_axis_desired]).T
51+
euler_desired = R.from_matrix(R_desired).as_euler("xyz", degrees=False)
52+
53+
action = np.concatenate([euler_desired, [thrust_desired]], dtype=np.float32)
54+
55+
return action
56+
57+
58+
def main():
59+
sim = Sim(control=Control.attitude)
60+
sim.reset()
61+
duration = 6.5
62+
fps = 60
63+
64+
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4)) # [roll, pitch, yaw, thrust]
65+
pos_start = sim.data.states.pos
66+
for i in range(int(duration * sim.control_freq)):
67+
obs = {
68+
"pos": sim.data.states.pos[0, 0],
69+
"vel": sim.data.states.vel[0, 0],
70+
"quat": sim.data.states.quat[0, 0],
71+
}
72+
cmd[0, 0, :] = control(
73+
i / sim.control_freq, obs, pos_start[0, 0], sim.data.params.mass[0, 0, 0]
74+
)
75+
sim.attitude_control(cmd)
76+
sim.step(sim.freq // sim.control_freq)
77+
if ((i * fps) % sim.control_freq) < fps:
78+
sim.render()
79+
sim.close()
80+
81+
82+
if __name__ == "__main__":
83+
main()

examples/change_pos.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ def main():
1313
# information on changing JAX arrays, see:
1414
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates
1515
sim.data = sim.data.replace(
16-
states=sim.data.states.replace(pos=sim.data.states.pos.at[0, 0].set(np.array([1, 1, 0.2])))
16+
states=sim.data.states.replace(
17+
pos=sim.data.states.pos.at[0, 0].set(np.array([0.5, 0.5, 0.2])),
18+
rotor_vel=sim.data.states.rotor_vel.at[0, 0].set(np.ones(4) * 20000),
19+
)
1720
)
1821
control = np.zeros((sim.n_worlds, sim.n_drones, 13))
1922
control[..., :3] = np.array([[0.0, 0.0, 0.3]])

0 commit comments

Comments
 (0)