Skip to content

Commit 1a92e74

Browse files
committed
small bug fixes
1 parent f066501 commit 1a92e74

11 files changed

Lines changed: 129 additions & 110 deletions

File tree

src/bitbots_motion/bitbots_rl_motion/bitbots_rl_motion/phase.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ class PhaseObject:
88
_phase_dt: float
99

1010
def __init__(self, config):
11-
super().__init__(config)
1211

13-
self._control_dt = self._config["phase"]["control_dt"]
14-
self._gait_frequency = self._config["phase"]["gait_frequency"]
12+
self._control_dt = config["phase"]["control_dt"]
13+
self._gait_frequency = config["phase"]["gait_frequency"]
1514
self._phase_dt = 2 * np.pi * self._gait_frequency * self._control_dt
1615

1716
def set_phase(self, new_phase):

src/bitbots_motion/bitbots_rl_motion/bitbots_rl_motion/policy_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def main():
99
rclpy.init()
1010

11-
wolfgang_config = os.path.join(get_package_share_directory("bitbots_rl_motion"), "config", "wolfgang_config.yaml")
11+
wolfgang_config = os.path.join(get_package_share_directory("bitbots_rl_motion"), "configs", "wolfgang_config.yaml")
1212

1313
# walk_node = WalkNode(wolfgang_config)
1414
kick_node = KickNode(wolfgang_config)

src/bitbots_motion/bitbots_rl_motion/handlers/ball_handler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from bitbots_rl_motion.handlers.handler import Handler
1+
from handlers.handler import Handler
22

33

44
class BallHandler(Handler):
@@ -9,5 +9,9 @@ def __init__(self, config):
99
def ball_pos_callback(self, msg):
1010
self.ball_pos = msg
1111

12+
1213
def get_ball_pos(self):
13-
return self.ball_pos
14+
try:
15+
return self.ball_pos
16+
except:
17+
return None
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from bitbots_rl_motion.handlers.handler import Handler
2+
from handlers.handler import Handler
33

44

55
class CommandHandler(Handler):
@@ -9,8 +9,12 @@ def __init__(self, config):
99
self._cmd_vel = None
1010

1111
def get_command(self):
12-
command = np.array([self._cmd_vel.linear.x, self._cmd_vel.linear.y, self._cmd_vel.angular.z], dtype=np.float32)
13-
return command
12+
try:
13+
command = np.array([self._cmd_vel.linear.x, self._cmd_vel.linear.y, self._cmd_vel.angular.z], dtype=np.float32)
14+
return command
15+
except:
16+
return None
1417

1518
def cmd_vel_callback(self, msg):
1619
self._cmd_vel = msg
20+

src/bitbots_motion/bitbots_rl_motion/handlers/gravity_handler.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from bitbots_rl_motion.handlers.handler import Handler
2+
from handlers.handler import Handler
33
from transforms3d.euler import euler2mat
44
from transforms3d.quaternions import quat2mat
55

@@ -15,17 +15,21 @@ def __init__(self, config):
1515
def imu_callback(self, msg):
1616
self._imu_data = msg
1717

18+
1819
def get_gravity(self):
19-
gravity = (
20-
quat2mat(
21-
[
22-
self._imu_data.orientation.w,
23-
self._imu_data.orientation.x,
24-
self._imu_data.orientation.y,
25-
self._imu_data.orientation.z,
26-
]
27-
)
28-
@ euler2mat(0, -0.0, 0)
29-
).T @ np.array([0, 0, -1], dtype=np.float32)
20+
try:
21+
gravity = (
22+
quat2mat(
23+
[
24+
self._imu_data.orientation.w,
25+
self._imu_data.orientation.x,
26+
self._imu_data.orientation.y,
27+
self._imu_data.orientation.z,
28+
]
29+
)
30+
@ euler2mat(0, -0.0, 0)
31+
).T @ np.array([0, 0, -1], dtype=np.float32)
32+
return gravity
33+
except:
34+
return None
3035

31-
return gravity
Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from bitbots_rl_motion.handlers.handler import Handler
2+
from handlers.handler import Handler
33

44

55
class GyroHandler(Handler):
@@ -10,15 +10,19 @@ def __init__(self, config):
1010
# Callables
1111
def imu_callback(self, msg):
1212
self._imu_data = msg
13+
1314

1415
def get_gyro(self):
15-
gyro = np.array(
16-
[
17-
self._imu_data.angular_velocity.x,
18-
self._imu_data.angular_velocity.y,
19-
self._imu_data.angular_velocity.z,
20-
],
21-
dtype=np.float32,
22-
)
16+
try:
17+
gyro = np.array(
18+
[
19+
self._imu_data.angular_velocity.x,
20+
self._imu_data.angular_velocity.y,
21+
self._imu_data.angular_velocity.z,
22+
],
23+
dtype=np.float32,
24+
)
2325

