Skip to content

Commit 1f77486

Browse files
committed
return tensor for all bbox
1 parent 11ba646 commit 1f77486

10 files changed

Lines changed: 155 additions & 161 deletions

File tree

isaaclab_arena/assets/object_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_bounding_box(self) -> AxisAlignedBoundingBox:
8888
Returns the bounding box with the greatest z-extent among all objects in the set.
8989
This is a heuristic to avoid objects spawning inside their support surfaces.
9090
"""
91-
return max(self.objects, key=lambda obj: obj.get_bounding_box().size[2]).get_bounding_box()
91+
return max(self.objects, key=lambda obj: obj.get_bounding_box().size[0, 2].item()).get_bounding_box()
9292

9393
def get_contact_sensor_cfg(self, contact_against_prim_paths: list[str] | None = None) -> ContactSensorCfg:
9494
# We override this function from the parent class because in some assets, the rigid body

isaaclab_arena/relations/object_placer.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,12 @@ def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlign
143143

144144
# Create bounds centered on anchor's world bounding box center
145145
anchor_bbox = anchor_object.get_world_bounding_box()
146-
center = anchor_bbox.center
147-
half_size = (
148-
self.params.init_bounds_size[0] / 2,
149-
self.params.init_bounds_size[1] / 2,
150-
self.params.init_bounds_size[2] / 2,
151-
)
146+
center = anchor_bbox.center[0] # (3,) tensor
147+
half_size = torch.tensor(self.params.init_bounds_size) / 2
152148

153149
return AxisAlignedBoundingBox(
154-
min_point=(
155-
center[0] - half_size[0],
156-
center[1] - half_size[1],
157-
center[2] - half_size[2],
158-
),
159-
max_point=(
160-
center[0] + half_size[0],
161-
center[1] + half_size[1],
162-
center[2] + half_size[2],
163-
),
150+
min_point=center - half_size,
151+
max_point=center + half_size,
164152
)
165153

166154
def _generate_initial_positions(
@@ -206,17 +194,17 @@ def _validate_on_relations(
206194
parent_world = parent.get_bounding_box().translated(positions[parent])
207195
# 1 & 2: Same as OnLossStrategy X/Y band (child's footprint within parent).
208196
if (
209-
child_world.min_point[0] < parent_world.min_point[0]
210-
or child_world.max_point[0] > parent_world.max_point[0]
211-
or child_world.min_point[1] < parent_world.min_point[1]
212-
or child_world.max_point[1] > parent_world.max_point[1]
197+
child_world.min_point[0, 0] < parent_world.min_point[0, 0]
198+
or child_world.max_point[0, 0] > parent_world.max_point[0, 0]
199+
or child_world.min_point[0, 1] < parent_world.min_point[0, 1]
200+
or child_world.max_point[0, 1] > parent_world.max_point[0, 1]
213201
):
214202
if self.params.verbose:
215203
print(f" On relation: '{obj.name}' XY outside parent (retrying)")
216204
return False
217205
# 3. Z: same as OnLossStrategy; child_bottom in (parent_top, parent_top+clearance_m], within on_relation_z_tolerance_m.
218-
parent_local_top_z: float = parent.get_bounding_box().max_point[2]
219-
child_local_bottom_z: float = obj.get_bounding_box().min_point[2]
206+
parent_local_top_z: float = parent.get_bounding_box().max_point[0, 2].item()
207+
child_local_bottom_z: float = obj.get_bounding_box().min_point[0, 2].item()
220208
parent_top_z = parent_local_top_z + positions[parent][2]
221209
clearance_m = rel.clearance_m
222210
child_bottom_z = child_local_bottom_z + positions[obj][2]
@@ -240,7 +228,7 @@ def _validate_no_overlap(
240228
a_world = a.get_bounding_box().translated(positions[a])
241229
b_world = b.get_bounding_box().translated(positions[b])
242230

243-
if a_world.overlaps(b_world, margin=self.params.min_separation_m):
231+
if a_world.overlaps(b_world, margin=self.params.min_separation_m).item():
244232
if self.params.verbose:
245233
print(f" Overlap between '{a.name}' and '{b.name}'")
246234
return False

isaaclab_arena/relations/relation_loss_strategies.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ def compute_loss(
158158

159159
# Parent world extents from the world bounding box
160160
if cfg.direction == Direction.POSITIVE:
161-
parent_edge = parent_world_bbox.max_point[cfg.primary_axis]
162-
child_offset = child_bbox.min_point[cfg.primary_axis]
161+
parent_edge = parent_world_bbox.max_point[0, cfg.primary_axis]
162+
child_offset = child_bbox.min_point[0, cfg.primary_axis]
163163
penalty_side = "less"
164164
else:
165-
parent_edge = parent_world_bbox.min_point[cfg.primary_axis]
166-
child_offset = child_bbox.max_point[cfg.primary_axis]
165+
parent_edge = parent_world_bbox.min_point[0, cfg.primary_axis]
166+
child_offset = child_bbox.max_point[0, cfg.primary_axis]
167167
penalty_side = "greater"
168168

169169
# 1. Half-plane loss: child must be on correct side of parent edge
@@ -175,10 +175,10 @@ def compute_loss(
175175
)
176176

177177
# 2. Band position loss: child placed at target position within parent's perpendicular extent
178-
parent_band_min = parent_world_bbox.min_point[cfg.band_axis]
179-
parent_band_max = parent_world_bbox.max_point[cfg.band_axis]
180-
valid_band_min = parent_band_min - child_bbox.min_point[cfg.band_axis]
181-
valid_band_max = parent_band_max - child_bbox.max_point[cfg.band_axis]
178+
parent_band_min = parent_world_bbox.min_point[0, cfg.band_axis]
179+
parent_band_max = parent_world_bbox.max_point[0, cfg.band_axis]
180+
valid_band_min = parent_band_min - child_bbox.min_point[0, cfg.band_axis]
181+
valid_band_max = parent_band_max - child_bbox.max_point[0, cfg.band_axis]
182182
# Convert cross_position_ratio [-1, 1] to interpolation factor [0, 1]: -1 = min, 0 = center, 1 = max
183183
t = (relation.cross_position_ratio + 1.0) / 2.0
184184
target_band_pos = valid_band_min + t * (valid_band_max - valid_band_min)
@@ -199,19 +199,19 @@ def compute_loss(
199199
band_axis_name = cfg.band_axis.name
200200
print(
201201
f" [NextTo] {relation.side.value}: child_{axis_name.lower()}="
202-
f"{child_pos[cfg.primary_axis].item():.4f}, parent_edge={parent_edge:.4f},"
202+
f"{child_pos[cfg.primary_axis].item():.4f}, parent_edge={parent_edge.item():.4f},"
203203
f" loss={half_plane_loss.item():.6f}"
204204
)
205205
print(
206206
f" [NextTo] {band_axis_name} band: child_{band_axis_name.lower()}="
207-
f"{child_pos[cfg.band_axis].item():.4f}, target={target_band_pos:.4f}"
207+
f"{child_pos[cfg.band_axis].item():.4f}, target={target_band_pos.item():.4f}"
208208
f" (cross_position_ratio={relation.cross_position_ratio:.2f},"
209-
f" range=[{valid_band_min:.4f}, {valid_band_max:.4f}]),"
209+
f" range=[{valid_band_min.item():.4f}, {valid_band_max.item():.4f}]),"
210210
f" loss={band_loss.item():.6f}"
211211
)
212212
print(
213213
f" [NextTo] Distance: child_{axis_name.lower()}="
214-
f"{child_pos[cfg.primary_axis].item():.4f}, target={target_pos:.4f},"
214+
f"{child_pos[cfg.primary_axis].item():.4f}, target={target_pos.item():.4f},"
215215
f" loss={distance_loss.item():.6f}"
216216
)
217217

@@ -257,19 +257,19 @@ def compute_loss(
257257
Weighted loss tensor.
258258
"""
259259
# Parent world-space extents from the world bounding box
260-
parent_x_min = parent_world_bbox.min_point[0]
261-
parent_x_max = parent_world_bbox.max_point[0]
262-
parent_y_min = parent_world_bbox.min_point[1]
263-
parent_y_max = parent_world_bbox.max_point[1]
264-
parent_z_max = parent_world_bbox.max_point[2] # Top surface
260+
parent_x_min = parent_world_bbox.min_point[0, 0]
261+
parent_x_max = parent_world_bbox.max_point[0, 0]
262+
parent_y_min = parent_world_bbox.min_point[0, 1]
263+
parent_y_max = parent_world_bbox.max_point[0, 1]
264+
parent_z_max = parent_world_bbox.max_point[0, 2] # Top surface
265265

