Skip to content

Commit e022d72

Browse files
committed
Improve tests
1 parent f179ca9 commit e022d72

3 files changed

Lines changed: 31 additions & 6 deletions

File tree

biped_walking_controller/preview_control.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class State(Enum):
301301

302302

303303
@dataclass
304-
class WalkingFSMParams:
304+
class WalkingStateMachineParams:
305305
t_init: float = 2.0 # [s]
306306
t_end: float = 2.0 # [s]
307307
t_ss: float = 0.8 # [s]
@@ -310,7 +310,7 @@ class WalkingFSMParams:
310310

311311

312312
class WalkingStateMachine:
313-
def __init__(self, params: WalkingFSMParams, initial_state=State.INIT):
313+
def __init__(self, params: WalkingStateMachineParams, initial_state=State.INIT):
314314
self.params = params
315315
self.state = initial_state
316316
self.next_ss_state = State.SS_RIGHT

examples/example_5_walking_controller.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
compute_zmp_ref,
2323
cubic_spline_interpolation,
2424
WalkingStateMachine,
25-
WalkingFSMParams,
25+
WalkingStateMachineParams,
2626
State,
2727
)
2828

@@ -183,7 +183,9 @@ def main():
183183
traj_generator=BezierCurveFootPathGenerator(max_height_foot),
184184
)
185185

186-
params = WalkingFSMParams(t_init=t_init, t_end=t_end, t_ss=t_ss, t_ds=t_ds, force_threshold=50)
186+
params = WalkingStateMachineParams(
187+
t_init=t_init, t_end=t_end, t_ss=t_ss, t_ds=t_ds, force_threshold=50
188+
)
187189
state_machine = WalkingStateMachine(params=params, initial_state=State.INIT)
188190

189191
zmp_ref = compute_zmp_ref(

tests/test_preview_control.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from biped_walking_controller.preview_control import (
55
PreviewControllerParams,
66
compute_preview_control_matrices,
7-
WalkingFSMParams,
7+
WalkingStateMachineParams,
88
WalkingStateMachine,
99
State,
1010
)
@@ -100,31 +100,53 @@ def test_gains_change_with_Qe(self):
100100

101101
class TestWalkingPhase(unittest.TestCase):
102102
def setUp(self):
103-
self.params = WalkingFSMParams()
103+
self.params = WalkingStateMachineParams()
104+
self.steps = np.ones([5, 3]) # Five steps
105+
106+
def test_begin_init(self):
107+
wsm = WalkingStateMachine(self.params, initial_state=State.INIT)
108+
wsm.update_steps(self.steps)
109+
110+
wsm.update(t=0.0, rf_contact_force=100.0, lf_contact_force=100.0)
111+
112+
self.assertEqual(wsm.get_current_state(), State.INIT)
113+
114+
def test_init_to_ds(self):
115+
wsm = WalkingStateMachine(self.params, initial_state=State.INIT)
116+
wsm.update_steps(self.steps)
117+
wsm.update(t=0.0, rf_contact_force=100.0, lf_contact_force=100.0)
118+
119+
wsm.update(t=self.params.t_init + 0.1, rf_contact_force=0.0, lf_contact_force=100.0)
120+
121+
self.assertEqual(wsm.get_current_state(), State.SS_RIGHT)
104122

105123
def test_begin_double_support(self):
106124
wsm = WalkingStateMachine(self.params, initial_state=State.DS)
125+
wsm.update_steps(self.steps)
107126

108127
wsm.update(t=0.0, rf_contact_force=100.0, lf_contact_force=100.0)
109128

110129
self.assertEqual(wsm.get_current_state(), State.DS)
111130

112131
def test_switch_to_single_support(self):
113132
wsm = WalkingStateMachine(self.params, initial_state=State.DS)
133+
wsm.update_steps(self.steps)
114134

115135
wsm.update(t=self.params.t_ds + 0.1, rf_contact_force=0.0, lf_contact_force=0.0)
116136

117137
self.assertEqual(wsm.get_current_state(), State.SS_RIGHT)
118138

119139
def test_do_not_switch_to_ds_if_beginning_of_phase(self):
120140
wsm = WalkingStateMachine(self.params, initial_state=State.SS_RIGHT)
141+
wsm.update_steps(self.steps)
121142

122143
wsm.update(t=0.0, rf_contact_force=0.0, lf_contact_force=self.params.force_threshold + 10)
123144

124145
self.assertEqual(wsm.get_current_state(), State.SS_RIGHT)
125146

126147
def test_do_not_switch_to_ds_if_force_too_low(self):
127148
wsm = WalkingStateMachine(self.params, initial_state=State.SS_RIGHT)
149+
wsm.update_steps(self.steps)
128150

129151
wsm.update(
130152
t=self.params.t_ss * 0.75,
@@ -136,6 +158,7 @@ def test_do_not_switch_to_ds_if_force_too_low(self):
136158

137159
def test_switch_to_ds_if_force_too_low_and_phase_close_to_end(self):
138160
wsm = WalkingStateMachine(self.params, initial_state=State.SS_LEFT)
161+
wsm.update_steps(self.steps)
139162

140163
wsm.update(
141164
t=self.params.t_ss * 0.75,

0 commit comments

Comments
 (0)