Skip to content

Commit dd28ecb

Browse files
committed
pixi conf and transition to walk ready pos
1 parent 00f2925 commit dd28ecb

1 file changed

Lines changed: 315 additions & 0 deletions

File tree

  • src/bitbots_rl_walk/bitbots_rl_walk
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Deploy an MJX policy in ONNX format to C MuJoCo and play with it."""
16+
17+
import os
18+
import time
19+
from typing import Optional
20+
21+
import numpy as np
22+
import onnxruntime as rt
23+
from ament_index_python import get_package_share_directory
24+
from geometry_msgs.msg import Twist
25+
from rclpy.node import Node
26+
from sensor_msgs.msg import Imu, JointState
27+
from transforms3d.euler import euler2mat
28+
from transforms3d.quaternions import quat2mat
29+
30+
from bitbots_msgs.msg import JointCommand
31+
32+
ONNX_MODEL = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", "wolfgang_policy.onnx")
33+
34+
PREPARATION_STATE = np.array(
35+
[
36+
0, # RShoulderPitch
37+
0, # RShoulderRoll
38+
0, # RElbow (angewinkelt)
39+
0, # LShoulderPitch
40+
0, # LShoulderRoll
41+
0, # LElbow (angewinkelt)
42+
0, # RHipYaw
43+
0, # RHipRoll
44+
0, # RHipPitch
45+
0, # RKnee (durchgestreckt)
46+
-0.2, # RAnklePitch (leichte Verlagerung nach hinten)
47+
0, # RAnkleRoll
48+
0, # LHipYaw
49+
0, # LHipRoll
50+
0, # LHipPitch
51+
0, # LKnee (durchgestreckt)
52+
0.2, # LAnklePitch (leichte Verlagerung nach hinten)
53+
0, # LAnkleRoll
54+
],
55+
dtype=np.float32,
56+
)
57+
58+
PREPARATION_STATE2 = np.array(
59+
[
60+
0, # RShoulderPitch
61+
0, # RShoulderRoll
62+
0, # RElbow
63+
0, # LShoulderPitch
64+
0, # LShoulderRoll
65+
0, # LElbow
66+
0, # RHipYaw
67+
0, # RHipRoll
68+
0, # RHipPitch ← stark reduziert (oder probiere +0.10, falls negativ = vorne)
69+
-1.0, # RKnee ← deutlich weniger Beugung → sanfter
70+
0.60, # RAnklePitch ← stärker Zehen hoch → mehr Gegengewicht nach hinten
71+
0, # RAnkleRoll
72+
0, # LHipYaw
73+
0, # LHipRoll
74+
0, # LHipPitch ← symmetrisch – teste ggf. -0.10
75+
1.0, # LKnee ← symmetrisch weniger Beugung
76+
-0.60, # LAnklePitch ← stärker kompensierend
77+
0, # LAnkleRoll
78+
],
79+
dtype=np.float32,
80+
)
81+
82+
PREPARATION_STATE3 = np.array(
83+
[
84+
0, # RShoulderPitch
85+
0, # RShoulderRoll
86+
0, # RElbow
87+
0, # LShoulderPitch
88+
0, # LShoulderRoll
89+
0, # LElbow
90+
0.012, # RHipYaw ← leichter Übergang zu 0.024
91+
-0.05, # RHipRoll ← leichter zu -0.104
92+
-0.32, # RHipPitch ← Mittel zwischen -0.10 und -0.735 (weniger vorwärts)
93+
-1.06, # RKnee ← Mittel zwischen -0.80 und -1.323
94+
0.57, # RAnklePitch ← etwas stärker als in STATE2 (mehr nach hinten)
95+
-0.06, # RAnkleRoll ← leichter zu -0.129
96+
-0.008, # LHipYaw ← zu -0.016
97+
0.04, # LHipRoll ← zu 0.073
98+
0.32, # LHipPitch ← Mittel zwischen 0.10 und 0.742 (mirrored Sign)
99+
1.07, # LKnee ← Mittel zwischen 0.80 und 1.335
100+
-0.57, # LAnklePitch ← mirrored, stärker kompensierend
101+
0.04, # LAnkleRoll ← zu 0.074
102+
],
103+
dtype=np.float32,
104+
)
105+
106+
WALKREADY_STATE = np.array(
107+
[
108+
0,
109+
0,
110+
0,
111+
0,
112+
0,
113+
0,
114+
0.023628265148262724,
115+
-0.10401795710581162,
116+
-0.7352626990449959,
117+
-1.3228415184260092,
118+
0.5495038397740458,
119+
-0.12913515511895796,
120+
-0.016441795868928723,
121+
0.07253788412595062,
122+
0.7420808433462046,
123+
1.334527650998329,
124+
-0.5537397918567754,
125+
0.07437380704149316,
126+
],
127+
dtype=np.float32,
128+
)
129+
130+
CONTROL_DT = 0.02 # Control loop frequency in seconds
131+
132+
GAIT_FREQUENCY = 1.5 # Gait frequency in Hz
133+
134+
ORDERED_RELEVANT_JOINT_NAMES = [
135+
"RShoulderPitch",
136+
"RShoulderRoll",
137+
"RElbow",
138+
"LShoulderPitch",
139+
"LShoulderRoll",
140+
"LElbow",
141+
"RHipYaw",
142+
"RHipRoll",
143+
"RHipPitch",
144+
"RKnee",
145+
"RAnklePitch",
146+
"RAnkleRoll",
147+
"LHipYaw",
148+
"LHipRoll",
149+
"LHipPitch",
150+
"LKnee",
151+
"LAnklePitch",
152+
"LAnkleRoll",
153+
]
154+
155+
156+
class WalkNode(Node):
157+
"""Node to control the wolfgang humanoid."""
158+
159+
_previous_action: np.ndarray = np.zeros(len(ORDERED_RELEVANT_JOINT_NAMES), dtype=np.float32)
160+
_imu_data: Optional[Imu] = None
161+
_joint_state: Optional[JointState] = None
162+
_cmd_vel: Optional[Twist] = None
163+
_phase: np.ndarray = np.array([0.0, np.pi], dtype=np.float32)
164+
_phase_dt: float
165+
166+
def __init__(self):
167+
super().__init__("reinforcement_learning_walk_inference_node")
168+
169+
# Set sim time parameter to true
170+
# self.set_parameters([
171+
# Parameter('use_sim_time', Parameter.Type.BOOL, True), ])
172+
173+
self._phase_dt = 2 * np.pi * GAIT_FREQUENCY * CONTROL_DT
174+
175+
# Load the ONNX model
176+
self._onnx_session = rt.InferenceSession(ONNX_MODEL, providers=["CPUExecutionProvider"])
177+
178+
self._joint_command_pub = self.create_publisher(JointCommand, "DynamixelController/command", 10)
179+
self._imu_sub = self.create_subscription(Imu, "imu/data", self._imu_callback, 10)
180+
self._joint_state_sub = self.create_subscription(JointState, "joint_states", self._joint_state_callback, 10)
181+
self._cmd_vel_sub = self.create_subscription(Twist, "cmd_vel", self._cmd_vel_callback, 10)
182+
183+
self._timer = self.create_timer(CONTROL_DT, self._timer_callback)
184+
185+
# First send the walkready state to the robot for 100 iterations
186+
joint_command = JointCommand()
187+
joint_command.joint_names = ORDERED_RELEVANT_JOINT_NAMES
188+
joint_command.positions = PREPARATION_STATE
189+
joint_command.velocities = [0.2] * len(ORDERED_RELEVANT_JOINT_NAMES)
190+
joint_command.accelerations = [-1.0] * len(ORDERED_RELEVANT_JOINT_NAMES)
191+
joint_command.max_currents = [-1.0] * len(ORDERED_RELEVANT_JOINT_NAMES) # -1.0 means no limit
192+
joint_command.header.stamp = self.get_clock().now().to_msg()
193+
self._joint_command_pub.publish(joint_command)
194+
time.sleep(8)
195+
196+
joint_command.positions = PREPARATION_STATE2
197+
self._joint_command_pub.publish(joint_command)
198+
time.sleep(12)
199+
200+
joint_command.positions = PREPARATION_STATE3
201+
self._joint_command_pub.publish(joint_command)
202+
time.sleep(12)
203+
204+
joint_command.positions = WALKREADY_STATE
205+
self._joint_command_pub.publish(joint_command)
206+
time.sleep(20)
207+
208+
def _joint_state_callback(self, msg: JointState):
209+
self._joint_state = msg
210+
211+
def _cmd_vel_callback(self, msg: Twist):
212+
self._cmd_vel = msg
213+
214+
def _imu_callback(self, msg: Imu):
215+
self._imu_data = msg
216+
217+
def _timer_callback(self):
218+
"""Timer callback to publish joint commands based on the ONNX policy."""
219+
if self._imu_data is None or self._joint_state is None or self._cmd_vel is None:
220+
self.get_logger().warning("Waiting for all sensors to be available", throttle_duration_sec=1.0)
221+
222+
# Print the sensor that we are still waiting for
223+
if self._imu_data is None:
224+
self.get_logger().warning("Waiting for IMU data", throttle_duration_sec=1.0)
225+
if self._joint_state is None:
226+
self.get_logger().warning("Waiting for joint state data", throttle_duration_sec=1.0)
227+
if self._cmd_vel is None:
228+
self.get_logger().warning("Waiting for cmd_vel data", throttle_duration_sec=1.0)
229+
230+
return
231+
232+
# TODO consider IMU mounting offset
233+
234+
# Prepare the observation vector
235+
gyro = np.array(
236+
[
237+
self._imu_data.angular_velocity.x,
238+
self._imu_data.angular_velocity.y,
239+
self._imu_data.angular_velocity.z,
240+
],
241+
dtype=np.float32,
242+
)
243+
244+
gravity = (
245+
quat2mat(
246+
[
247+
self._imu_data.orientation.w,
248+
self._imu_data.orientation.x,
249+
self._imu_data.orientation.y,
250+
self._imu_data.orientation.z,
251+
]
252+
)
253+
@ euler2mat(0, -0.0, 0)
254+
).T @ np.array([0, 0, -1], dtype=np.float32)
255+
256+
joint_angles = (
257+
np.array(
258+
[
259+
self._joint_state.position[self._joint_state.name.index(name)]
260+
for name in ORDERED_RELEVANT_JOINT_NAMES
261+
],
262+
dtype=np.float32,
263+
)
264+
- WALKREADY_STATE
265+
)
266+
267+
joint_velocities = np.array(
268+
[self._joint_state.velocity[self._joint_state.name.index(name)] for name in ORDERED_RELEVANT_JOINT_NAMES],
269+
dtype=np.float32,
270+
)
271+
272+
phase = np.array([np.cos(self._phase), np.sin(self._phase)], dtype=np.float32).flatten()
273+
274+
command = np.array([self._cmd_vel.linear.x, self._cmd_vel.linear.y, self._cmd_vel.angular.z], dtype=np.float32)
275+
276+
obs = np.hstack(
277+
[
278+
gyro,
279+
gravity,
280+
command,
281+
joint_angles,
282+
joint_velocities,
283+
self._previous_action, # Previous action
284+
phase,
285+
]
286+
).astype(np.float32)
287+
288+
# Run the ONNX model
289+
onnx_input = {"obs": obs.reshape(1, -1)}
290+
onnx_pred = self._onnx_session.run(["continuous_actions"], onnx_input)[0][0]
291+
self._previous_action = onnx_pred
292+
293+
# Publish the joint commands
294+
joint_command = JointCommand()
295+
joint_command.header.stamp = self._joint_state.header.stamp
296+
joint_command.joint_names = ORDERED_RELEVANT_JOINT_NAMES
297+
joint_command.positions = onnx_pred * 0.5 + WALKREADY_STATE
298+
joint_command.velocities = [-1.0] * len(ORDERED_RELEVANT_JOINT_NAMES)
299+
joint_command.accelerations = [-1.0] * len(ORDERED_RELEVANT_JOINT_NAMES)
300+
joint_command.max_currents = [-1.0] * len(ORDERED_RELEVANT_JOINT_NAMES)
301+
302+
self._joint_command_pub.publish(joint_command)
303+
304+
phase_tp1 = self._phase + self._phase_dt
305+
self._phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
306+
307+
308+
def main():
309+
import rclpy
310+
311+
rclpy.init()
312+
node = WalkNode()
313+
rclpy.spin(node)
314+
node.destroy_node()
315+
rclpy.try_shutdown()

0 commit comments

Comments
 (0)