|
29 | 29 |
|
30 | 30 | from bitbots_msgs.msg import JointCommand |
31 | 31 |
|
32 | | -ONNX_MODEL = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", "wolfgang_policy.onnx") |
| 32 | +ONNX_MODEL = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", "wolfgang_walk_ppo.onnx") |
33 | 33 |
|
34 | 34 | PREPARATION_STATE = np.array( |
35 | 35 | [ |
@@ -225,7 +225,8 @@ def _timer_callback(self): |
225 | 225 | if self._joint_state is None: |
226 | 226 | self.get_logger().warning("Waiting for joint state data", throttle_duration_sec=1.0) |
227 | 227 | if self._cmd_vel is None: |
228 | | - self.get_logger().warning("Waiting for cmd_vel data", throttle_duration_sec=1.0) |
| 228 | + # self.get_logger().warning("Waiting for cmd_vel data", throttle_duration_sec=1.0) |
| 229 | + self._cmd_vel = Twist(x=1.0, y=0.0, z=0.0) # Testing purpose |
229 | 230 |
|
230 | 231 | return |
231 | 232 |
|
@@ -275,19 +276,19 @@ def _timer_callback(self): |
275 | 276 |
|
276 | 277 | obs = np.hstack( |
277 | 278 | [ |
278 | | - gyro, |
279 | | - gravity, |
280 | | - command, |
281 | | - joint_angles, |
282 | | - joint_velocities, |
283 | | - self._previous_action, # Previous action |
284 | | - phase, |
| 279 | + gyro, # 3 |
| 280 | + gravity, # 4 |
| 281 | + command, # 3 |
| 282 | + joint_angles, # 18 |
| 283 | + joint_velocities, # 18 |
| 284 | + self._previous_action, # 18 # Previous action |
| 285 | + phase, # 2 |
285 | 286 | ] |
286 | 287 | ).astype(np.float32) |
287 | 288 |
|
288 | 289 | # Run the ONNX model |
289 | | - onnx_input = {"obs": obs.reshape(1, -1)} |
290 | | - onnx_pred = self._onnx_session.run(["continuous_actions"], onnx_input)[0][0] |
| 290 | + onnx_input = {"in_0": obs.reshape(1, -1)} |
| 291 | + onnx_pred = self._onnx_session.run(["tanh_out_0"], onnx_input)[0][0] |
291 | 292 | self._previous_action = onnx_pred |
292 | 293 |
|
293 | 294 | # Publish the joint commands |
|
0 commit comments