Skip to content

Commit e380a03

Browse files
committed
using confg method to avoid initialization deadlocks
1 parent c2ee62a commit e380a03

3 files changed

Lines changed: 52 additions & 35 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class ConfgObject:
2+
def __init__():
3+
pass

src/bitbots_rl_walk/nodes/rl_node.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -59,50 +59,60 @@ def __init__(self, path_to_model):
5959
for out in self._onnx_model.graph.output:
6060
self._onnx_output_name.append(out)
6161

62-
# TODO: Move timer to child class
63-
self._timer = self.create_timer(CONTROL_DT, self._timer_callback)
62+
self._timer_phase_confg = None # Should be implemented in the subclass
6463

65-
self.load_phase()
66-
67-
self._subs = []
68-
69-
for key, value in self.__dict__.values():
70-
if type(value) is Subscription:
71-
self._subs.append(key)
64+
self._config = False
7265

7366
self._obs = None # should be defined in subclass
7467
self._timer_phase_confg = None # Should be defined in subclass
7568

7669
# TODO: fix
7770
def _timer_callback(self):
78-
for subscription in self._subs:
79-
if subscription is None:
80-
self.get_logger().warning("Waiting for all sensors to be available", throttle_duration_sec=1.0)
81-
82-
for subscription in self._subs:
83-
if subscription is None:
84-
self.get_logger().warning(f"Waiting for: {subscription} to be available", throttle_duration=1.0)
85-
86-
return
87-
88-
# TODO consider IMU mounting offset
89-
90-
self._timer_phase_confg.set_obs_phase(
91-
np.array(
92-
[np.cos(self._timer_phase_confg.get_phase()), np.sin(self._timer_phase_confg.get_phase())],
93-
dtype=np.float32,
94-
).flatten()
95-
)
71+
if self._config:
72+
for subscription in self._subs:
73+
if subscription is None:
74+
self.get_logger().warning("Waiting for all sensors to be available", throttle_duration_sec=1.0)
75+
76+
for subscription in self._subs:
77+
if subscription is None:
78+
self.get_logger().warning(
79+
f"Waiting for: {subscription} to be available", throttle_duration=1.0
80+
)
81+
82+
return
83+
84+
# TODO consider IMU mounting offset
85+
86+
self._timer_phase_confg.set_obs_phase(
87+
np.array(
88+
[np.cos(self._timer_phase_confg.get_phase()), np.sin(self._timer_phase_confg.get_phase())],
89+
dtype=np.float32,
90+
).flatten()
91+
)
92+
93+
# Run the ONNX model
94+
onnx_input = {self._onnx_input_name[0]: self._obs.reshape(1, -1)} # TODO: Improve input
95+
onnx_pred = self._onnx_session.run(self._onnx_output_name, onnx_input)[0][0]
96+
self._previous_action = onnx_pred
97+
98+
self.publisher(onnx_pred)
99+
100+
phase_tp1 = self._timer_phase_confg.get_phase() + self._timer_phase_confg.get_phase_dt()
101+
self._timer_phase_confg.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
102+
else:
103+
raise ConfigError("Configuration is missing! Try to run self.config() in init.")
104+
105+
def config(self):
106+
self._timer = self.create_timer(self._timer_phase_confg.get_control_dt(), self._timer_callback)
107+
self.load_phase()
96108

97-
# Run the ONNX model
98-
onnx_input = {self._onnx_input_name[0]: self._obs.reshape(1, -1)} # TODO: Improve input
99-
onnx_pred = self._onnx_session.run(self._onnx_output_name, onnx_input)[0][0]
100-
self._previous_action = onnx_pred
109+
self._subs = []
101110

102-
self.publisher(onnx_pred)
111+
for key, value in self.__dict__.values():
112+
if type(value) is Subscription:
113+
self._subs.append(key)
103114

104-
phase_tp1 = self._timer_phase_confg.get_phase() + self._timer_phase_confg.get_phase_dt()
105-
self._timer_phase_confg.set_phase(np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi)
115+
self._config = True
106116

107117
def obs(self):
108118
# Should be defined in subclass
@@ -115,3 +125,7 @@ def publisher(self):
115125
def load_phase(self):
116126
# Should be defined in subclass
117127
pass
128+
129+
130+
class ConfigError(Exception):
131+
pass

src/bitbots_rl_walk/nodes/walk_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, walk_policy_path):
3333

3434
self._timer_phase_confg = TimerPhaseConfg()
3535

36-
# TODO: timer is missing
36+
self.confg()
3737

3838
self._obs = np.hstack(
3939
[

0 commit comments

Comments
 (0)