24-
return gyro
26+
return gyro
27+
except:
28+
return None
Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from bitbots_rl_motion.handlers.handler import Handler
2+
from handlers.handler import Handler
33

44
from bitbots_msgs.msg import JointCommand
55

@@ -12,77 +12,77 @@ def __init__(self, config):
1212
self._walkready_state = self._config["joints"]["walkready_state"]
1313
self._previous_action: np.ndarray = np.zeros(len(self._ordered_relevant_joint_names), dtype=np.float32)
1414
self._joint_state = None
15-
self._obs_phase = None
16-
self._phase = None
15+
1716

1817
def joint_state_callback(self, msg):
1918
self._joint_state = msg
2019

21-
def set_obs_phase(self, phase):
22-
self._obs_phase = phase
20+
def get_angle_data(self):
21+
try:
22+
joint_angles = (
23+
np.array(
24+
[
25+
self._joint_state.position[self._joint_state.name.index(name)]
26+
for name in self._ordered_relevant_joint_names
27+
],
28+
dtype=np.float32,
29+
)
30+
- self._walkready_state
31+
)
2332

24-
def set_phase(self, phase):
25-
self._phase = phase
33+
return joint_angles
34+
except:
35+
return None
2636

27-
def get_angle_data(self):
28-
joint_angles = (
29-
np.array(
37+
def get_velocity_data(self):
38+
try:
39+
joint_velocities = np.array(
3040
[
31-
self._joint_state.position[self._joint_state.name.index(name)]
41+
self._joint_state.velocity[self._joint_state.name.index(name)]
3242
for name in self._ordered_relevant_joint_names
3343
],
3444
dtype=np.float32,
3545
)
36-
- self._walkready_state
37-
)
3846

39-
return joint_angles
40-
41-
def get_velocity_data(self):
42-
joint_velocities = np.array(
43-
[
44-
self._joint_state.velocity[self._joint_state.name.index(name)]
45-
for name in self._ordered_relevant_joint_names
46-
],
47-
dtype=np.float32,
48-
)
49-
50-
return joint_velocities
47+
return joint_velocities
48+
except:
49+
return None
5150

5251
def get_data(self):
5352
return self.get_angle_data(), self.get_velocity_data()
5453

5554
def get_walkready_joint_command(self, timestamp):
56-
joint_command = JointCommand()
57-
joint_command.joint_names = self._ordered_relevant_joint_names
58-
joint_command.velocities = [0.2] * len(self._ordered_relevant_joint_names)
59-
joint_command.accelerations = [-1.0] * len(self._ordered_relevant_joint_names)
60-
joint_command.max_currents = [-1.0] * len(self._ordered_relevant_joint_names) # -1.0 means no limit
61-
joint_command.header.stamp = timestamp.to_msg()
62-
joint_command.positions = self._walkready_state
55+
try:
56+
joint_command = JointCommand()
57+
joint_command.joint_names = self._ordered_relevant_joint_names
58+
joint_command.velocities = [0.2] * len(self._ordered_relevant_joint_names)
59+
joint_command.accelerations = [-1.0] * len(self._ordered_relevant_joint_names)
60+
joint_command.max_currents = [-1.0] * len(self._ordered_relevant_joint_names) # -1.0 means no limit
61+
joint_command.header.stamp = timestamp.to_msg()
62+
joint_command.positions = self._walkready_state
6363

64-
self._previous_action = joint_command
64+
self._previous_action = joint_command
6565

66-
return joint_command
66+
return joint_command
67+
except:
68+
return None
6769

6870
def get_joint_commands(self, onnx_pred):
69-
joint_command = JointCommand()
70-
joint_command.header.stamp = self._joint_state.header.stamp
71-
joint_command.joint_names = self._ordered_relevant_joint_names
72-
joint_command.positions = onnx_pred * 0.5 + self._walkready_state
73-
joint_command.velocities = [-1.0] * len(self._ordered_relevant_joint_names)
74-
joint_command.accelerations = [-1.0] * len(self._ordered_relevant_joint_names)
75-
joint_command.max_currents = [-1.0] * len(self._ordered_relevant_joint_names)
71+
try:
72+
joint_command = JointCommand()
73+
joint_command.header.stamp = self._joint_state.header.stamp
74+
joint_command.joint_names = self._ordered_relevant_joint_names
75+
joint_command.positions = onnx_pred * 0.5 + self._walkready_state
76+
joint_command.velocities = [-1.0] * len(self._ordered_relevant_joint_names)
77+
joint_command.accelerations = [-1.0] * len(self._ordered_relevant_joint_names)
78+
joint_command.max_currents = [-1.0] * len(self._ordered_relevant_joint_names)
7679

77-
self._previous_action = joint_command
80+
self._previous_action = joint_command
7881

79-
return joint_command
82+
return joint_command
83+
except:
84+
return None
8085

8186
def get_previous_action(self):
8287
return self._previous_action
8388

84-
def get_obs_phase(self):
85-
return self._obs_phase
86-
87-
def get_phase(self):
88-
return self._phase

src/bitbots_motion/bitbots_rl_motion/nodes/kick_node.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
2-
from bitbots_rl_motion.nodes.rl_node import RLNode
3-
from bitbots_rl_motion.handlers.joint_handler import JointHandler
4-
from bitbots_rl_motion.handlers.ball_handler import BallHandler
2+
from nodes.rl_node import RLNode
3+
from handlers.joint_handler import JointHandler
4+
from handlers.ball_handler import BallHandler
55
from geometry_msgs.msg import PoseStamped
66
from handlers.gravity_handler import GravityHandler
77
from handlers.gyro_handler import GyroHandler
@@ -12,11 +12,7 @@
1212

1313
class KickNode(RLNode):
1414
def __init__(self, config_path: str):
15-
super().__init__(config_path)
16-
17-
# loading model
18-
model = self._config["models"]["kick_model"]
19-
self.load_model(model)
15+
super().__init__(config_path, "kick_node")
2016

2117
# publishers
2218
self._joint_command_pub = self.create_publisher(JointCommand, "walking_motor_goals", 10)
@@ -46,6 +42,10 @@ def __init__(self, config_path: str):
4642
]
4743
).astype(np.float32)
4844

