@@ -16,17 +16,18 @@ def test_state_interface(dynamics: Dynamics):
1616 sim = Sim (dynamics = dynamics , control = Control .state )
1717
1818 # Simple P controller for attitude to reach target height
19+ target_height = 0.5
1920 cmd = np .zeros ((1 , 1 , 13 ), dtype = np .float32 )
20- cmd [0 , 0 , 2 ] = 1.0 # Set z position target to 1.0
21+ cmd [0 , 0 , 2 ] = target_height
22+ steps = int (2 * sim .control_freq ) # Run simulation for 2 seconds
2123
22- for _ in range (int (2 * sim .control_freq )): # Run simulation for 2 seconds
24+ for i in range (steps ): # Run simulation for 2 seconds
25+ cmd [..., 2 ] = target_height * i / steps # Linearly interpolate target height
2326 sim .state_control (cmd )
2427 sim .step (sim .freq // sim .control_freq )
25- if np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , 1.0 ])) < 0.1 :
26- break
2728
2829 # Check if drone reached target position
29- distance = np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , 1.0 ]))
30+ distance = np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , target_height ]))
3031 assert distance < 0.1 , f"Failed to reach target height with { dynamics } dynamics"
3132
3233
@@ -35,13 +36,15 @@ def test_state_interface(dynamics: Dynamics):
3536def test_attitude_interface (dynamics : Dynamics ):
3637 sim = Sim (dynamics = dynamics , control = Control .attitude )
3738 target_pos = np .array ([0.0 , 0.0 , 1.0 ])
38- jit_state2attitude = jax .jit (parametrize (state2attitude , drone = "cf2x_L250" ))
39+ jit_state2attitude = jax .jit (parametrize (state2attitude , drone = sim . drone ))
3940
4041 pos_err_i = np .zeros ((1 , 1 , 3 ))
4142 cmd = np .zeros ((1 , 1 , 13 ))
42- cmd [0 , 0 , 2 ] = 1.0 # Set z position target to 1.0
43+ cmd [0 , 0 , 2 ] = 1.0
44+ steps = int (3 * sim .control_freq )
4345
44- for _ in range (int (2 * sim .control_freq )): # Run simulation for 2 seconds
46+ for i in range (steps ):
47+ cmd [..., :3 ] = target_pos * i / steps # Linearly interpolate target position
4548 pos , vel , quat = sim .data .states .pos , sim .data .states .vel , sim .data .states .quat
4649 rpyt , pos_err_i = jit_state2attitude (pos , quat , vel , cmd , pos_err_i , ctrl_freq = 100 )
4750 sim .attitude_control (rpyt )
@@ -77,15 +80,19 @@ def test_rotor_vel_interface():
7780def test_swarm_control (dynamics : Dynamics ):
7881 n_worlds , n_drones = 2 , 3
7982 sim = Sim (n_worlds = n_worlds , n_drones = n_drones , dynamics = dynamics , control = Control .state )
83+ start_pos = np .asarray (sim .data .states .pos )
8084 target_pos = sim .data .states .pos + np .array ([0.3 , 0.3 , 0.3 ])
81-
8285 cmd = np .zeros ((n_worlds , n_drones , 13 ))
83- cmd [..., :3 ] = target_pos
84- sim .state_control (cmd )
85- sim .step (3 * sim .freq )
86- # Check if drone maintained hover position
86+ steps = int (3 * sim .control_freq )
87+
88+ for i in range (steps ):
89+ alpha = i / (steps )
90+ cmd [..., :3 ] = start_pos * (1 - alpha ) + target_pos * alpha
91+ sim .state_control (cmd )
92+ sim .step (sim .freq // sim .control_freq )
93+
8794 max_dist = np .max (np .linalg .norm (sim .data .states .pos - target_pos , axis = - 1 ))
88- assert max_dist < 0.05 , f"Failed to reach target, max dist: { max_dist } "
95+ assert max_dist < 0.08 , f"Failed to reach target, max dist: { max_dist } "
8996
9097
9198@pytest .mark .integration
0 commit comments