1818from pathlib import Path
1919from typing import Callable , NamedTuple
2020
21+ from abc import ABC , abstractmethod
2122import numpy as np
2223import onnx
2324import onnxruntime as rt
3132from handlers .handler import Handler
3233
3334
34- class RLNode (Node ):
35+ class RLNode (Node , ABC ):
3536 """Node to control the wolfgang humanoid."""
3637
37- class PublisherParam (NamedTuple ):
38- msg_type : int
39- topic : str
40- qos_profile : int | QoSProfile
41-
42- class SubscriptionParam (NamedTuple ):
43- msg_type : int
44- topic : str
45- callback : Callable
46- qos_profile : int | QoSProfile
47-
4838 def __init__ (self , config_path : str , node_name : str ):
4939 super ().__init__ (f"{ node_name } " )
5040
@@ -56,49 +46,33 @@ def _load_config(self, path: str):
5646 with open (path ) as f :
5747 return yaml .safe_load (f )
5848
59- # TODO: fix
6049 def _timer_callback (self ):
6150 if not self ._config :
6251 raise ConfigError ("Configuration is missing!" )
6352
6453 # Prüfen ob alle Subscriber schon mindestens eine Nachricht hatten
6554 if not self ._all_sensors_ready ():
6655 self .get_logger ().warning ("Waiting for all sensors to be available" , throttle_duration_sec = 1.0 )
67- return
68-
69- if self ._config :
70- for subscription in self ._subs :
71- if subscription is None :
72- self .get_logger ().warning ("Waiting for all sensors to be available" , throttle_duration_sec = 1.0 )
7356
74- for subscription in self ._subs :
75- if subscription is None :
76- self .get_logger ().warning (
77- f"Waiting for: { subscription } to be available" , throttle_duration = 1.0
78- )
57+ # TODO consider IMU mounting offset
7958
80- return
59+ self ._phase .set_phase (
60+ np .array (
61+ [np .cos (self ._phase .get_phase ()), np .sin (self ._phase .get_phase ())],
62+ dtype = np .float32 ,
63+ ).flatten ()
64+ )
8165
82- # TODO consider IMU mounting offset
66+ # Run the ONNX model
67+ onnx_input = {self ._onnx_input_name [0 ]: self ._obs .reshape (1 , - 1 )} # TODO: Improve input
68+ onnx_pred = self ._onnx_session .run (self ._onnx_output_name , onnx_input )[0 ][0 ]
69+ self ._previous_action = onnx_pred
8370
84- self ._phase .set_phase (
85- np .array (
86- [np .cos (self ._phase .get_phase ()), np .sin (self ._phase .get_phase ())],
87- dtype = np .float32 ,
88- ).flatten ()
89- )
71+ self .publisher (onnx_pred )
9072
91- # Run the ONNX model
92- onnx_input = {self ._onnx_input_name [0 ]: self ._obs .reshape (1 , - 1 )} # TODO: Improve input
93- onnx_pred = self ._onnx_session .run (self ._onnx_output_name , onnx_input )[0 ][0 ]
94- self ._previous_action = onnx_pred
73+ phase_tp1 = self ._phase .get_phase () + self ._phase .get_phase_dt ()
74+ self ._phase .set_phase (np .fmod (phase_tp1 + np .pi , 2 * np .pi ) - np .pi )
9575
96- self .publisher (onnx_pred )
97-
98- phase_tp1 = self ._phase .get_phase () + self ._phase .get_phase_dt ()
99- self ._phase .set_phase (np .fmod (phase_tp1 + np .pi , 2 * np .pi ) - np .pi )
100- else :
101- raise ConfigError ("Configuration is missing! Try to run self.config() in init." )
10276
10377 def _all_sensors_ready (self ):
10478 for handler in self ._handlers :
@@ -116,13 +90,8 @@ def load_model(self, model):
11690 self ._onnx_session = rt .InferenceSession (self ._onnx_model_path , providers = self ._config ["providers" ])
11791 self ._onnx_model = onnx .load (self ._onnx_model_path )
11892
119- self ._onnx_input_name = []
120- for inp in self ._onnx_model .graph .input :
121- self ._onnx_input_name .append (inp )
122-
123- self ._onnx_output_name = []
124- for out in self ._onnx_model .graph .output :
125- self ._onnx_output_name .append (out )
93+ self ._onnx_input_name = [inp .name for inp in self ._onnx_model .graph .input ]
94+ self ._onnx_output_name = [out .name for out in self ._onnx_model .graph .output ]
12695
12796 self ._subs = []
12897 self ._handlers = []
@@ -137,12 +106,12 @@ def load_model(self, model):
137106
138107 self .load_phase ()
139108
140- def publisher ( self ):
141- # Should be defined in subclass
109+ @ abstractmethod
110+ def publisher ( self , action ):
142111 pass
143112
113+ @abstractmethod
144114 def load_phase (self ):
145- # Should be defined in subclass
146115 pass
147116
148117
0 commit comments