Skip to content

Commit 027ac43

Browse files
committed
phase in child class
1 parent c4f5003 commit 027ac43

4 files changed

Lines changed: 79 additions & 26 deletions

File tree

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
3+
CONTROL_DT = 0.02 # Control loop frequency in seconds
4+
GAIT_FREQUENCY = 1.5 # Gait frequency in Hz
5+
6+
7+
class TimerPhaseConfg:
8+
_phase: np.ndarray = np.array([0.0, np.pi], dtype=np.float32)
9+
_phase_dt: float
10+
11+
def __init__(self):
12+
# Phase time
13+
self._control_dt = CONTROL_DT
14+
self._gait_frequency = GAIT_FREQUENCY
15+
self._phase_dt = 2 * np.pi * GAIT_FREQUENCY * CONTROL_DT
16+
17+
def get_control_dt(self):
18+
return self._control_dt
19+
20+
def get_gait_frequency(self):
21+
return self._gait_frequency
22+
23+
def get_phase(self):
24+
return self._phase
25+
26+
def get_phase_dt(self):
27+
return self._phase_dt

src/bitbots_rl_walk/handler/joint_handler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,16 @@ class JointHandler(Handler):
5353
def __init__(self, ordered_relevant_joint_names=ORDERED_RELEVANT_JOINT_NAMES, walkready_state=WALKREADY_STATE):
5454
self._ordered_relevant_joint_names = ordered_relevant_joint_names
5555
self._walkready_state = walkready_state
56+
self._previous_action: np.ndarray = np.zeros(len(self._ordered_relevant_joint_names), dtype=np.float32)
5657
self._joint_state = None
58+
self._obs_phase = None
59+
self._phase = None
60+
61+
def set_obs_phase(self, phase):
62+
self._obs_phase = phase
63+
64+
def set_phase(self, phase):
65+
self._phase = phase
5766

