Skip to content

Commit 5bba14b

Browse files
committed
refactoring (+ abstraction)
1 parent e7da5e6 commit 5bba14b

3 files changed

Lines changed: 25 additions & 53 deletions

File tree

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
class Handler:
1+
from abc import ABC, abstractmethod
2+
3+
class Handler(ABC):
24
def __init__(self, config):
35
self._config = config
46

7+
@abstractmethod
58
def has_data(self):
69
pass

src/bitbots_motion/bitbots_rl_motion/nodes/rl_node.py

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919
from typing import Callable, NamedTuple
2020

21+
from abc import ABC, abstractmethod
2122
import numpy as np
2223
import onnx
2324
import onnxruntime as rt
@@ -31,20 +32,9 @@
3132
from handlers.handler import Handler
3233

3334

34-
class RLNode(Node):
35+
class RLNode(Node, ABC):
3536
"""Node to control the wolfgang humanoid."""
3637

37-
class PublisherParam(NamedTuple):
38-
msg_type: int
39-
topic: str
40-
qos_profile: int | QoSProfile
41-
42-
class SubscriptionParam(NamedTuple):
43-
msg_type: int
44-
topic: str
45-
callback: Callable
46-
qos_profile: int | QoSProfile
47-
4838
def __init__(self, config_path: str, node_name: str):
4939
super().__init__(f"{node_name}")
5040

@@ -56,49 +46,33 @@ def _load_config(self, path: str):
5646
with open(path) as f:
5747
return yaml.safe_load(f)
5848

59-
# TODO: fix
6049
def _timer_callback(self):
6150
if not self._config:
6251
raise ConfigError("Configuration is missing!")
6352

6453
# Prüfen ob alle Subscriber schon mindestens eine Nachricht hatten
6554
if not self._all_sensors_ready():
6655
self.get_logger().warning("Waiting for all sensors to be available", throttle_duration_sec=1.0)
67-
return
68-
69-
if self._config:
70-
for subscription in self._subs:
71-
if subscription is None:
72-
self.get_logger().warning("Waiting for all sensors to be available", throttle_duration_sec=1.0)
7356

74-
for subscription in self._subs:
75-
if subscription is None:
76-
self.get_logger().warning(
77-
f"Waiting for: {subscription} to be available", throttle_duration=1.0
78-
)
57+
# TODO consider IMU mounting offset
7958

80-
return
59+
self._phase.set_phase(
60+
np.array(
61+
[np.cos(self._phase.get_phase()), np.sin(self._phase.get_phase())],
62+
dtype=np.float32,
63+
).flatten()
64+
)
8165

82-
# TODO consider IMU mounting offset
66+
# Run the ONNX model
67+
onnx_input = {self._onnx_input_name[0]: self._obs.reshape(1, -1)} # TODO: Improve input
68+
onnx_pred = self._onnx_session.run(self._onnx_output_name, onnx_input)[0][0]
69+
self._previous_action = onnx_pred
8370

84-
self._phase.set_phase(
85-
np.array(
86-
[np.cos(self._phase.get_phase()), np.sin(self._phase.get_phase())],
87-
dtype=np.float32,
88-
).flatten()
89-
)
71+
self.publisher(onnx_pred)
9072

91-
# Run the ONNX model
92-
onnx_input = {self._onnx_input_name[0]: self._obs.reshape(1, -1)} # TODO: Improve input
93-
onnx_pred = self._onnx_session.run(self._onnx_output_name, onnx_input)[0][0]
94-
self._previous_action = onnx_pred
73+
phase_tp1 = self._phase.get_phase() + self._phase.get_phase_dt()
74+
self._phase.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
9575

96-
self.publisher(onnx_pred)
97-
98-
phase_tp1 = self._phase.get_phase() + self._phase.get_phase_dt()
99-
self._phase.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
100-
else:
101-
raise ConfigError("Configuration is missing! Try to run self.config() in init.")
10276

10377
def _all_sensors_ready(self):
10478
for handler in self._handlers:
@@ -116,13 +90,8 @@ def load_model(self, model):
11690
self._onnx_session = rt.InferenceSession(self._onnx_model_path, providers=self._config["providers"])
11791
self._onnx_model = onnx.load(self._onnx_model_path)
11892

119-
self._onnx_input_name = []
120-
for inp in self._onnx_model.graph.input:
121-
self._onnx_input_name.append(inp)
122-
123-
self._onnx_output_name = []
124-
for out in self._onnx_model.graph.output:
125-
self._onnx_output_name.append(out)
93+
self._onnx_input_name = [inp.name for inp in self._onnx_model.graph.input]
94+
self._onnx_output_name = [out.name for out in self._onnx_model.graph.output]
12695

12796
self._subs = []
12897
self._handlers = []
@@ -137,12 +106,12 @@ def load_model(self, model):
137106

138107
self.load_phase()
139108

140-
def publisher(self):
141-
# Should be defined in subclass
109+
@abstractmethod
110+
def publisher(self, action):
142111
pass
143112

113+
@abstractmethod
144114
def load_phase(self):
145-
# Should be defined in subclass
146115
pass
147116

148117

src/bitbots_motion/bitbots_rl_motion/bitbots_rl_motion/forward_kick.py renamed to src/bitbots_motion/bitbots_rl_motion/policies/forward_kick.py

File renamed without changes.

0 commit comments

Comments
 (0)