|
12 | 12 | import torch |
13 | 13 | import warp as wp |
14 | 14 |
|
15 | | -from pxr import Usd |
| 15 | +from pxr import Usd, UsdGeom |
16 | 16 |
|
17 | 17 | import isaaclab.sim as sim_utils |
18 | 18 | from isaaclab.app.settings_manager import SettingsManager |
@@ -219,14 +219,68 @@ def get_world_poses(self, indices=None): |
219 | 219 | return positions_wp, orientations_wp |
220 | 220 |
|
221 | 221 | # ------------------------------------------------------------------ |
222 | | - # Local poses — USD fallback (Fabric only accelerates world poses) |
| 222 | + # Local poses — computed from Fabric world poses when Fabric is active |
223 | 223 | # ------------------------------------------------------------------ |
224 | 224 |
|
225 | 225 | def set_local_poses(self, translations=None, orientations=None, indices=None): |
226 | | - self._usd_view.set_local_poses(translations, orientations, indices) |
| 226 | + if not self._use_fabric or not self._fabric_initialized or not self._fabric_usd_sync_done: |
| 227 | + self._usd_view.set_local_poses(translations, orientations, indices) |
| 228 | + if self._use_fabric and self._fabric_initialized: |
| 229 | + # After writing local to USD, recompute Fabric world matrices |
| 230 | + self._fabric_hierarchy.update_world_xforms() |
| 231 | + self._prepare_for_reuse() |
| 232 | + return |
| 233 | + |
| 234 | + # Fabric path: compute child world = parent_world * local, then write to Fabric |
| 235 | + import torch |
| 236 | + |
| 237 | + indices_wp = self._resolve_indices_wp(indices) |
| 238 | + count = indices_wp.shape[0] |
| 239 | + indices_list = wp.to_torch(indices_wp).long().tolist() |
| 240 | + |
| 241 | + parent_pos, parent_ori = self._get_parent_world_poses(indices_list) |
| 242 | + |
| 243 | + if translations is not None: |
| 244 | + local_pos = wp.to_torch(_to_float32_2d(translations)) |
| 245 | + else: |
| 246 | + local_pos = torch.zeros((count, 3), dtype=torch.float32, device=self._device) |
| 247 | + |
| 248 | + if orientations is not None: |
| 249 | + local_ori = wp.to_torch(_to_float32_2d(orientations)) |
| 250 | + else: |
| 251 | + local_ori = torch.tensor([[0.0, 0.0, 0.0, 1.0]] * count, dtype=torch.float32, device=self._device) |
| 252 | + |
| 253 | + child_pos, child_ori = self._compose_parent_local(parent_pos, parent_ori, local_pos, local_ori) |
| 254 | + |
| 255 | + self.set_world_poses( |
| 256 | + wp.from_torch(child_pos.contiguous()), |
| 257 | + wp.from_torch(child_ori.contiguous()), |
| 258 | + indices, |
| 259 | + ) |
227 | 260 |
|
228 | 261 | def get_local_poses(self, indices=None): |
229 | | - return self._usd_view.get_local_poses(indices) |
| 262 | + if not self._use_fabric or not self._fabric_initialized or not self._fabric_usd_sync_done: |
| 263 | + return self._usd_view.get_local_poses(indices) |
| 264 | + |
| 265 | + # Fabric path: local = inv(parent_world) * child_world |
| 266 | + import torch |
| 267 | + |
| 268 | + indices_wp = self._resolve_indices_wp(indices) |
| 269 | + count = indices_wp.shape[0] |
| 270 | + indices_list = wp.to_torch(indices_wp).long().tolist() |
| 271 | + |
| 272 | + child_pos_wp, child_ori_wp = self.get_world_poses(indices) |
| 273 | + child_pos = wp.to_torch(child_pos_wp) |
| 274 | + child_ori = wp.to_torch(child_ori_wp) |
| 275 | + |
| 276 | + parent_pos, parent_ori = self._get_parent_world_poses(indices_list) |
| 277 | + |
| 278 | + local_pos, local_ori = self._invert_parent_compose(parent_pos, parent_ori, child_pos, child_ori) |
| 279 | + |
| 280 | + return ( |
| 281 | + wp.from_torch(local_pos.contiguous()), |
| 282 | + wp.from_torch(local_ori.contiguous()), |
| 283 | + ) |
230 | 284 |
|
231 | 285 | # ------------------------------------------------------------------ |
232 | 286 | # Scales — Fabric-accelerated or USD fallback |
@@ -348,6 +402,113 @@ def _rebuild_fabric_arrays(self) -> None: |
348 | 402 |
|
349 | 403 | self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") |
350 | 404 |
|
| 405 | + # ------------------------------------------------------------------ |
| 406 | + # Internal — Local/world pose helpers |
| 407 | + # ------------------------------------------------------------------ |
| 408 | + |
| 409 | + def _get_parent_world_poses(self, indices_list: list[int]) -> tuple: |
| 410 | + """Read parent world poses from USD for given child indices. |
| 411 | +
|
| 412 | + Parents are not tracked in Fabric, so we read from USD XformCache. |
| 413 | + Returns torch tensors ``(parent_pos[N,3], parent_ori[N,4])`` on self._device. |
| 414 | + Orientation is ``(x, y, z, w)`` to match the convention used by FabricFrameView. |
| 415 | + """ |
| 416 | + import torch |
| 417 | + |
| 418 | + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) |
| 419 | + stage = self._usd_view._prims[0].GetStage() |
| 420 | + |
| 421 | + parent_positions = [] |
| 422 | + parent_orientations = [] |
| 423 | + for idx in indices_list: |
| 424 | + child_path = self.prim_paths[idx] |
| 425 | + parent_path = child_path.rsplit("/", 1)[0] |
| 426 | + parent_prim = stage.GetPrimAtPath(parent_path) |
| 427 | + if parent_prim and parent_prim.IsValid(): |
| 428 | + parent_tf = xform_cache.GetLocalToWorldTransform(parent_prim) |
| 429 | + parent_tf.Orthonormalize() |
| 430 | + t = parent_tf.ExtractTranslation() |
| 431 | + q = parent_tf.ExtractRotationQuat() |
| 432 | + img = q.GetImaginary() |
| 433 | + real = q.GetReal() |
| 434 | + parent_positions.append([float(t[0]), float(t[1]), float(t[2])]) |
| 435 | + # (x, y, z, w) convention |
| 436 | + parent_orientations.append([float(img[0]), float(img[1]), float(img[2]), float(real)]) |
| 437 | + else: |
| 438 | + # No parent — identity |
| 439 | + parent_positions.append([0.0, 0.0, 0.0]) |
| 440 | + parent_orientations.append([0.0, 0.0, 0.0, 1.0]) |
| 441 | + |
| 442 | + return ( |
| 443 | + torch.tensor(parent_positions, dtype=torch.float32, device=self._device), |
| 444 | + torch.tensor(parent_orientations, dtype=torch.float32, device=self._device), |
| 445 | + ) |
| 446 | + |
| 447 | + @staticmethod |
| 448 | + def _compose_parent_local( |
| 449 | + parent_pos: "torch.Tensor", |
| 450 | + parent_ori: "torch.Tensor", |
| 451 | + local_pos: "torch.Tensor", |
| 452 | + local_ori: "torch.Tensor", |
| 453 | + ) -> tuple: |
| 454 | + """Compute child_world = parent_world * local. |
| 455 | +
|
| 456 | + Orientations are ``(x, y, z, w)``. |
| 457 | + Returns ``(child_world_pos, child_world_ori)``. |
| 458 | + """ |
| 459 | + child_pos = parent_pos + FabricFrameView._quat_rotate(parent_ori, local_pos) |
| 460 | + child_ori = FabricFrameView._quat_mul(parent_ori, local_ori) |
| 461 | + return child_pos, child_ori |
| 462 | + |
| 463 | + @staticmethod |
| 464 | + def _invert_parent_compose( |
| 465 | + parent_pos: "torch.Tensor", |
| 466 | + parent_ori: "torch.Tensor", |
| 467 | + child_pos: "torch.Tensor", |
| 468 | + child_ori: "torch.Tensor", |
| 469 | + ) -> tuple: |
| 470 | + """Compute local = inv(parent_world) * child_world. |
| 471 | +
|
| 472 | + Orientations are ``(x, y, z, w)``. |
| 473 | + Returns ``(local_pos, local_ori)``. |
| 474 | + """ |
| 475 | + parent_ori_inv = FabricFrameView._quat_conjugate(parent_ori) |
| 476 | + local_pos = FabricFrameView._quat_rotate(parent_ori_inv, child_pos - parent_pos) |
| 477 | + local_ori = FabricFrameView._quat_mul(parent_ori_inv, child_ori) |
| 478 | + return local_pos, local_ori |
| 479 | + |
| 480 | + @staticmethod |
| 481 | + def _quat_mul(q1: "torch.Tensor", q2: "torch.Tensor") -> "torch.Tensor": |
| 482 | + """Quaternion multiply (x,y,z,w) convention.""" |
| 483 | + x1, y1, z1, w1 = q1[..., 0:1], q1[..., 1:2], q1[..., 2:3], q1[..., 3:4] |
| 484 | + x2, y2, z2, w2 = q2[..., 0:1], q2[..., 1:2], q2[..., 2:3], q2[..., 3:4] |
| 485 | + import torch |
| 486 | + |
| 487 | + return torch.cat( |
| 488 | + [ |
| 489 | + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, |
| 490 | + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, |
| 491 | + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, |
| 492 | + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, |
| 493 | + ], |
| 494 | + dim=-1, |
| 495 | + ) |
| 496 | + |
| 497 | + @staticmethod |
| 498 | + def _quat_conjugate(q: "torch.Tensor") -> "torch.Tensor": |
| 499 | + """Quaternion conjugate (x,y,z,w) convention.""" |
| 500 | + return q * q.new_tensor([-1, -1, -1, 1]) |
| 501 | + |
| 502 | + @staticmethod |
| 503 | + def _quat_rotate(q: "torch.Tensor", v: "torch.Tensor") -> "torch.Tensor": |
| 504 | + """Rotate vector v by quaternion q. (x,y,z,w) convention.""" |
| 505 | + import torch |
| 506 | + |
| 507 | + q_xyz = q[..., :3] |
| 508 | + q_w = q[..., 3:4] |
| 509 | + t = 2.0 * torch.linalg.cross(q_xyz, v) |
| 510 | + return v + q_w * t + torch.linalg.cross(q_xyz, t) |
| 511 | + |
351 | 512 | # ------------------------------------------------------------------ |
352 | 513 | # Internal — Fabric initialization |
353 | 514 | # ------------------------------------------------------------------ |
|
0 commit comments