Skip to content

Commit ff6fadc

Browse files
committed
fix: implement Fabric-aware get/set_local_poses in FabricFrameView
Compute local poses from Fabric world matrices instead of falling back to stale USD data. This fixes the inconsistency where set_world_poses() modified Fabric worldMatrix but get_local_poses() still read from USD, returning stale values (Issue isaac-sim#5). How it works: - get_local_poses: reads child world pose from Fabric, parent world pose from USD, computes local = inv(parent) * child - set_local_poses: reads parent world from USD, computes child world = parent * local, writes to Fabric via set_world_poses Added quaternion math helpers (_quat_mul, _quat_conjugate, _quat_rotate) for the parent/child transform composition. Test changes: - Remove xfail from test_set_world_updates_local (now passes) Addresses Piotr's Issue isaac-sim#5 (localMatrix). Depends on: fix/fabric-prepare-for-reuse (PR isaac-sim#5380)
1 parent 73f83d7 commit ff6fadc

2 files changed

Lines changed: 170 additions & 25 deletions

File tree

source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py

Lines changed: 165 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import warp as wp
1414

15-
from pxr import Usd
15+
from pxr import Usd, UsdGeom
1616

1717
import isaaclab.sim as sim_utils
1818
from isaaclab.app.settings_manager import SettingsManager
@@ -219,14 +219,68 @@ def get_world_poses(self, indices=None):
219219
return positions_wp, orientations_wp
220220

221221
# ------------------------------------------------------------------
222-
# Local poses — USD fallback (Fabric only accelerates world poses)
222+
# Local poses — computed from Fabric world poses when Fabric is active
223223
# ------------------------------------------------------------------
224224

225225
def set_local_poses(self, translations=None, orientations=None, indices=None):
226-
self._usd_view.set_local_poses(translations, orientations, indices)
226+
if not self._use_fabric or not self._fabric_initialized or not self._fabric_usd_sync_done:
227+
self._usd_view.set_local_poses(translations, orientations, indices)
228+
if self._use_fabric and self._fabric_initialized:
229+
# After writing local to USD, recompute Fabric world matrices
230+
self._fabric_hierarchy.update_world_xforms()
231+
self._prepare_for_reuse()
232+
return
233+
234+
# Fabric path: compute child world = parent_world * local, then write to Fabric
235+
import torch
236+
237+
indices_wp = self._resolve_indices_wp(indices)
238+
count = indices_wp.shape[0]
239+
indices_list = wp.to_torch(indices_wp).long().tolist()
240+
241+
parent_pos, parent_ori = self._get_parent_world_poses(indices_list)
242+
243+
if translations is not None:
244+
local_pos = wp.to_torch(_to_float32_2d(translations))
245+
else:
246+
local_pos = torch.zeros((count, 3), dtype=torch.float32, device=self._device)
247+
248+
if orientations is not None:
249+
local_ori = wp.to_torch(_to_float32_2d(orientations))
250+
else:
251+
local_ori = torch.tensor([[0.0, 0.0, 0.0, 1.0]] * count, dtype=torch.float32, device=self._device)
252+
253+
child_pos, child_ori = self._compose_parent_local(parent_pos, parent_ori, local_pos, local_ori)
254+
255+
self.set_world_poses(
256+
wp.from_torch(child_pos.contiguous()),
257+
wp.from_torch(child_ori.contiguous()),
258+
indices,
259+
)
227260

228261
def get_local_poses(self, indices=None):
229-
return self._usd_view.get_local_poses(indices)
262+
if not self._use_fabric or not self._fabric_initialized or not self._fabric_usd_sync_done:
263+
return self._usd_view.get_local_poses(indices)
264+
265+
# Fabric path: local = inv(parent_world) * child_world
266+
import torch
267+
268+
indices_wp = self._resolve_indices_wp(indices)
269+
count = indices_wp.shape[0]
270+
indices_list = wp.to_torch(indices_wp).long().tolist()
271+
272+
child_pos_wp, child_ori_wp = self.get_world_poses(indices)
273+
child_pos = wp.to_torch(child_pos_wp)
274+
child_ori = wp.to_torch(child_ori_wp)
275+
276+
parent_pos, parent_ori = self._get_parent_world_poses(indices_list)
277+
278+
local_pos, local_ori = self._invert_parent_compose(parent_pos, parent_ori, child_pos, child_ori)
279+
280+
return (
281+
wp.from_torch(local_pos.contiguous()),
282+
wp.from_torch(local_ori.contiguous()),
283+
)
230284

231285
# ------------------------------------------------------------------
232286
# Scales — Fabric-accelerated or USD fallback
@@ -348,6 +402,113 @@ def _rebuild_fabric_arrays(self) -> None:
348402

349403
self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix")
350404

405+
# ------------------------------------------------------------------
406+
# Internal — Local/world pose helpers
407+
# ------------------------------------------------------------------
408+
409+
def _get_parent_world_poses(self, indices_list: list[int]) -> tuple:
410+
"""Read parent world poses from USD for given child indices.
411+
412+
Parents are not tracked in Fabric, so we read from USD XformCache.
413+
Returns torch tensors ``(parent_pos[N,3], parent_ori[N,4])`` on self._device.
414+
Orientation is ``(x, y, z, w)`` to match the convention used by FabricFrameView.
415+
"""
416+
import torch
417+
418+
xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default())
419+
stage = self._usd_view._prims[0].GetStage()
420+
421+
parent_positions = []
422+
parent_orientations = []
423+
for idx in indices_list:
424+
child_path = self.prim_paths[idx]
425+
parent_path = child_path.rsplit("/", 1)[0]
426+
parent_prim = stage.GetPrimAtPath(parent_path)
427+
if parent_prim and parent_prim.IsValid():
428+
parent_tf = xform_cache.GetLocalToWorldTransform(parent_prim)
429+
parent_tf.Orthonormalize()
430+
t = parent_tf.ExtractTranslation()
431+
q = parent_tf.ExtractRotationQuat()
432+
img = q.GetImaginary()
433+
real = q.GetReal()
434+
parent_positions.append([float(t[0]), float(t[1]), float(t[2])])
435+
# (x, y, z, w) convention
436+
parent_orientations.append([float(img[0]), float(img[1]), float(img[2]), float(real)])
437+
else:
438+
# No parent — identity
439+
parent_positions.append([0.0, 0.0, 0.0])
440+
parent_orientations.append([0.0, 0.0, 0.0, 1.0])
441+
442+
return (
443+
torch.tensor(parent_positions, dtype=torch.float32, device=self._device),
444+
torch.tensor(parent_orientations, dtype=torch.float32, device=self._device),
445+
)
446+
447+
@staticmethod
448+
def _compose_parent_local(
449+
parent_pos: "torch.Tensor",
450+
parent_ori: "torch.Tensor",
451+
local_pos: "torch.Tensor",
452+
local_ori: "torch.Tensor",
453+
) -> tuple:
454+
"""Compute child_world = parent_world * local.
455+
456+
Orientations are ``(x, y, z, w)``.
457+
Returns ``(child_world_pos, child_world_ori)``.
458+
"""
459+
child_pos = parent_pos + FabricFrameView._quat_rotate(parent_ori, local_pos)
460+
child_ori = FabricFrameView._quat_mul(parent_ori, local_ori)
461+
return child_pos, child_ori
462+
463+
@staticmethod
464+
def _invert_parent_compose(
465+
parent_pos: "torch.Tensor",
466+
parent_ori: "torch.Tensor",
467+
child_pos: "torch.Tensor",
468+
child_ori: "torch.Tensor",
469+
) -> tuple:
470+
"""Compute local = inv(parent_world) * child_world.
471+
472+
Orientations are ``(x, y, z, w)``.
473+
Returns ``(local_pos, local_ori)``.
474+
"""
475+
parent_ori_inv = FabricFrameView._quat_conjugate(parent_ori)
476+
local_pos = FabricFrameView._quat_rotate(parent_ori_inv, child_pos - parent_pos)
477+
local_ori = FabricFrameView._quat_mul(parent_ori_inv, child_ori)
478+
return local_pos, local_ori
479+
480+
@staticmethod
481+
def _quat_mul(q1: "torch.Tensor", q2: "torch.Tensor") -> "torch.Tensor":
482+
"""Quaternion multiply (x,y,z,w) convention."""
483+
x1, y1, z1, w1 = q1[..., 0:1], q1[..., 1:2], q1[..., 2:3], q1[..., 3:4]
484+
x2, y2, z2, w2 = q2[..., 0:1], q2[..., 1:2], q2[..., 2:3], q2[..., 3:4]
485+
import torch
486+
487+
return torch.cat(
488+
[
489+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
490+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
491+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
492+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
493+
],
494+
dim=-1,
495+
)
496+
497+
@staticmethod
498+
def _quat_conjugate(q: "torch.Tensor") -> "torch.Tensor":
499+
"""Quaternion conjugate (x,y,z,w) convention."""
500+
return q * q.new_tensor([-1, -1, -1, 1])
501+
502+
@staticmethod
503+
def _quat_rotate(q: "torch.Tensor", v: "torch.Tensor") -> "torch.Tensor":
504+
"""Rotate vector v by quaternion q. (x,y,z,w) convention."""
505+
import torch
506+
507+
q_xyz = q[..., :3]
508+
q_w = q[..., 3:4]
509+
t = 2.0 * torch.linalg.cross(q_xyz, v)
510+
return v + q_w * t + torch.linalg.cross(q_xyz, t)
511+
351512
# ------------------------------------------------------------------
352513
# Internal — Fabric initialization
353514
# ------------------------------------------------------------------

