Skip to content

Commit 341cb6a

Browse files
committed
extra loading function + fixes
1 parent d22dd9d commit 341cb6a

8 files changed

Lines changed: 56 additions & 51 deletions

File tree

src/bitbots_rl_walk/handlers/command_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55

66
class CommandHandler(Handler):
7-
def __init__(self, config_file: str):
8-
super().__init__(config_file=config_file)
7+
def __init__(self, config):
8+
super().__init__(config)
99

1010
self._cmd_vel = None
1111

src/bitbots_rl_walk/handlers/gravity_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
class GravityHandler(Handler):
9-
def __init__(self, config_file: str):
10-
super().__init__(config_file=config_file)
9+
def __init__(self, config):
10+
super().__init__(config)
1111

1212
self._imu_data = None
1313
self._gravity = None

src/bitbots_rl_walk/handlers/gyro_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55

66
class GyroHandler(Handler):
7-
def __init__(self, config_file: str):
8-
super().__init__(config_file=config_file)
7+
def __init__(self, config):
8+
super().__init__(config)
99
self._imu_data = None
1010

1111
# Callables
Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
1-
import yaml
2-
3-
41
class Handler:
5-
def __init__(self, config_file: str):
6-
self._config = self._load_config(self, config_file)
7-
8-
def _load_config(self, path: str):
9-
with open(path) as f:
10-
return yaml.safe_load(f)
2+
def __init__(self, config):
3+
self._config = config
114

125
def get_data(self):
136
pass

src/bitbots_rl_walk/handlers/joint_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66

77
class JointHandler(Handler):
8-
def __init__(self, config_file: str):
9-
super().__init__(config_file=config_file)
8+
def __init__(self, config):
9+
super().__init__(config)
1010

1111
self._ordered_relevant_joint_names = self._config["joints"]["ordered_relevant_joint_names"]
1212
self._walkready_state = self._config["joints"]["walkready_state"]

src/bitbots_rl_walk/handlers/phase_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ class PhaseHandler(Handler):
77
_phase: np.ndarray = np.array([0.0, np.pi], dtype=np.float32)
88
_phase_dt: float
99

10-
def __init__(self, config_file: str):
11-
super().__init__(config_file=config_file)
10+
def __init__(self, config):
11+
super().__init__(config)
1212

1313
self._control_dt = self._config["phase"]["control_dt"]
1414
self._gait_frequency = self._config["phase"]["gait_frequency"]

src/bitbots_rl_walk/nodes/rl_node.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
# ==============================================================================
1515
"""Deploy an MJX policy in ONNX format to C MuJoCo and play with it."""
1616

17+
import os
1718
from pathlib import Path
1819
from typing import Callable, NamedTuple
1920

2021
import numpy as np
2122
import onnx
2223
import onnxruntime as rt
24+
import yaml
25+
from ament_index_python import get_package_share_directory
2326
from rclpy.node import Node
2427
from rclpy.qos import QoSProfile
2528
from rclpy.subscription import Subscription
@@ -42,29 +45,14 @@ class SubscriptionParam(NamedTuple):
4245
callback: Callable
4346
qos_profile: int | QoSProfile
4447

45-
def __init__(self, path_to_model):
46-
self._onnx_model_path = Path(path_to_model)
47-
model_name = self._onnx_model_path.stem
48-
super().__init__(f"{model_name}")
49-
50-
# Load the ONNX model
51-
self._onnx_session = rt.InferenceSession(self._onnx_model_path, providers=["CPUExecutionProvider"])
52-
self._onnx_model = onnx.load(self._onnx_model_path)
53-
54-
self._onnx_input_name = []
55-
for inp in self._onnx_model.graph.input:
56-
self._onnx_input_name.append(inp)
57-
58-
self._onnx_output_name = []
59-
for out in self._onnx_model.graph.output:
60-
self._onnx_output_name.append(out)
61-
62-
self._timer_phase_config = None # Should be implemented in the subclass
63-
64-
self._config = False
65-
48+
def __init__(self, config_path: str):
49+
self._config = self._load_config(config_path)
6650
self._obs = None # should be defined in subclass
67-
self._timer_phase_config = None # Should be defined in subclass
51+
self._phase_handler = None # shoul be defined in subclass
52+
53+
def _load_config(self, path: str):
54+
with open(path) as f:
55+
return yaml.safe_load(f)
6856

6957
# TODO: fix
7058
def _timer_callback(self):
@@ -85,7 +73,7 @@ def _timer_callback(self):
8573

8674
self._timer_phase_config.set_obs_phase(
8775
np.array(
88-
[np.cos(self._timer_phase_config.get_phase()), np.sin(self._timer_phase_config.get_phase())],
76+
[np.cos(self._phase_handler.get_phase()), np.sin(self._phase_handler.get_phase())],
8977
dtype=np.float32,
9078
).flatten()
9179
)
@@ -97,11 +85,31 @@ def _timer_callback(self):
9785

9886
self.publisher(onnx_pred)
9987

100-
phase_tp1 = self._timer_phase_config.get_phase() + self._timer_phase_config.get_phase_dt()
101-
self._timer_phase_config.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
88+
phase_tp1 = self._phase_handler.get_phase() + self._phase_handler.get_phase_dt()
89+
self._phase_handler.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
10290
else:
10391
raise ConfigError("Configuration is missing! Try to run self.config() in init.")
10492

93+
def load_model(self, model):
94+
path_to_model = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", model)
95+
96+
self._onnx_model_path = Path(path_to_model)
97+
model_name = self._onnx_model_path.stem
98+
99+
super().__init__(f"{model_name}")
100+
101+
# Load the ONNX model
102+
self._onnx_session = rt.InferenceSession(self._onnx_model_path, providers=["CPUExecutionProvider"])
103+
self._onnx_model = onnx.load(self._onnx_model_path)
104+
105+
self._onnx_input_name = []
106+
for inp in self._onnx_model.graph.input:
107+
self._onnx_input_name.append(inp)
108+
109+
self._onnx_output_name = []
110+
for out in self._onnx_model.graph.output:
111+
self._onnx_output_name.append(out)
112+
105113
def config(self):
106114
self._timer = self.create_timer(self._timer_phase_config.get_control_dt(), self._timer_callback)
107115
self.load_phase()

src/bitbots_rl_walk/nodes/walk_node.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515

1616
class WalkNode(RLNode):
17-
def __init__(self, config):
18-
super().__init__(config)
17+
def __init__(self, config_path: str):
18+
super().__init__(config_path)
19+
20+
# loading model
21+
model = self._config["models"]["walk_model"]
22+
self.load_model(model)
1923

2024
# publishers
2125
self._joint_command_pub = self.create_publisher(JointCommand, "walking_motor_goals", 10)
@@ -26,11 +30,11 @@ def __init__(self, config):
2630
self._cmd_vel_sub = self.create_subscription(Twist, "cmd_vel", self._cmd_vel_callback, 10)
2731

2832
# handlers
29-
self._gyro_handler = GyroHandler()
30-
self._gravity_handler = GravityHandler()
31-
self._joint_handler = JointHandler()
32-
self._command_handler = CommandHandler()
33-
self._phase_handler = PhaseHandler()
33+
self._gyro_handler = GyroHandler(self._config)
34+
self._gravity_handler = GravityHandler(self._config)
35+
self._joint_handler = JointHandler(self._config)
36+
self._command_handler = CommandHandler(self._config)
37+
self._phase_handler = PhaseHandler(self._config)
3438

3539
self.config()
3640

0 commit comments

Comments
 (0)