Skip to content

Commit 8043366

Browse files
committed
adjustments to wolfgang_walk_ppo model
1 parent 431d1bd commit 8043366

2 files changed

Lines changed: 38 additions & 11 deletions

File tree

src/bitbots_rl_walk/bitbots_rl_walk/walk.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from bitbots_msgs.msg import JointCommand
3131

32-
ONNX_MODEL = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", "wolfgang_policy.onnx")
32+
ONNX_MODEL = os.path.join(get_package_share_directory("bitbots_rl_walk"), "models", "wolfgang_walk_ppo.onnx")
3333

3434
PREPARATION_STATE = np.array(
3535
[
@@ -225,7 +225,8 @@ def _timer_callback(self):
225225
if self._joint_state is None:
226226
self.get_logger().warning("Waiting for joint state data", throttle_duration_sec=1.0)
227227
if self._cmd_vel is None:
228-
self.get_logger().warning("Waiting for cmd_vel data", throttle_duration_sec=1.0)
228+
# self.get_logger().warning("Waiting for cmd_vel data", throttle_duration_sec=1.0)
229+
self._cmd_vel = Twist(x=1.0, y=0.0, z=0.0) # Testing purpose
229230

230231
return
231232

@@ -275,19 +276,19 @@ def _timer_callback(self):
275276

276277
obs = np.hstack(
277278
[
278-
gyro,
279-
gravity,
280-
command,
281-
joint_angles,
282-
joint_velocities,
283-
self._previous_action, # Previous action
284-
phase,
279+
gyro, # 3
280+
gravity, # 4
281+
command, # 3
282+
joint_angles, # 18
283+
joint_velocities, # 18
284+
self._previous_action, # 18 # Previous action
285+
phase, # 2
285286
]
286287
).astype(np.float32)
287288

288289
# Run the ONNX model
289-
onnx_input = {"obs": obs.reshape(1, -1)}
290-
onnx_pred = self._onnx_session.run(["continuous_actions"], onnx_input)[0][0]
290+
onnx_input = {"in_0": obs.reshape(1, -1)}
291+
onnx_pred = self._onnx_session.run(["tanh_out_0"], onnx_input)[0][0]
291292
self._previous_action = onnx_pred
292293

293294
# Publish the joint commands

src/bitbots_rl_walk/setup.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from setuptools import find_packages, setup
2+
3+
package_name = "bitbots_rl_walk"
4+
5+
setup(
6+
name=package_name,
7+
version="0.0.0",
8+
packages=find_packages(exclude=["test"]),
9+
data_files=[
10+
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
11+
("share/" + package_name, ["package.xml"]),
12+
("share/" + package_name + "/models", ["models/wolfgang_walk_ppo.onnx"]),
13+
],
14+
install_requires=["setuptools"],
15+
zip_safe=True,
16+
maintainer="florian",
17+
maintainer_email="git@flova.de",
18+
description="TODO: Package description",
19+
license="TODO: License declaration",
20+
tests_require=["pytest"],
21+
entry_points={
22+
"console_scripts": [
23+
"walk = bitbots_rl_walk.walk:main",
24+
],
25+
},
26+
)

0 commit comments

Comments
 (0)