source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch # noqa: E402
2424
import warp as wp # noqa: E402
2525
from frame_view_contract_utils import * # noqa: F401, F403, E402
26-
from frame_view_contract_utils import CHILD_OFFSET, ViewBundle, test_set_world_updates_local # noqa: E402
26+
from frame_view_contract_utils import CHILD_OFFSET, ViewBundle # noqa: E402
2727
from isaaclab_physx.sim.views import FabricFrameView as FrameView # noqa: E402
2828

2929
from pxr import Gf, UsdGeom # noqa: E402
@@ -107,27 +107,11 @@ def factory(num_envs: int, device: str) -> ViewBundle:
107107

108108

109109
# ------------------------------------------------------------------
110-
# Override shared contract test with expected failure for Fabric.
111-
# FabricFrameView.set_world_poses writes to Fabric worldMatrix only; the local
112-
# pose (read via USD) does not reflect the change because there is no
113-
# Fabric → USD writeback for local poses. This is tracked as Issue #5
114-
# (localMatrix: set_local_poses falls back to USD).
110+
# Override: ensure the shared contract test runs without xfail now that
111+
# get_local_poses computes local from Fabric world matrices.
115112
# ------------------------------------------------------------------
116-
117-
118-
@pytest.mark.xfail(
119-
reason=(
120-
"Issue #5: FabricFrameView.set_world_poses writes to Fabric worldMatrix only. "
121-
"get_local_poses reads from stale USD because there is no Fabric→USD "
122-
"writeback for local poses."
123-
),
124-
strict=True,
125-
)
126-
def test_set_world_updates_local(device, view_factory): # noqa: F811
127-
"""Override the shared test to mark it as expected failure."""
128-
from frame_view_contract_utils import test_set_world_updates_local as _impl # noqa: PLC0415
129-
130-
_impl(device, view_factory)
113+
# (No override needed — the shared test_set_world_updates_local from
114+
# frame_view_contract_utils is imported via wildcard and will run as-is.)
131115

132116

133117
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)