Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions source/isaaclab/isaaclab/envs/mdp/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,36 @@ def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv):
f" with type: '{type(self.asset)}'."
)

# check if the physics backend supports material property randomization
self._supported = hasattr(self.asset.root_view, "get_material_properties")
if not self._supported:

logging.getLogger(__name__).warning(
"randomize_rigid_body_material: skipping because the current physics backend"
" does not support material property randomization."
)
self.num_shapes_per_body = None
return

# obtain number of shapes per body (needed for indexing the material properties correctly)
# note: this is a workaround since the Articulation does not provide a direct way to obtain the number of shapes
# per body. We use the physics simulation view to obtain the number of shapes per body.
if isinstance(self.asset, BaseArticulation) and self.asset_cfg.body_ids != slice(None):
self.num_shapes_per_body = []
for link_path in self.asset.root_view.link_paths[0]:
link_physx_view = self.asset._physics_sim_view.create_rigid_body_view(link_path) # type: ignore
self.num_shapes_per_body.append(link_physx_view.max_shapes)
# ensure the parsing is correct
num_shapes = sum(self.num_shapes_per_body)
expected_shapes = self.asset.root_view.max_shapes
if num_shapes != expected_shapes:
raise ValueError(
"Randomization term 'randomize_rigid_body_material' failed to parse the number of shapes per body."
f" Expected total shapes: {expected_shapes}, but got: {num_shapes}."
)
# check if the asset directly provides num_shapes_per_body (e.g. future backends that expose it)
if hasattr(self.asset, "num_shapes_per_body"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] Missing validation: The PhysX path (lines 224-233) validates sum(num_shapes_per_body) == max_shapes. This Newton path trusts the backend property blindly. If it's ever wrong, materials will be assigned to incorrect shape indices.

Suggestion — add the same sanity check:

if hasattr(self.asset, "num_shapes_per_body"):
    self.num_shapes_per_body = self.asset.num_shapes_per_body
    # validate against max_shapes (same check as PhysX path)
    num_shapes = sum(self.num_shapes_per_body)
    expected_shapes = self.asset.root_view.max_shapes
    if num_shapes != expected_shapes:
        raise ValueError(
            "Randomization term 'randomize_rigid_body_material' failed to parse the number of shapes"
            f" per body. Expected total shapes: {expected_shapes}, but got: {num_shapes}."
        )

self.num_shapes_per_body = self.asset.num_shapes_per_body
else:
# PhysX workaround: use the simulation view to obtain the number of shapes per body
self.num_shapes_per_body = []
for link_path in self.asset.root_view.link_paths[0]:
link_physx_view = self.asset._physics_sim_view.create_rigid_body_view(link_path) # type: ignore
self.num_shapes_per_body.append(link_physx_view.max_shapes)
# ensure the parsing is correct
num_shapes = sum(self.num_shapes_per_body)
expected_shapes = self.asset.root_view.max_shapes
if num_shapes != expected_shapes:
raise ValueError(
"Randomization term 'randomize_rigid_body_material' failed to parse the number of shapes"
f" per body. Expected total shapes: {expected_shapes}, but got: {num_shapes}."
)
else:
# in this case, we don't need to do special indexing
self.num_shapes_per_body = None
Expand Down Expand Up @@ -252,6 +266,10 @@ def __call__(
asset_cfg: SceneEntityCfg,
make_consistent: bool = False,
):
# skip if backend doesn't support material randomization
if not self._supported:
return

# resolve environment ids
if env_ids is None:
env_ids = torch.arange(env.scene.num_envs, device="cpu", dtype=torch.int32)
Expand Down