Skip to content

Commit c1e6e7e

Browse files
committed
refactoring obs
1 parent 5bba14b commit c1e6e7e

4 files changed

Lines changed: 36 additions & 42 deletions

File tree

src/bitbots_motion/bitbots_rl_motion/configs/wolfgang_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
models:
22
walk_model: wolfang_walk_ppo.onnx
3-
kick_model: wolfgang_kick_ppo.onnx
3+
kick_model: wolfgang_forward_kick_better_ball_ppo.onnx
44

55
joints:
66
ordered_relevant_joint_names: [

src/bitbots_motion/bitbots_rl_motion/nodes/kick_node.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class KickNode(RLNode):
1414
def __init__(self, config_path: str):
15-
super().__init__(config_path, "kick_node")
15+
super().__init__(config_path, node_name="kick_node")
1616

1717
# publishers
1818
self._joint_command_pub = self.create_publisher(JointCommand, "walking_motor_goals", 10)
@@ -28,26 +28,11 @@ def __init__(self, config_path: str):
2828
self._joint_handler = JointHandler(self._config)
2929
self._ball_handler = BallHandler(self._config)
3030

31-
# observations
32-
33-
self._obs = np.hstack(
34-
[
35-
self._gyro_handler.get_gyro(),
36-
self._gravity_handler.get_gravity(),
37-
self._joint_handler.get_velocity_data(),
38-
self._joint_handler.get_angle_data(),
39-
self._joint_handler.get_previous_action(),
40-
self._phase.get_phase(),
41-
self._ball_handler.get_ball_pos(),
42-
]
43-
).astype(np.float32)
44-
45-
# loading model
31+
# loading model
4632
model = self._config["models"]["kick_model"]
4733
self.load_model(model)
4834

4935
# callback functions
50-
5136
def _imu_callback(self, msg):
5237
self._gyro_handler.imu_callback(msg)
5338
self._gravity_handler.imu_callback(msg)
@@ -58,13 +43,25 @@ def _joint_state_callback(self, msg):
5843
def _ball_pos_callback(self, msg):
5944
self._ball_handler.ball_pos_callback(msg)
6045

61-
# load phase function
46+
# observations
47+
def obs(self):
48+
return np.hstack(
49+
[
50+
self._gyro_handler.get_gyro(),
51+
self._gravity_handler.get_gravity(),
52+
self._joint_handler.get_velocity_data(),
53+
self._joint_handler.get_angle_data(),
54+
self._joint_handler.get_previous_action(),
55+
self._phase.get_phase(),
56+
self._ball_handler.get_ball_pos(),
57+
]
58+
).astype(np.float32)
6259

60+
# load phase function
6361
def load_phase(self):
6462
pass
6563

6664
# publisher function
67-
6865
def publisher(self, onnx_pred):
6966
joint_command = self._joint_handler.get_joint_commands(onnx_pred)
7067
self._joint_command_pub.publish(joint_command)

src/bitbots_motion/bitbots_rl_motion/nodes/rl_node.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import os
1818
from pathlib import Path
19-
from typing import Callable, NamedTuple
2019

2120
from abc import ABC, abstractmethod
2221
import numpy as np
@@ -26,7 +25,6 @@
2625
from ament_index_python import get_package_share_directory
2726
from bitbots_rl_motion.phase import PhaseObject
2827
from rclpy.node import Node
29-
from rclpy.qos import QoSProfile
3028
from rclpy.subscription import Subscription
3129

3230
from handlers.handler import Handler
@@ -40,7 +38,6 @@ def __init__(self, config_path: str, node_name: str):
4038

4139
self._config = self._load_config(config_path)
4240
self._phase = PhaseObject(self._config)
43-
self._obs = None # should be defined in subclass
4441

4542
def _load_config(self, path: str):
4643
with open(path) as f:
@@ -64,7 +61,7 @@ def _timer_callback(self):
6461
)
6562

6663
# Run the ONNX model
67-
onnx_input = {self._onnx_input_name[0]: self._obs.reshape(1, -1)} # TODO: Improve input
64+
onnx_input = {self._onnx_input_name[0]: self.obs().reshape(1, -1)} # TODO: Improve input
6865
onnx_pred = self._onnx_session.run(self._onnx_output_name, onnx_input)[0][0]
6966
self._previous_action = onnx_pred
7067

@@ -114,6 +111,9 @@ def publisher(self, action):
114111
def load_phase(self):
115112
pass
116113

114+
@abstractmethod
115+
def obs(self):
116+
pass
117117

118118
class ConfigError(Exception):
119119
pass

src/bitbots_motion/bitbots_rl_motion/nodes/walk_node.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class WalkNode(RLNode):
1616
def __init__(self, config_path: str):
17-
super().__init__(config_path, "walk_node")
17+
super().__init__(config_path, node_name="walk_node")
1818

1919
# publishers
2020
self._joint_command_pub = self.create_publisher(JointCommand, "walking_motor_goals", 10)
@@ -30,26 +30,11 @@ def __init__(self, config_path: str):
3030
self._joint_handler = JointHandler(self._config)
3131
self._command_handler = CommandHandler(self._config)
3232

33-
# observations
34-
35-
self._obs = np.hstack(
36-
[
37-
self._gyro_handler.get_gyro(),
38-
self._gravity_handler.get_gravity(),
39-
self._command_handler.get_command(),
40-
self._joint_handler.get_velocity_data(),
41-
self._joint_handler.get_angle_data(),
42-
self._joint_handler.get_previous_action(),
43-
self._phase.get_phase(),
44-
]
45-
).astype(np.float32)
46-
4733
# loading model
4834
model = self._config["models"]["walk_model"]
4935
self.load_model(model)
5036

5137
# callback functions
52-
5338
def _imu_callback(self, msg):
5439
self._gyro_handler.imu_callback(msg)
5540
self._gravity_handler.imu_callback(msg)
@@ -60,15 +45,27 @@ def _joint_state_callback(self, msg):
6045
def _cmd_vel_callback(self, msg):
6146
self._command_handler.cmd_vel_callback(msg)
6247

63-
# load phase function
48+
# observations
49+
def obs(self):
50+
return np.hstack(
51+
[
52+
self._gyro_handler.get_gyro(),
53+
self._gravity_handler.get_gravity(),
54+
self._command_handler.get_command(),
55+
self._joint_handler.get_velocity_data(),
56+
self._joint_handler.get_angle_data(),
57+
self._joint_handler.get_previous_action(),
58+
self._phase.get_phase(),
59+
]
60+
).astype(np.float32)
6461

62+
# load phase function
6563
def load_phase(self):
6664
walkready_command = self._joint_handler.get_walkready_joint_command()
6765
self._joint_command_pub.publish(walkready_command)
6866
time.sleep(10)
6967

7068
# publisher function
71-
7269
def publisher(self, onnx_pred):
7370
joint_command = self._joint_handler.get_joint_commands(onnx_pred)
7471
self._joint_command_pub.publish(joint_command)

0 commit comments

Comments
 (0)