@@ -59,50 +59,60 @@ def __init__(self, path_to_model):
5959 for out in self ._onnx_model .graph .output :
6060 self ._onnx_output_name .append (out )
6161
62- # TODO: Move timer to child class
63- self ._timer = self .create_timer (CONTROL_DT , self ._timer_callback )
62+ self ._timer_phase_confg = None # Should be implemented in the subclass
6463
65- self .load_phase ()
66-
67- self ._subs = []
68-
69- for key , value in self .__dict__ .values ():
70- if type (value ) is Subscription :
71- self ._subs .append (key )
64+ self ._config = False
7265
7366 self ._obs = None # should be defined in subclass
7467 self ._timer_phase_confg = None # Should be defined in subclass
7568
7669 # TODO: fix
7770 def _timer_callback (self ):
78- for subscription in self ._subs :
79- if subscription is None :
80- self .get_logger ().warning ("Waiting for all sensors to be available" , throttle_duration_sec = 1.0 )
81-
82- for subscription in self ._subs :
83- if subscription is None :
84- self .get_logger ().warning (f"Waiting for: { subscription } to be available" , throttle_duration = 1.0 )
85-
86- return
87-
88- # TODO consider IMU mounting offset
89-
90- self ._timer_phase_confg .set_obs_phase (
91- np .array (
92- [np .cos (self ._timer_phase_confg .get_phase ()), np .sin (self ._timer_phase_confg .get_phase ())],
93- dtype = np .float32 ,
94- ).flatten ()
95- )
71+ if self ._config :
72+ for subscription in self ._subs :
73+ if subscription is None :
74+ self .get_logger ().warning ("Waiting for all sensors to be available" , throttle_duration_sec = 1.0 )
75+
76+ for subscription in self ._subs :
77+ if subscription is None :
78+ self .get_logger ().warning (
79+ f"Waiting for: { subscription } to be available" , throttle_duration = 1.0
80+ )
81+
82+ return
83+
84+ # TODO consider IMU mounting offset
85+
86+ self ._timer_phase_confg .set_obs_phase (
87+ np .array (
88+ [np .cos (self ._timer_phase_confg .get_phase ()), np .sin (self ._timer_phase_confg .get_phase ())],
89+ dtype = np .float32 ,
90+ ).flatten ()
91+ )
92+
93+ # Run the ONNX model
94+ onnx_input = {self ._onnx_input_name [0 ]: self ._obs .reshape (1 , - 1 )} # TODO: Improve input
95+ onnx_pred = self ._onnx_session .run (self ._onnx_output_name , onnx_input )[0 ][0 ]
96+ self ._previous_action = onnx_pred
97+
98+ self .publisher (onnx_pred )
99+
100+ phase_tp1 = self ._timer_phase_confg .get_phase () + self ._timer_phase_confg .get_phase_dt ()
101+ self ._timer_phase_confg .set_phase (np .fmod (phase_tp1 + np .pi , 2 * np .pi ) - np .pi )
102+ else :
103+ raise ConfigError ("Configuration is missing! Try to run self.config() in init." )
104+
105+ def config (self ):
106+ self ._timer = self .create_timer (self ._timer_phase_confg .get_control_dt (), self ._timer_callback )
107+ self .load_phase ()
96108
97- # Run the ONNX model
98- onnx_input = {self ._onnx_input_name [0 ]: self ._obs .reshape (1 , - 1 )} # TODO: Improve input
99- onnx_pred = self ._onnx_session .run (self ._onnx_output_name , onnx_input )[0 ][0 ]
100- self ._previous_action = onnx_pred
109+ self ._subs = []
101110
102- self .publisher (onnx_pred )
111+ for key , value in self .__dict__ .values ():
112+ if type (value ) is Subscription :
113+ self ._subs .append (key )
103114
104- phase_tp1 = self ._timer_phase_confg .get_phase () + self ._timer_phase_confg .get_phase_dt ()
105- self ._timer_phase_confg .set_phase (np .fmod (phase_tp1 + np .pi , 2 * np .pi ) - np .pi )
115+ self ._config = True
106116
107117 def obs (self ):
108118 # Should be defined in subclass
@@ -115,3 +125,7 @@ def publisher(self):
115125 def load_phase (self ):
116126 # Should be defined in subclass
117127 pass
128+
129+
130+ class ConfigError (Exception ):
131+ pass
0 commit comments