5867
def get_angle_data(self):
5968
joint_angles = (
@@ -92,6 +101,8 @@ def get_walkready_joint_command(self, timestamp):
92101
joint_command.header.stamp = timestamp.to_msg()
93102
joint_command.positions = self._walkready_state
94103

104+
self._previous_action = joint_command
105+
95106
return joint_command
96107

97108
def get_joint_commands(self, onnx_pred):
@@ -103,7 +114,18 @@ def get_joint_commands(self, onnx_pred):
103114
joint_command.accelerations = [-1.0] * len(self._ordered_relevant_joint_names)
104115
joint_command.max_currents = [-1.0] * len(self._ordered_relevant_joint_names)
105116

117+
self._previous_action = joint_command
118+
106119
return joint_command
107120

121+
def get_previous_action(self):
122+
return self._previous_action
123+
124+
def get_obs_phase(self):
125+
return self._obs_phase
126+
127+
def get_phase(self):
128+
return self._phase
129+
108130
def joint_state_callback(self, msg):
109131
self._joint_state = msg

src/bitbots_rl_walk/nodes/rl_node.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@
3131
class RLNode(Node):
3232
"""Node to control the wolfgang humanoid."""
3333

34-
# TODO: _previous_action: np.ndarray = np.zeros(len(ORDERED_RELEVANT_JOINT_NAMES), dtype=np.float32)
35-
_phase: np.ndarray = np.array([0.0, np.pi], dtype=np.float32)
36-
_phase_dt: float
37-
3834
class PublisherParam(NamedTuple):
3935
msg_type: int
4036
topic: str
@@ -63,9 +59,7 @@ def __init__(self, path_to_model):
6359
for out in self._onnx_model.graph.output:
6460
self._onnx_output_name.append(out)
6561

66-
# Phase time
67-
self._phase_dt = 2 * np.pi * GAIT_FREQUENCY * CONTROL_DT
68-
62+
# TODO: Move timer to child class
6963
self._timer = self.create_timer(CONTROL_DT, self._timer_callback)
7064

7165
self.load_phase()
@@ -76,6 +70,9 @@ def __init__(self, path_to_model):
7670
if type(value) is Subscription:
7771
self._subs.append(key)
7872

73+
self._obs = None # should be defined in subclass
74+
self._timer_phase_confg = None # Should be defined in subclass
75+
7976
# TODO: fix
8077
def _timer_callback(self):
8178
for subscription in self._subs:
@@ -90,17 +87,22 @@ def _timer_callback(self):
9087

9188
# TODO consider IMU mounting offset
9289

93-
self._obs_phase = np.array([np.cos(self._phase), np.sin(self._phase)], dtype=np.float32).flatten()
90+
self._timer_phase_confg.set_obs_phase(
91+
np.array(
92+
[np.cos(self._timer_phase_confg.get_phase()), np.sin(self._timer_phase_confg.get_phase())],
93+
dtype=np.float32,
94+
).flatten()
95+
)
9496

9597
# Run the ONNX model
96-
onnx_input = {self._onnx_input_name[0]: self.obs().reshape(1, -1)} # TODO: Improve input
98+
onnx_input = {self._onnx_input_name[0]: self._obs.reshape(1, -1)} # TODO: Improve input
9799
onnx_pred = self._onnx_session.run(self._onnx_output_name, onnx_input)[0][0]
98100
self._previous_action = onnx_pred
99101

100102
self.publisher(onnx_pred)
101103

102-
phase_tp1 = self._phase + self._phase_dt
103-
self._phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
104+
phase_tp1 = self._timer_phase_confg.get_phase() + self._timer_phase_confg.get_phase_dt()
105+
self._timer_phase_confg.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
104106

105107
def obs():
106108
# Should be defined in subclass

src/bitbots_rl_walk/nodes/walk_node.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sensor_msgs.msg import Imu, JointState
99

1010
from bitbots_msgs.msg import JointCommand
11+
from bitbots_rl_walk.confg.timer_phase_confg import TimerPhaseConfg
1112
from bitbots_rl_walk.handler.joint_handler import JointHandler
1213
from bitbots_rl_walk.nodes.rl_node import RLNode
1314

@@ -28,33 +29,34 @@ def __init__(self, walk_policy_path):
2829
self._joint_handler = JointHandler()
2930
self._command_handler = CommandHandler()
3031

31-
super().__init__(walk_policy_path)
32+
self._timer_phase_confg = TimerPhaseConfg()
3233

33-
def _imu_callback(self, msg):
34-
self._gyro_handler.imu_callback(msg)
35-
self._gravity_handler.imu_callback(msg)
34+
# TODO: timer is missing
3635

37-
def _joint_state_callback(self, msg):
38-
self._joint_handler.joint_state_callback(msg)
39-
40-
def _cmd_vel_callback(self, msg):
41-
self._command_handler.cmd_vel_callback(msg)
42-
43-
def obs(self):
44-
obs = np.hstack(
36+
self._obs = np.hstack(
4537
[
4638
self._gyro_handler.get_data(), # 3
4739
self._gravity_handler.get_data(), # 4
4840
self._command_handler.get_data(), # 3
4941
self._joint_handler.get_velocity_data(), # 18
5042
self._joint_handler.get_angle_data(), # 18
5143
# TODO: fix
52-
self._previous_action, # 18 # Previous action
53-
self._obs_phase, # 2
44+
self._joint_handler.get_previous_action(), # 18 # Previous action
45+
self._timer_phase_confg.get_obs_phase(), # 2
5446
]
5547
).astype(np.float32)
5648

57-
return obs
49+
super().__init__(walk_policy_path)
50+
51+
def _imu_callback(self, msg):
52+
self._gyro_handler.imu_callback(msg)
53+
self._gravity_handler.imu_callback(msg)
54+
55+
def _joint_state_callback(self, msg):
56+
self._joint_handler.joint_state_callback(msg)
57+
58+
def _cmd_vel_callback(self, msg):
59+
self._command_handler.cmd_vel_callback(msg)
5860

5961
def load_phase(self):
6062
walkready_command = self._joint_handler.get_walkready_joint_command()

0 commit comments

Comments
 (0)