@@ -56,34 +56,89 @@ def inner_getter() -> torch.Tensor:
5656 context_manager .add_component (onnx_input )
5757
5858
59- def add_body_pos_and_quat (
59+ def add_body_pos (
6060 articulations : dict [str , Articulation ],
6161 context_manager : ContextManager ,
6262):
63- """Add body position and orientation inputs for all articulations.
63+ """Add body position inputs for all articulations.
6464
65- For each articulation, this function adds inputs for the position and quaternion
66- of each body in the world frame .
65+ For each articulation, this function adds inputs for the position
66+ of each body belonging to that articulation .
6767
6868 Args:
6969 articulations: Dictionary mapping object names to Articulation instances.
7070 context_manager: The context manager to add body pose inputs to.
7171 """
72- # Add inputs for all body positions and quaternions in world frame
72+
73+ # Add inputs for all body positions in world frame
7374 for obj_name , articulation in articulations .items ():
74- for i , body_name in enumerate (articulation .data .body_names ):
75+ for body_name in articulation .data .body_names :
76+ body_ids , _ = articulation .find_bodies (body_name )
77+ assert len (body_ids ) == 1 , (
78+ f"Body name { body_name } is not unique in articulation { obj_name } . "
79+ f"Found body IDs: { body_ids } "
80+ )
81+ body_idx = body_ids [0 ]
7582 pos_b_rt_w_in_w = Input (
7683 name = f"{ OBJ_PREFIX } .{ obj_name } .{ body_name } .pos_b_rt_w_in_w" ,
77- get_from_env_cb = lambda art = articulation , idx = i : art .data .body_pos_w [:, idx ],
84+ get_from_env_cb = lambda art = articulation , idx = body_idx : art .data .body_pos_w [:, idx ],
85+ )
86+ context_manager .add_component (pos_b_rt_w_in_w )
87+
88+
89+ def add_body_quat (
90+ articulations : dict [str , Articulation ],
91+ context_manager : ContextManager ,
92+ ):
93+ """Add body orientation inputs for all articulations.
94+
95+ For each articulation, this function adds inputs for the quaternion
96+ of each body belonging to that articulation.
97+
98+ Args:
99+ articulations: Dictionary mapping object names to Articulation instances.
100+ context_manager: The context manager to add body pose inputs to.
101+ """
102+ # Add inputs for all body quaternions in world frame
103+ for obj_name , articulation in articulations .items ():
104+ for body_name in articulation .data .body_names :
105+ body_ids , _ = articulation .find_bodies (body_name )
106+ assert len (body_ids ) == 1 , (
107+ f"Body name { body_name } is not unique in articulation { obj_name } . "
108+ f"Found body IDs: { body_ids } "
78109 )
110+ body_idx = body_ids [0 ]
79111 w_Q_b = Input (
80112 name = f"{ OBJ_PREFIX } .{ obj_name } .{ body_name } .w_Q_b" ,
81- get_from_env_cb = lambda art = articulation , idx = i : art .data .body_quat_w [:, idx ],
113+ get_from_env_cb = lambda art = articulation , idx = body_idx : art .data .body_quat_w [:, idx ],
82114 )
83- context_manager .add_component (pos_b_rt_w_in_w )
84115 context_manager .add_component (w_Q_b )
85116
86117
118+ def add_body_pos_and_quat (
119+ articulations : dict [str , Articulation ],
120+ context_manager : ContextManager ,
121+ ):
122+ """Add body position and orientation inputs for all articulations.
123+
124+ For each articulation, this function adds inputs for the position and quaternion
125+ of each body belonging to that articulation.
126+
127+ Args:
128+ articulations: Dictionary mapping object names to Articulation instances.
129+ context_manager: The context manager to add body pose inputs to.
130+ """
131+ # Add inputs for all body positions and quaternions in world frame
132+ add_body_pos (
133+ articulations = articulations ,
134+ context_manager = context_manager ,
135+ )
136+ add_body_quat (
137+ articulations = articulations ,
138+ context_manager = context_manager ,
139+ )
140+
141+
87142def add_base_vel (
88143 articulations : dict [str , Articulation ],
89144 context_manager : ContextManager ,
0 commit comments