Skip to content

Commit a92aebc

Browse files
committed
kick node
1 parent bbcb1f1 commit a92aebc

9 files changed

Lines changed: 104 additions & 20 deletions

File tree

src/bitbots_motion/bitbots_rl_motion/bitbots_rl_motion/phase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ def __init__(self, config):
1414
self._gait_frequency = self._config["phase"]["gait_frequency"]
1515
self._phase_dt = 2 * np.pi * self._gait_frequency * self._control_dt
1616

17+
def set_phase(self, new_phase):
18+
self._phase = new_phase
19+
1720
def get_phase(self):
1821
return self._phase
1922

src/bitbots_motion/bitbots_rl_motion/bitbots_rl_motion/policy_nodes.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22

33
import rclpy
44
from ament_index_python import get_package_share_directory
5-
from nodes.walk_node import WalkNode
5+
from nodes.kick_node import KickNode
66

77

88
def main():
99
rclpy.init()
1010

1111
wolfgang_config = os.path.join(get_package_share_directory("bitbots_rl_motion"), "config", "wolfgang_config.yaml")
12-
# kick_policy_path = os.path.join(get_package_share_directory("bitbots_rl_motion"), "models", "wolfgang_kick_ppo.onnx")
1312

14-
walk_node = WalkNode(wolfgang_config)
15-
# kick_node = RLNode(kick_policy_path)
13+
# walk_node = WalkNode(wolfgang_config)
14+
kick_node = KickNode(wolfgang_config)
1615

17-
rclpy.spin(walk_node)
18-
walk_node.destroy()
16+
rclpy.spin(kick_node)
17+
kick_node.destroy()
1918

2019
rclpy.try_shutdown()

src/bitbots_motion/bitbots_rl_motion/configs/wolfgang_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
models:
22
walk_model: wolfang_walk_ppo.onnx
3+
kick_model: wolfgang_kick_ppo.onnx
34

