1414# ==============================================================================
1515"""Deploy an MJX policy in ONNX format to C MuJoCo and play with it."""
1616
17+ import os
1718from pathlib import Path
1819from typing import Callable , NamedTuple
1920
2021import numpy as np
2122import onnx
2223import onnxruntime as rt
24+ import yaml
25+ from ament_index_python import get_package_share_directory
2326from rclpy .node import Node
2427from rclpy .qos import QoSProfile
2528from rclpy .subscription import Subscription
@@ -42,29 +45,14 @@ class SubscriptionParam(NamedTuple):
4245 callback : Callable
4346 qos_profile : int | QoSProfile
4447
45- def __init__ (self , path_to_model ):
46- self ._onnx_model_path = Path (path_to_model )
47- model_name = self ._onnx_model_path .stem
48- super ().__init__ (f"{ model_name } " )
49-
50- # Load the ONNX model
51- self ._onnx_session = rt .InferenceSession (self ._onnx_model_path , providers = ["CPUExecutionProvider" ])
52- self ._onnx_model = onnx .load (self ._onnx_model_path )
53-
54- self ._onnx_input_name = []
55- for inp in self ._onnx_model .graph .input :
56- self ._onnx_input_name .append (inp )
57-
58- self ._onnx_output_name = []
59- for out in self ._onnx_model .graph .output :
60- self ._onnx_output_name .append (out )
61-
62- self ._timer_phase_config = None # Should be implemented in the subclass
63-
64- self ._config = False
65-
48+ def __init__ (self , config_path : str ):
49+ self ._config = self ._load_config (config_path )
6650 self ._obs = None # should be defined in subclass
67- self ._timer_phase_config = None # Should be defined in subclass
51+ self ._phase_handler = None # shoul be defined in subclass
52+
53+ def _load_config (self , path : str ):
54+ with open (path ) as f :
55+ return yaml .safe_load (f )
6856
6957 # TODO: fix
7058 def _timer_callback (self ):
@@ -85,7 +73,7 @@ def _timer_callback(self):
8573
8674 self ._timer_phase_config .set_obs_phase (
8775 np .array (
88- [np .cos (self ._timer_phase_config .get_phase ()), np .sin (self ._timer_phase_config .get_phase ())],
76+ [np .cos (self ._phase_handler .get_phase ()), np .sin (self ._phase_handler .get_phase ())],
8977 dtype = np .float32 ,
9078 ).flatten ()
9179 )
@@ -97,11 +85,31 @@ def _timer_callback(self):
9785
9886 self .publisher (onnx_pred )
9987
100- phase_tp1 = self ._timer_phase_config .get_phase () + self ._timer_phase_config .get_phase_dt ()
101- self ._timer_phase_config .set_phase (np .fmod (phase_tp1 + np .pi , 2 * np .pi ) - np .pi )
88+ phase_tp1 = self ._phase_handler .get_phase () + self ._phase_handler .get_phase_dt ()
89+ self ._phase_handler .set_phase (np .fmod (phase_tp1 + np .pi , 2 * np .pi ) - np .pi )
10290 else :
10391 raise ConfigError ("Configuration is missing! Try to run self.config() in init." )
10492
93+ def load_model (self , model ):
94+ path_to_model = os .path .join (get_package_share_directory ("bitbots_rl_walk" ), "models" , model )
95+
96+ self ._onnx_model_path = Path (path_to_model )
97+ model_name = self ._onnx_model_path .stem
98+
99+ super ().__init__ (f"{ model_name } " )
100+
101+ # Load the ONNX model
102+ self ._onnx_session = rt .InferenceSession (self ._onnx_model_path , providers = ["CPUExecutionProvider" ])
103+ self ._onnx_model = onnx .load (self ._onnx_model_path )
104+
105+ self ._onnx_input_name = []
106+ for inp in self ._onnx_model .graph .input :
107+ self ._onnx_input_name .append (inp )
108+
109+ self ._onnx_output_name = []
110+ for out in self ._onnx_model .graph .output :
111+ self ._onnx_output_name .append (out )
112+
105113 def config (self ):
106114 self ._timer = self .create_timer (self ._timer_phase_config .get_control_dt (), self ._timer_callback )
107115 self .load_phase ()
0 commit comments