File tree Expand file tree Collapse file tree
src/bitbots_motion/bitbots_rl_motion Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -8,16 +8,22 @@ class PhaseObject:
88 _phase_dt : float
99
1010 def __init__ (self , config ):
11-
1211 self ._control_dt = config ["phase" ]["control_dt" ]
1312 self ._gait_frequency = config ["phase" ]["gait_frequency" ]
1413 self ._phase_dt = 2 * np .pi * self ._gait_frequency * self ._control_dt
14+ self ._obs_phase = None
1515
1616 def set_phase (self , new_phase ):
1717 self ._phase = new_phase
1818
19+ def set_obs_phase (self , new_obs_phase ):
20+ self ._obs_phase = new_obs_phase
21+
1922 def get_phase (self ):
2023 return self ._phase
2124
2225 def get_phase_dt (self ):
2326 return self ._phase_dt
27+
28+ def get_obs_phase (self ):
29+ return self ._obs_phase
Original file line number Diff line number Diff line change @@ -46,18 +46,20 @@ def _ball_pos_callback(self, msg):
4646
4747 # observations
4848 def obs (self ):
49- return np .hstack (
49+ observation = np .hstack (
5050 [
5151 self ._gyro_handler .get_gyro (),
5252 self ._gravity_handler .get_gravity (),
5353 self ._joint_handler .get_velocity_data (),
5454 self ._joint_handler .get_angle_data (),
5555 self ._previous_action .get_previous_action (),
56- self ._phase .get_phase (),
56+ self ._phase .get_obs_phase (),
5757 self ._ball_handler .get_ball_pos (),
5858 ]
5959 ).astype (np .float32 )
6060
61+ return observation
62+
6163 # load phase function
6264 def load_phase (self ):
6365 pass
Original file line number Diff line number Diff line change @@ -62,15 +62,17 @@ def _timer_callback(self):
6262
6363 # TODO consider IMU mounting offset
6464
65- self ._phase .set_phase (
65+ self ._phase .set_obs_phase (
6666 np .array (
6767 [np .cos (self ._phase .get_phase ()), np .sin (self ._phase .get_phase ())],
6868 dtype = np .float32 ,
6969 ).flatten ()
7070 )
7171
72+ observation = self .obs ()
73+
7274 # Run the ONNX model
73- onnx_input = {self ._onnx_input_name [0 ]: self . obs () .reshape (1 , - 1 )} # TODO: Improve input
75+ onnx_input = {self ._onnx_input_name [0 ]: observation .reshape (1 , - 1 )} # TODO: Improve input
7476 onnx_pred = self ._onnx_session .run (self ._onnx_output_name , onnx_input )[0 ][0 ]
7577 self ._previous_action .set_previous_action (onnx_pred )
7678
Original file line number Diff line number Diff line change @@ -56,7 +56,7 @@ def obs(self):
5656 self ._joint_handler .get_velocity_data (),
5757 self ._joint_handler .get_angle_data (),
5858 self ._previous_action .get_previous_action (),
59- self ._phase .get_phase (),
59+ self ._phase .get_obs_phase (),
6060 ]
6161 ).astype (np .float32 )
6262
You can’t perform that action at this time.
0 commit comments