45
joints:
56
ordered_relevant_joint_names: [
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from bitbots_rl_motion.handlers.handler import Handler
2+
3+
4+
class BallHandler(Handler):
5+
def __init__(self, config):
6+
super().__init__(config)
7+
self.ball_pos = None
8+
9+
def ball_pos_callback(self, msg):
10+
self.ball_pos = msg
11+
12+
def get_ball_pos(self):
13+
return self.ball_pos
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy as np
2+
from bitbots_rl_motion.handlers.ball_handler import BallHandler
3+
from bitbots_rl_motion.handlers.joint_handler import JointHandler
4+
from bitbots_rl_motion.nodes.rl_node import RLNode
5+
from geometry_msgs.msg import PoseStamped
6+
from handlers.gravity_handler import GravityHandler
7+
from handlers.gyro_handler import GyroHandler
8+
from sensor_msgs.msg import Imu, JointState
9+
10+
from bitbots_msgs.msg import JointCommand
11+
12+
13+
class KickNode(RLNode):
14+
def __init__(self, config_path: str):
15+
super().__init__(config_path)
16+
17+
# loading model
18+
model = self._config["models"]["kick_model"]
19+
self.load_model(model)
20+
21+
# publishers
22+
self._joint_command_pub = self.create_publisher(JointCommand, "walking_motor_goals", 10)
23+
24+
# subscribers
25+
self._imu_sub = self.create_subscription(Imu, "imu/data", self._imu_callback, 10)
26+
self._joint_state_sub = self.create_subscription(JointState, "joint_states", self._joint_state_callback, 10)
27+
self._ball_pos_sub = self.create_subscription(PoseStamped, "ball_pos", self._ball_pos_callback, 10)
28+
29+
# handlers
30+
self._gyro_handler = GyroHandler(self._config)
31+
self._gravity_handler = GravityHandler(self._config)
32+
self._joint_handler = JointHandler(self._config)
33+
self._ball_handler = BallHandler(self._config)
34+
35+
# observations
36+
37+
self._obs = np.hstack(
38+
[
39+
self._gyro_handler.get_gyro(),
40+
self._gravity_handler.get_gravity(),
41+
self._joint_handler.get_velocity_data(),
42+
self._joint_handler.get_angle_data(),
43+
self._joint_handler.get_previous_action(),
44+
self._phase.get_phase(),
45+
self._ball_handler.get_ball_pos(),
46+
]
47+
).astype(np.float32)
48+
49+
# callback functions
50+
51+
def _imu_callback(self, msg):
52+
self._gyro_handler.imu_callback(msg)
53+
self._gravity_handler.imu_callback(msg)
54+
55+
def _joint_state_callback(self, msg):
56+
self._joint_handler.joint_state_callback(msg)
57+
58+
def _ball_pos_callback(self, msg):
59+
self._ball_handler.ball_pos_callback(msg)
60+
61+
# load phase function
62+
63+
def load_phase(self):
64+
pass
65+
66+
# publisher function
67+
68+
def publisher(self, onnx_pred):
69+
joint_command = self._joint_handler.get_joint_commands(onnx_pred)
70+
self._joint_command_pub.publish(joint_command)

src/bitbots_motion/bitbots_rl_motion/nodes/rl_node.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import onnxruntime as rt
2424
import yaml
2525
from ament_index_python import get_package_share_directory
26+
from bitbots_rl_motion.bitbots_rl_motion.phase import PhaseObject
2627
from rclpy.node import Node
2728
from rclpy.qos import QoSProfile
2829
from rclpy.subscription import Subscription
@@ -44,8 +45,8 @@ class SubscriptionParam(NamedTuple):
4445

4546
def __init__(self, config_path: str):
4647
self._config = self._load_config(config_path)
48+
self._phase = PhaseObject(self._config)
4749
self._obs = None # should be defined in subclass
48-
self._phase_handler = None # should be defined in subclass
4950

5051
self._timer = self.create_timer(self._config["phase"]["control_dt"], self._timer_callback)
5152
self.load_phase()
@@ -77,9 +78,9 @@ def _timer_callback(self):
7778

7879
# TODO consider IMU mounting offset
7980

80-
self._phase_handler.set_obs_phase(
81+
self._phase.set_phase(
8182
np.array(
82-
[np.cos(self._phase_handler.get_phase()), np.sin(self._phase_handler.get_phase())],
83+
[np.cos(self._phase.get_phase()), np.sin(self._phase.get_phase())],
8384
dtype=np.float32,
8485
).flatten()
8586
)
@@ -91,8 +92,8 @@ def _timer_callback(self):
9192

9293
self.publisher(onnx_pred)
9394

94-
phase_tp1 = self._phase_handler.get_phase() + self._phase_handler.get_phase_dt()
95-
self._phase_handler.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
95+
phase_tp1 = self._phase.get_phase() + self._phase.get_phase_dt()
96+
self._phase.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
9697
else:
9798
raise ConfigError("Configuration is missing! Try to run self.config() in init.")
9899

src/bitbots_motion/bitbots_rl_motion/nodes/walk_node.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
from bitbots_rl_motion.handlers.joint_handler import JointHandler
5-
from bitbots_rl_motion.handlers.phase_handler import PhaseHandler
65
from bitbots_rl_motion.nodes.rl_node import RLNode
76
from geometry_msgs.msg import Twist
87
from handlers.command_handler import CommandHandler
@@ -34,7 +33,6 @@ def __init__(self, config_path: str):
3433
self._gravity_handler = GravityHandler(self._config)
3534
self._joint_handler = JointHandler(self._config)
3635
self._command_handler = CommandHandler(self._config)
37-
self._phase_handler = PhaseHandler(self._config)
3836

3937
# observations
4038

@@ -46,7 +44,7 @@ def __init__(self, config_path: str):
4644
self._joint_handler.get_velocity_data(),
4745
self._joint_handler.get_angle_data(),
4846
self._joint_handler.get_previous_action(),
49-
self._phase_handler.get_phase(), # TODO: Check whether correct
47+
self._phase.get_phase(),
5048
]
5149
).astype(np.float32)
5250

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
<launch>
2+

src/bitbots_motion/bitbots_rl_motion/setup.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,14 @@
1818
],
1919
install_requires=["setuptools"],
2020
zip_safe=True,
21-
maintainer="florian",
22-
maintainer_email="git@flova.de",
21+
maintainer="mark oliver",
22+
maintainer_email="git@sWintermoor.de",
2323
description="TODO: Package description",
2424
license="TODO: License declaration",
2525
tests_require=["pytest"],
2626
entry_points={
2727
"console_scripts": [
28-
"walk = bitbots_rl_motion.walk:main",
29-
"kick = bitbots_rl_motion.kick:main",
30-
"walk_kick = bitbots_rl_motion.walk_kick:main",
31-
"forward_kick = bitbots_rl_motion.forward_kick:main",
28+
"run_policies = bitbots_rl_motion.policy_nodes:main",
3229
],
3330
},
3431
)

0 commit comments

Comments
 (0)