Skip to content

Commit 4a937ac

Browse files
committed
fix phase
1 parent 6640b4b commit 4a937ac

4 files changed

Lines changed: 16 additions & 6 deletions

File tree

src/bitbots_motion/bitbots_rl_motion/bitbots_rl_motion/phase.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,22 @@ class PhaseObject:
88
_phase_dt: float
99

1010
def __init__(self, config):
11-
1211
self._control_dt = config["phase"]["control_dt"]
1312
self._gait_frequency = config["phase"]["gait_frequency"]
1413
self._phase_dt = 2 * np.pi * self._gait_frequency * self._control_dt
14+
self._obs_phase = None
1515

1616
def set_phase(self, new_phase):
1717
self._phase = new_phase
1818

19+
def set_obs_phase(self, new_obs_phase):
20+
self._obs_phase = new_obs_phase
21+
1922
def get_phase(self):
2023
return self._phase
2124

2225
def get_phase_dt(self):
2326
return self._phase_dt
27+
28+
def get_obs_phase(self):
29+
return self._obs_phase

src/bitbots_motion/bitbots_rl_motion/nodes/kick_node.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,20 @@ def _ball_pos_callback(self, msg):
4646

4747
# observations
4848
def obs(self):
49-
return np.hstack(
49+
observation = np.hstack(
5050
[
5151
self._gyro_handler.get_gyro(),
5252
self._gravity_handler.get_gravity(),
5353
self._joint_handler.get_velocity_data(),
5454
self._joint_handler.get_angle_data(),
5555
self._previous_action.get_previous_action(),
56-
self._phase.get_phase(),
56+
self._phase.get_obs_phase(),
5757
self._ball_handler.get_ball_pos(),
5858
]
5959
).astype(np.float32)
6060

61+
return observation
62+
6163
# load phase function
6264
def load_phase(self):
6365
pass

src/bitbots_motion/bitbots_rl_motion/nodes/rl_node.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,17 @@ def _timer_callback(self):
6262

6363
# TODO consider IMU mounting offset
6464

65-
self._phase.set_phase(
65+
self._phase.set_obs_phase(
6666
np.array(
6767
[np.cos(self._phase.get_phase()), np.sin(self._phase.get_phase())],
6868
dtype=np.float32,
6969
).flatten()
7070
)
7171

72+
observation = self.obs()
73+
7274
# Run the ONNX model
73-
onnx_input = {self._onnx_input_name[0]: self.obs().reshape(1, -1)} # TODO: Improve input
75+
onnx_input = {self._onnx_input_name[0]: observation.reshape(1, -1)} # TODO: Improve input
7476
onnx_pred = self._onnx_session.run(self._onnx_output_name, onnx_input)[0][0]
7577
self._previous_action.set_previous_action(onnx_pred)
7678

src/bitbots_motion/bitbots_rl_motion/nodes/walk_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def obs(self):
5656
self._joint_handler.get_velocity_data(),
5757
self._joint_handler.get_angle_data(),
5858
self._previous_action.get_previous_action(),
59-
self._phase.get_phase(),
59+
self._phase.get_obs_phase(),
6060
]
6161
).astype(np.float32)
6262

0 commit comments

Comments
 (0)