45+
# loading model
46+
model = self._config["models"]["kick_model"]
47+
self.load_model(model)
48+
4949
# callback functions
5050

5151
def _imu_callback(self, msg):

src/bitbots_motion/bitbots_rl_motion/nodes/rl_node.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import onnxruntime as rt
2424
import yaml
2525
from ament_index_python import get_package_share_directory
26-
from bitbots_rl_motion.bitbots_rl_motion.phase import PhaseObject
26+
from bitbots_rl_motion.phase import PhaseObject
2727
from rclpy.node import Node
2828
from rclpy.qos import QoSProfile
2929
from rclpy.subscription import Subscription
@@ -43,20 +43,13 @@ class SubscriptionParam(NamedTuple):
4343
callback: Callable
4444
qos_profile: int | QoSProfile
4545

46-
def __init__(self, config_path: str):
46+
def __init__(self, config_path: str, node_name: str):
47+
super().__init__(f"{node_name}")
48+
4749
self._config = self._load_config(config_path)
4850
self._phase = PhaseObject(self._config)
4951
self._obs = None # should be defined in subclass
5052

51-
self._timer = self.create_timer(self._config["phase"]["control_dt"], self._timer_callback)
52-
self.load_phase()
53-
54-
self._subs = []
55-
56-
for key, value in self.__dict__.values():
57-
if type(value) is Subscription:
58-
self._subs.append(key)
59-
6053
def _load_config(self, path: str):
6154
with open(path) as f:
6255
return yaml.safe_load(f)
@@ -101,12 +94,9 @@ def load_model(self, model):
10194
path_to_model = os.path.join(get_package_share_directory("bitbots_rl_motion"), "models", model)
10295

10396
self._onnx_model_path = Path(path_to_model)
104-
model_name = self._onnx_model_path.stem
105-
106-
super().__init__(f"{model_name}")
10797

10898
# Load the ONNX model
109-
self._onnx_session = rt.InferenceSession(self._onnx_model_path, self._config["providers"])
99+
self._onnx_session = rt.InferenceSession(self._onnx_model_path, providers=self._config["providers"])
110100
self._onnx_model = onnx.load(self._onnx_model_path)
111101

112102
self._onnx_input_name = []
@@ -117,6 +107,16 @@ def load_model(self, model):
117107
for out in self._onnx_model.graph.output:
118108
self._onnx_output_name.append(out)
119109

110+
self._timer = self.create_timer(self._config["phase"]["control_dt"], self._timer_callback)
111+
112+
self._subs = []
113+
114+
for (key, value) in self.__dict__.values():
115+
if type(value) is Subscription:
116+
self._subs.append(key)
117+
118+
self.load_phase()
119+
120120
def obs(self):
121121
# Should be defined in subclass
122122
pass

0 commit comments

Comments
 (0)