2929
3030from bitbots_msgs .msg import JointCommand
3131
32+ from soccer_vision_3d_msgs .msg import Ball , BallArray
33+
3234ONNX_MODEL = os .path .join (get_package_share_directory ("bitbots_rl_walk" ), "models" , "wolfgang_forward_kick_ppo.onnx" )
3335
3436WALKREADY_STATE = np .array (
@@ -88,7 +90,7 @@ class KickNode(Node):
8890 _imu_data : Optional [Imu ] = None
8991 _joint_state : Optional [JointState ] = None
9092 _cmd_vel : Optional [Twist ] = None
91- _ball_pose : Optional [PoseStamped ] = None
93+ _ball_pose : Optional [BallArray ] = None
9294 _phase : np .ndarray = np .array ([0.0 , np .pi ], dtype = np .float32 )
9395 _phase_dt : float
9496
@@ -108,7 +110,7 @@ def __init__(self):
108110 self ._imu_sub = self .create_subscription (Imu , "imu/data" , self ._imu_callback , 10 )
109111 self ._joint_state_sub = self .create_subscription (JointState , "joint_states" , self ._joint_state_callback , 10 )
110112 self ._cmd_vel_sub = self .create_subscription (Twist , "cmd_vel" , self ._cmd_vel_callback , 10 )
111- self ._goal_pose_sub = self .create_subscription (PoseStamped , "ball_pose " , self ._ball_pose_callback , 10 )
113+ self ._goal_pose_sub = self .create_subscription (BallArray , "balls_relative " , self ._ball_pose_callback , 10 )
112114
113115 self ._timer = self .create_timer (CONTROL_DT , self ._timer_callback )
114116
@@ -130,24 +132,23 @@ def _joint_state_callback(self, msg: JointState):
130132 def _cmd_vel_callback (self , msg : Twist ):
131133 self ._cmd_vel = msg
132134
133- def _ball_pose_callback (self , msg : PoseStamped ):
134- self ._ball_pose = msg
135+ def _ball_pose_callback (self , msg : BallArray ):
136+ if msg .balls :
137+ self ._ball_pose = msg .balls [0 ].center
135138
136139 def _imu_callback (self , msg : Imu ):
137140 self ._imu_data = msg
138141
139142 def _timer_callback (self ):
140143 """Timer callback to publish joint commands based on the ONNX policy."""
141- if self ._imu_data is None or self ._joint_state is None or self ._cmd_vel is None or self . _ball_pose is None :
144+ if self ._imu_data is None or self ._joint_state is None or self ._ball_pose is None :
142145 self .get_logger ().warning ("Waiting for all sensors to be available" , throttle_duration_sec = 1.0 )
143146
144147 # Print the sensor that we are still waiting for
145148 if self ._imu_data is None :
146149 self .get_logger ().warning ("Waiting for IMU data" , throttle_duration_sec = 1.0 )
147150 if self ._joint_state is None :
148151 self .get_logger ().warning ("Waiting for joint state data" , throttle_duration_sec = 1.0 )
149- if self ._cmd_vel is None :
150- self .get_logger ().warning ("Waiting for cmd_vel data" , throttle_duration_sec = 1.0 )
151152 # self._cmd_vel = Twist(x=0.3, y=0.0, z=0.0) # Testing purpose
152153 if self ._ball_pose is None :
153154 self .get_logger ().warning ("Waiting for ball pose data" , throttle_duration_sec = 1.0 )
@@ -177,7 +178,7 @@ def _timer_callback(self):
177178 self ._imu_data .orientation .z ,
178179 ]
179180 )
180- @ euler2mat (0 , - 0.0 , 0 )
181+ @ euler2mat (0 , - 0.1 , 0 )
181182 ).T @ np .array ([0 , 0 , - 1 ], dtype = np .float32 )
182183
183184 joint_angles = (
@@ -198,12 +199,12 @@ def _timer_callback(self):
198199
199200 phase = np .array ([np .cos (self ._phase ), np .sin (self ._phase )], dtype = np .float32 ).flatten ()
200201
201- command = np .array ([ self . _cmd_vel . linear . x , self . _cmd_vel . linear . y , self . _cmd_vel . angular . z ], dtype = np . float32 )
202+ command = np .zeros ( 3 )
202203
203204 rel_ball_pos = np .array (
204205 [
205- self ._ball_pose .pose . position . x ,
206- self ._ball_pose .pose . position . y ,
206+ self ._ball_pose .x ,
207+ self ._ball_pose .y ,
207208 ],
208209 dtype = np .float32 ,
209210 )
0 commit comments