266266
# Compute valid position ranges such that child's entire footprint is within parent
267-
# Child left edge = child_pos[0] + child_bbox.min_point[0], must be >= parent_x_min
268-
# Child right edge = child_pos[0] + child_bbox.max_point[0], must be <= parent_x_max
269-
valid_x_min = parent_x_min - child_bbox.min_point[0] # child's left at parent's left
270-
valid_x_max = parent_x_max - child_bbox.max_point[0] # child's right at parent's right
271-
valid_y_min = parent_y_min - child_bbox.min_point[1]
272-
valid_y_max = parent_y_max - child_bbox.max_point[1]
267+
# Child left edge = child_pos[0] + child_bbox.min_point[0, 0], must be >= parent_x_min
268+
# Child right edge = child_pos[0] + child_bbox.max_point[0, 0], must be <= parent_x_max
269+
valid_x_min = parent_x_min - child_bbox.min_point[0, 0] # child's left at parent's left
270+
valid_x_max = parent_x_max - child_bbox.max_point[0, 0] # child's right at parent's right
271+
valid_y_min = parent_y_min - child_bbox.min_point[0, 1]
272+
valid_y_max = parent_y_max - child_bbox.max_point[0, 1]
273273

274274
# 1. X band loss: child's footprint entirely within parent's X extent
275275
x_band_loss = linear_band_loss(
@@ -288,19 +288,22 @@ def compute_loss(
288288
)
289289

290290
# 3. Z point loss: child bottom = parent top + clearance
291-
target_z = parent_z_max + relation.clearance_m - child_bbox.min_point[2]
291+
target_z = parent_z_max + relation.clearance_m - child_bbox.min_point[0, 2]
292292
z_loss = single_point_linear_loss(child_pos[2], target_z, slope=self.slope)
293293

294294
if self.debug:
295295
print(
296-
f" [On] X: child_pos={child_pos[0].item():.4f}, valid_range=[{valid_x_min:.4f},"
297-
f" {valid_x_max:.4f}], loss={x_band_loss.item():.6f}"
296+
f" [On] X: child_pos={child_pos[0].item():.4f}, valid_range=[{valid_x_min.item():.4f},"
297+
f" {valid_x_max.item():.4f}], loss={x_band_loss.item():.6f}"
298298
)
299299
print(
300-
f" [On] Y: child_pos={child_pos[1].item():.4f}, valid_range=[{valid_y_min:.4f},"
301-
f" {valid_y_max:.4f}], loss={y_band_loss.item():.6f}"
300+
f" [On] Y: child_pos={child_pos[1].item():.4f}, valid_range=[{valid_y_min.item():.4f},"
301+
f" {valid_y_max.item():.4f}], loss={y_band_loss.item():.6f}"
302+
)
303+
print(
304+
f" [On] Z: child_pos={child_pos[2].item():.4f}, target={target_z.item():.4f},"
305+
f" loss={z_loss.item():.6f}"
302306
)
303-
print(f" [On] Z: child_pos={child_pos[2].item():.4f}, target={target_z:.4f}, loss={z_loss.item():.6f}")
304307

305308
total_loss = x_band_loss + y_band_loss + z_loss
306309
return relation.relation_loss_weight * total_loss
@@ -346,16 +349,16 @@ def compute_loss(
346349
"""
347350
# Parent world extents from the world bounding box, expanded by clearance_m
348351
c = relation.clearance_m
349-
parent_x_min = parent_world_bbox.min_point[0] - c
350-
parent_x_max = parent_world_bbox.max_point[0] + c
351-
parent_y_min = parent_world_bbox.min_point[1] - c
352-
parent_y_max = parent_world_bbox.max_point[1] + c
353-
parent_z_min = parent_world_bbox.min_point[2] - c
354-
parent_z_max = parent_world_bbox.max_point[2] + c
352+
parent_x_min = parent_world_bbox.min_point[0, 0] - c
353+
parent_x_max = parent_world_bbox.max_point[0, 0] + c
354+
parent_y_min = parent_world_bbox.min_point[0, 1] - c
355+
parent_y_max = parent_world_bbox.max_point[0, 1] + c
356+
parent_z_min = parent_world_bbox.min_point[0, 2] - c
357+
parent_z_max = parent_world_bbox.max_point[0, 2] + c
355358

356359
# Child world extents
357-
child_world_min = child_pos + torch.tensor(child_bbox.min_point, dtype=child_pos.dtype, device=child_pos.device)
358-
child_world_max = child_pos + torch.tensor(child_bbox.max_point, dtype=child_pos.dtype, device=child_pos.device)
360+
child_world_min = child_pos + child_bbox.min_point[0]
361+
child_world_max = child_pos + child_bbox.max_point[0]
359362

360363
# 1. Per-axis overlap: zero when separated; else overlap length (default slope 1.0 gives length in m)
361364
overlap_x = interval_overlap_axis_loss(child_world_min[0], child_world_max[0], parent_x_min, parent_x_max)
@@ -369,15 +372,18 @@ def compute_loss(
369372
if self.debug:
370373
print(
371374
f" [NoCollision] X: overlap={overlap_x.item():.6f} (child_x=[{child_world_min[0].item():.4f},"
372-
f" {child_world_max[0].item():.4f}], parent_x=[{parent_x_min:.4f}, {parent_x_max:.4f}])"
375+
f" {child_world_max[0].item():.4f}], parent_x=[{parent_x_min.item():.4f},"
376+
f" {parent_x_max.item():.4f}])"
373377
)
374378
print(
375379
f" [NoCollision] Y: overlap={overlap_y.item():.6f} (child_y=[{child_world_min[1].item():.4f},"
376-
f" {child_world_max[1].item():.4f}], parent_y=[{parent_y_min:.4f}, {parent_y_max:.4f}])"
380+
f" {child_world_max[1].item():.4f}], parent_y=[{parent_y_min.item():.4f},"
381+
f" {parent_y_max.item():.4f}])"
377382
)
378383
print(
379384
f" [NoCollision] Z: overlap={overlap_z.item():.6f} (child_z=[{child_world_min[2].item():.4f},"
380-
f" {child_world_max[2].item():.4f}], parent_z=[{parent_z_min:.4f}, {parent_z_max:.4f}])"
385+
f" {child_world_max[2].item():.4f}], parent_z=[{parent_z_min.item():.4f},"
386+
f" {parent_z_max.item():.4f}])"
381387
)
382388
print(f" [NoCollision] volume={overlap_volume.item():.6f}, loss={total_loss.item():.6f}")
383389

isaaclab_arena/relations/relation_solver.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,21 +238,34 @@ def _print_relation_debug(
238238

239239
print(f"\n=== {obj.name} -> {type(relation).__name__}({relation.parent.name}) ===")
240240
print(f" Child pos: ({child_pos[0].item():.4f}, {child_pos[1].item():.4f}, {child_pos[2].item():.4f})")
241-
print(f" Child bbox: min={child_bbox.min_point}, max={child_bbox.max_point}, size={child_bbox.size}")
241+
print(f" Child bbox: min={child_bbox.min_point[0].tolist()}, max={child_bbox.max_point[0].tolist()},"
242+
f" size={child_bbox.size[0].tolist()}")
242243
print(f" Parent pos: ({parent_pos[0].item():.4f}, {parent_pos[1].item():.4f}, {parent_pos[2].item():.4f})")
243244
print(
244-
f" Parent world bbox: min={parent_world_bbox.min_point}, max={parent_world_bbox.max_point},"
245-
f" size={parent_world_bbox.size}"
245+
f" Parent world bbox: min={parent_world_bbox.min_point[0].tolist()},"
246+
f" max={parent_world_bbox.max_point[0].tolist()}, size={parent_world_bbox.size[0].tolist()}"
246247
)
247248

248249
# Child world extents
249-
child_x_range = (child_pos[0].item() + child_bbox.min_point[0], child_pos[0].item() + child_bbox.max_point[0])
250-
child_y_range = (child_pos[1].item() + child_bbox.min_point[1], child_pos[1].item() + child_bbox.max_point[1])
250+
child_x_range = (
251+
child_pos[0].item() + child_bbox.min_point[0, 0].item(),
252+
child_pos[0].item() + child_bbox.max_point[0, 0].item(),
253+
)
254+
child_y_range = (
255+
child_pos[1].item() + child_bbox.min_point[0, 1].item(),
256+
child_pos[1].item() + child_bbox.max_point[0, 1].item(),
257+
)
251258

252259
print(f" Child world X: [{child_x_range[0]:.4f}, {child_x_range[1]:.4f}]")
253260
print(f" Child world Y: [{child_y_range[0]:.4f}, {child_y_range[1]:.4f}]")
254-
print(f" Parent world X: [{parent_world_bbox.min_point[0]:.4f}, {parent_world_bbox.max_point[0]:.4f}]")
255-
print(f" Parent world Y: [{parent_world_bbox.min_point[1]:.4f}, {parent_world_bbox.max_point[1]:.4f}]")
261+
print(
262+
f" Parent world X: [{parent_world_bbox.min_point[0, 0].item():.4f},"
263+
f" {parent_world_bbox.max_point[0, 0].item():.4f}]"
264+
)
265+
print(
266+
f" Parent world Y: [{parent_world_bbox.min_point[0, 1].item():.4f},"
267+
f" {parent_world_bbox.max_point[0, 1].item():.4f}]"
268+
)
256269
print(f" Loss: {loss.item():.6f}")
257270

258271

@@ -270,5 +283,6 @@ def _print_unary_relation_debug(
270283
)
271284
print(f"\n=== {obj.name} -> {type(relation).__name__}({target_str}) ===")
272285
print(f" Child pos: ({child_pos[0].item():.4f}, {child_pos[1].item():.4f}, {child_pos[2].item():.4f})")
273-
print(f" Child bbox: min={child_bbox.min_point}, max={child_bbox.max_point}, size={child_bbox.size}")
286+
print(f" Child bbox: min={child_bbox.min_point[0].tolist()}, max={child_bbox.max_point[0].tolist()},"
287+
f" size={child_bbox.size[0].tolist()}")
274288
print(f" Loss: {loss.item():.6f}")

0 commit comments

Comments
 (0)