Skip to content

Commit ae8b277

Browse files
committed
using central config file
1 parent 03ae24c commit ae8b277

12 files changed

Lines changed: 116 additions & 103 deletions

File tree

src/bitbots_rl_walk/bitbots_rl_walk/policy_nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
def main():
99
rclpy.init()
1010

11-
walk_policy_path = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", "wolfgang_walk_ppo.onnx")
11+
wolfgang_config = os.path.join(get_package_share_directory("bitbots_rl_walk"), "config", "wolfgang_config.yaml")
1212
# kick_policy_path = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", "wolfgang_kick_ppo.onnx")
1313

14-
walk_node = WalkNode(walk_policy_path)
14+
walk_node = WalkNode(wolfgang_config)
1515
# kick_node = RLNode(kick_policy_path)
1616

1717
rclpy.spin(walk_node)

src/bitbots_rl_walk/config/config_obj.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/bitbots_rl_walk/config/timer_phase_config.py

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
models:
2+
walk_model: wolfang_walk_ppo.onnx
3+
4+
joints:
5+
ordered_relevant_joint_names: [
6+
"RShoulderPitch",
7+
"RShoulderRoll",
8+
"RElbow",
9+
"LShoulderPitch",
10+
"LShoulderRoll",
11+
"LElbow",
12+
"RHipYaw",
13+
"RHipRoll",
14+
"RHipPitch",
15+
"RKnee",
16+
"RAnklePitch",
17+
"RAnkleRoll",
18+
"LHipYaw",
19+
"LHipRoll",
20+
"LHipPitch",
21+
"LKnee",
22+
"LAnklePitch",
23+
"LAnkleRoll",
24+
]
25+
26+
walkready_state: [
27+
0,
28+
0,
29+
0,
30+
0,
31+
0,
32+
0,
33+
0.023628265148262724,
34+
-0.10401795710581162,
35+
-0.7352626990449959,
36+
-1.3228415184260092,
37+
0.5495038397740458,
38+
-0.12913515511895796,
39+
-0.016441795868928723,
40+
0.07253788412595062,
41+
0.7420808433462046,
42+
1.334527650998329,
43+
-0.5537397918567754,
44+
0.07437380704149316,
45+
]
46+
47+
phase:
48+
control_dt: 0.02
49+
gait_frequency: 1.5

src/bitbots_rl_walk/handler/handler.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

src/bitbots_rl_walk/handler/command_handler.py renamed to src/bitbots_rl_walk/handlers/command_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import numpy as np
22

3-
from bitbots_rl_walk.handler.handler import Handler
3+
from bitbots_rl_walk.handlers.handler import Handler
44

55

66
class CommandHandler(Handler):
7-
def __init__(self):
7+
def __init__(self, config_file: str):
8+
super().__init__(config_file=config_file)
9+
810
self._cmd_vel = None
911

1012
def get_data(self):

src/bitbots_rl_walk/handler/gravity_handler.py renamed to src/bitbots_rl_walk/handlers/gravity_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from transforms3d.euler import euler2mat
33
from transforms3d.quaternions import quat2mat
44

5-
from bitbots_rl_walk.handler.handler import Handler
5+
from bitbots_rl_walk.handlers.handler import Handler
66

77

88
class GravityHandler(Handler):
9-
def __init__(self):
9+
def __init__(self, config_file: str):
10+
super().__init__(config_file=config_file)
11+
1012
self._imu_data = None
1113
self._gravity = None
1214

src/bitbots_rl_walk/handler/gyro_handler.py renamed to src/bitbots_rl_walk/handlers/gyro_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
22

3-
from bitbots_rl_walk.handler.handler import Handler
3+
from bitbots_rl_walk.handlers.handler import Handler
44

55

66
class GyroHandler(Handler):
7-
def __init__(self):
7+
def __init__(self, config_file: str):
8+
super().__init__(config_file=config_file)
89
self._imu_data = None
910

1011
# Callables
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import yaml
2+
3+
4+
class Handler:
5+
def __init__(self, config_file: str):
6+
self._config = self._load_config(self, config_file)
7+
8+
def _load_config(self, path: str):
9+
with open(path) as f:
10+
return yaml.safe_load(f)
11+
12+
def get_data(self):
13+
pass

src/bitbots_rl_walk/handler/joint_handler.py renamed to src/bitbots_rl_walk/handlers/joint_handler.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,23 @@
11
import numpy as np
22

33
from bitbots_msgs.msg import JointCommand
4-
from bitbots_rl_walk.handler.handler import Handler
5-
6-
ORDERED_RELEVANT_JOINT_NAMES = [
7-
"RShoulderPitch",
8-
"RShoulderRoll",
9-
"RElbow",
10-
"LShoulderPitch",
11-
"LShoulderRoll",
12-
"LElbow",
13-
"RHipYaw",
14-
"RHipRoll",
15-
"RHipPitch",
16-
"RKnee",
17-
"RAnklePitch",
18-
"RAnkleRoll",
19-
"LHipYaw",
20-
"LHipRoll",
21-
"LHipPitch",
22-
"LKnee",
23-
"LAnklePitch",
24-
"LAnkleRoll",
25-
]
26-
27-
WALKREADY_STATE = np.array(
28-
[
29-
0,
30-
0,
31-
0,
32-
0,
33-
0,
34-
0,
35-
0.023628265148262724,
36-
-0.10401795710581162,
37-
-0.7352626990449959,
38-
-1.3228415184260092,
39-
0.5495038397740458,
40-
-0.12913515511895796,
41-
-0.016441795868928723,
42-
0.07253788412595062,
43-
0.7420808433462046,
44-
1.334527650998329,
45-
-0.5537397918567754,
46-
0.07437380704149316,
47-
],
48-
dtype=np.float32,
49-
)
4+
from bitbots_rl_walk.handlers.handler import Handler
505

516

527
class JointHandler(Handler):
53-
def __init__(self, ordered_relevant_joint_names=ORDERED_RELEVANT_JOINT_NAMES, walkready_state=WALKREADY_STATE):
54-
self._ordered_relevant_joint_names = ordered_relevant_joint_names
55-
self._walkready_state = walkready_state
8+
def __init__(self, config_file: str):
9+
super().__init__(config_file=config_file)
10+
11+
self._ordered_relevant_joint_names = self._config["joints"]["ordered_relevant_joint_names"]
12+
self._walkready_state = self._config["joints"]["walkready_state"]
5613
self._previous_action: np.ndarray = np.zeros(len(self._ordered_relevant_joint_names), dtype=np.float32)
5714
self._joint_state = None
5815
self._obs_phase = None
5916
self._phase = None
6017

18+
def joint_state_callback(self, msg):
19+
self._joint_state = msg
20+
6121
def set_obs_phase(self, phase):
6222
self._obs_phase = phase
6323

@@ -126,6 +86,3 @@ def get_obs_phase(self):
12686

12787
def get_phase(self):
12888
return self._phase
129-
130-
def joint_state_callback(self, msg):
131-
self._joint_state = msg

0 commit comments

Comments
 (0)