|
10 | 10 | from typing import TYPE_CHECKING |
11 | 11 |
|
12 | 12 | from isaaclab_arena.relations.loss_primitives import ( |
| 13 | + interval_overlap_axis_loss, |
13 | 14 | linear_band_loss, |
14 | 15 | single_boundary_linear_loss, |
15 | 16 | single_point_linear_loss, |
16 | 17 | ) |
17 | 18 | from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox |
18 | 19 |
|
19 | 20 | if TYPE_CHECKING: |
20 | | - from isaaclab_arena.relations.relations import AtPosition, NextTo, On, Relation |
| 21 | + from isaaclab_arena.relations.relations import AtPosition, NextTo, On, Relation, NoCollision |
21 | 22 |
|
22 | 23 | from isaaclab_arena.relations.relations import Side |
23 | 24 |
|
@@ -305,6 +306,84 @@ def compute_loss( |
305 | 306 | return relation.relation_loss_weight * total_loss |
306 | 307 |
|
307 | 308 |
|
| 309 | +class NoCollisionLossStrategy(RelationLossStrategy): |
| 310 | + """Loss strategy for NoCollision relations. |
| 311 | +
|
| 312 | + Computes loss based on: |
| 313 | + 1. X overlap: zero when child and parent are separated along X; else overlap length |
| 314 | + 2. Y overlap: zero when separated along Y; else overlap length |
| 315 | + 3. Z overlap: zero when separated along Z; else overlap length |
| 316 | + 4. Volume loss: slope * (overlap_x * overlap_y * overlap_z) |
| 317 | + """ |
| 318 | + |
| 319 | + def __init__(self, slope: float = 10.0, debug: bool = False): |
| 320 | + """ |
| 321 | + Args: |
| 322 | + slope: Gradient magnitude for overlap volume loss (default: 10.0). |
| 323 | + Loss scales with slope times overlap volume. |
| 324 | + debug: If True, print detailed loss component breakdown. |
| 325 | + """ |
| 326 | + self.slope = slope |
| 327 | + self.debug = debug |
| 328 | + |
| 329 | + def compute_loss( |
| 330 | + self, |
| 331 | + relation: "NoCollision", |
| 332 | + child_pos: torch.Tensor, |
| 333 | + child_bbox: AxisAlignedBoundingBox, |
| 334 | + parent_world_bbox: AxisAlignedBoundingBox, |
| 335 | + ) -> torch.Tensor: |
| 336 | + """Compute loss for NoCollision relation. |
| 337 | +
|
| 338 | + Args: |
| 339 | + relation: NoCollision relation with relation_loss_weight. |
| 340 | + child_pos: Child object position tensor (x, y, z) in world coords. |
| 341 | + child_bbox: Child object local bounding box. |
| 342 | + parent_world_bbox: Parent bounding box in world coordinates. |
| 343 | +
|
| 344 | + Returns: |
| 345 | + Weighted loss tensor. |
| 346 | + """ |
| 347 | + # Parent world extents from the world bounding box, expanded by clearance_m |
| 348 | + c = relation.clearance_m |
| 349 | + parent_x_min = parent_world_bbox.min_point[0] - c |
| 350 | + parent_x_max = parent_world_bbox.max_point[0] + c |
| 351 | + parent_y_min = parent_world_bbox.min_point[1] - c |
| 352 | + parent_y_max = parent_world_bbox.max_point[1] + c |
| 353 | + parent_z_min = parent_world_bbox.min_point[2] - c |
| 354 | + parent_z_max = parent_world_bbox.max_point[2] + c |
| 355 | + |
| 356 | + # Child world extents |
| 357 | + child_world_min = child_pos + torch.tensor(child_bbox.min_point, dtype=child_pos.dtype, device=child_pos.device) |
| 358 | + child_world_max = child_pos + torch.tensor(child_bbox.max_point, dtype=child_pos.dtype, device=child_pos.device) |
| 359 | + |
| 360 | + # 1. Per-axis overlap: zero when separated; else overlap length (default slope 1.0 gives length in m) |
| 361 | + overlap_x = interval_overlap_axis_loss(child_world_min[0], child_world_max[0], parent_x_min, parent_x_max) |
| 362 | + overlap_y = interval_overlap_axis_loss(child_world_min[1], child_world_max[1], parent_y_min, parent_y_max) |
| 363 | + overlap_z = interval_overlap_axis_loss(child_world_min[2], child_world_max[2], parent_z_min, parent_z_max) |
| 364 | + |
| 365 | + # 2. Volume loss: slope * product of per-axis overlap lengths (overlap volume when slope 1.0) |
| 366 | + overlap_volume = overlap_x * overlap_y * overlap_z |
| 367 | + total_loss = self.slope * overlap_volume |
| 368 | + |
| 369 | + if self.debug: |
| 370 | + print( |
| 371 | + f" [NoCollision] X: overlap={overlap_x.item():.6f} (child_x=[{child_world_min[0].item():.4f}," |
| 372 | + f" {child_world_max[0].item():.4f}], parent_x=[{parent_x_min:.4f}, {parent_x_max:.4f}])" |
| 373 | + ) |
| 374 | + print( |
| 375 | + f" [NoCollision] Y: overlap={overlap_y.item():.6f} (child_y=[{child_world_min[1].item():.4f}," |
| 376 | + f" {child_world_max[1].item():.4f}], parent_y=[{parent_y_min:.4f}, {parent_y_max:.4f}])" |
| 377 | + ) |
| 378 | + print( |
| 379 | + f" [NoCollision] Z: overlap={overlap_z.item():.6f} (child_z=[{child_world_min[2].item():.4f}," |
| 380 | + f" {child_world_max[2].item():.4f}], parent_z=[{parent_z_min:.4f}, {parent_z_max:.4f}])" |
| 381 | + ) |
| 382 | + print(f" [NoCollision] volume={overlap_volume.item():.6f}, loss={total_loss.item():.6f}") |
| 383 | + |
| 384 | + return relation.relation_loss_weight * total_loss |
| 385 | + |
| 386 | + |
308 | 387 | class AtPositionLossStrategy(UnaryRelationLossStrategy): |
309 | 388 | """Loss strategy for AtPosition relations. |
310 | 389 |
|
|
0 commit comments