Skip to content

Commit 63824e0

Browse files
committed
margin and gap redesign
1 parent 40ef9d6 commit 63824e0

10 files changed

Lines changed: 100 additions & 63 deletions

File tree

mujoco_warp/_src/collision_convex.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from mujoco_warp._src.types import MJ_MAX_EPAHORIZON
3636
from mujoco_warp._src.types import MJ_MAXCONPAIR
3737
from mujoco_warp._src.types import MJ_MAXVAL
38+
from mujoco_warp._src.types import NEW_GAP_SEMANTICS
3839
from mujoco_warp._src.types import Data
3940
from mujoco_warp._src.types import DisableBit
4041
from mujoco_warp._src.types import GeomType
@@ -773,7 +774,10 @@ def eval_ccd_write_contact(
773774
if is_collision_sensor:
774775
cutoff = 1.0e32
775776
else:
776-
cutoff = 0.0
777+
if wp.static(NEW_GAP_SEMANTICS):
778+
cutoff = gap
779+
else:
780+
cutoff = 0.0
777781
dist, ncollision, w1, w2, multiccd_idx = ccd(
778782
opt_ccd_tolerance[worldid % opt_ccd_tolerance.shape[0]],
779783
cutoff,
@@ -793,8 +797,12 @@ def eval_ccd_write_contact(
793797
epa_horizon_in[ccdid],
794798
)
795799

796-
if dist >= 0.0 and pairid[1] == -1:
797-
return 0
800+
if wp.static(NEW_GAP_SEMANTICS):
801+
if dist >= gap and pairid[1] == -1:
802+
return 0
803+
else:
804+
if dist >= 0.0 and pairid[1] == -1:
805+
return 0
798806

799807
# CCD operates on margin-inflated shapes (support() inflates each geom by
800808
# 0.5 * margin). The returned dist is therefore relative to the inflated

mujoco_warp/_src/collision_core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from mujoco_warp._src.math import safe_div
2424
from mujoco_warp._src.types import MJ_MINMU
2525
from mujoco_warp._src.types import MJ_MINVAL
26+
from mujoco_warp._src.types import NEW_GAP_SEMANTICS
2627
from mujoco_warp._src.types import ContactType
2728
from mujoco_warp._src.types import GeomType
2829
from mujoco_warp._src.types import mat63
@@ -197,14 +198,15 @@ def write_contact(
197198
Returns 1 if the contact is active (dist < margin), 0 otherwise.
198199
"""
199200
active = dist_in < margin_in
201+
detected = dist_in < margin_in + gap_in
200202

201203
# skip contact and no collision sensor
202-
if (pairid_in[0] == -2 or not active) and pairid_in[1] == -1:
204+
if (pairid_in[0] == -2 or not detected) and pairid_in[1] == -1:
203205
return 0
204206

205207
contact_type = 0
206208

207-
if pairid_in[0] >= -1 and active:
209+
if pairid_in[0] >= -1 and detected:
208210
contact_type |= ContactType.CONSTRAINT
209211

210212
if pairid_in[1] >= 0:
@@ -217,7 +219,10 @@ def write_contact(
217219
contact_frame_out[cid] = frame_in
218220
contact_geom_out[cid] = geoms_in
219221
contact_worldid_out[cid] = worldid_in
220-
includemargin = margin_in - gap_in
222+
if wp.static(NEW_GAP_SEMANTICS):
223+
includemargin = margin_in
224+
else:
225+
includemargin = margin_in - gap_in
221226
contact_includemargin_out[cid] = includemargin
222227
contact_dim_out[cid] = condim_in
223228
contact_friction_out[cid] = friction_in

mujoco_warp/_src/collision_driver.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,14 @@ def _obb_filter(
271271
return True
272272

273273

274-
def _broadphase_filter(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
274+
def _broadphase_filter(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int, ngeom_gap: int):
275275
@wp.func
276276
def func(
277277
# Model:
278278
geom_aabb: wp.array3d[wp.vec3],
279279
geom_rbound: wp.array2d[float],
280280
geom_margin: wp.array2d[float],
281+
geom_gap: wp.array2d[float],
281282
# Data in:
282283
geom_xpos_in: wp.array2d[wp.vec3],
283284
geom_xmat_in: wp.array2d[wp.mat33],
@@ -299,21 +300,25 @@ def func(
299300
rbound1, rbound2 = geom_rbound[rbound_id, geom1], geom_rbound[rbound_id, geom2] # kernel_analyzer: ignore
300301
margin_id = worldid % ngeom_margin if wp.static(ngeom_margin > 1) else 0
301302
margin1, margin2 = geom_margin[margin_id, geom1], geom_margin[margin_id, geom2] # kernel_analyzer: ignore
303+
gap_id = worldid % ngeom_gap if wp.static(ngeom_gap > 1) else 0
304+
gap1, gap2 = geom_gap[gap_id, geom1], geom_gap[gap_id, geom2] # kernel_analyzer: ignore
305+
effective_margin1 = margin1 + gap1
306+
effective_margin2 = margin2 + gap2
302307
xpos1, xpos2 = geom_xpos_in[worldid, geom1], geom_xpos_in[worldid, geom2]
303308
xmat1, xmat2 = geom_xmat_in[worldid, geom1], geom_xmat_in[worldid, geom2]
304309

305310
if rbound1 == 0.0 or rbound2 == 0.0:
306311
if wp.static(opt_broadphase_filter & BroadphaseFilter.PLANE):
307-
return _plane_filter(rbound1, rbound2, margin1, margin2, xpos1, xpos2, xmat1, xmat2)
312+
return _plane_filter(rbound1, rbound2, effective_margin1, effective_margin2, xpos1, xpos2, xmat1, xmat2)
308313
else:
309314
if wp.static(opt_broadphase_filter & BroadphaseFilter.SPHERE):
310-
if not _sphere_filter(rbound1, rbound2, margin1, margin2, xpos1, xpos2):
315+
if not _sphere_filter(rbound1, rbound2, effective_margin1, effective_margin2, xpos1, xpos2):
311316
return False
312317
if wp.static(opt_broadphase_filter & BroadphaseFilter.AABB):
313-
if not _aabb_filter(center1, center2, size1, size2, margin1, margin2, xpos1, xpos2, xmat1, xmat2):
318+
if not _aabb_filter(center1, center2, size1, size2, effective_margin1, effective_margin2, xpos1, xpos2, xmat1, xmat2):
314319
return False
315320
if wp.static(opt_broadphase_filter & BroadphaseFilter.OBB):
316-
if not _obb_filter(center1, center2, size1, size2, margin1, margin2, xpos1, xpos2, xmat1, xmat2):
321+
if not _obb_filter(center1, center2, size1, size2, effective_margin1, effective_margin2, xpos1, xpos2, xmat1, xmat2):
317322
return False
318323

319324
return True
@@ -378,6 +383,7 @@ def sap_project(
378383
ngeom: int,
379384
geom_rbound: wp.array2d[float],
380385
geom_margin: wp.array2d[float],
386+
geom_gap: wp.array2d[float],
381387
# Data in:
382388
geom_xpos_in: wp.array2d[wp.vec3],
383389
nworld_in: int,
@@ -398,7 +404,7 @@ def sap_project(
398404
# geom is a plane
399405
rbound = MJ_MAXVAL
400406

401-
radius = rbound + geom_margin[worldid % geom_margin.shape[0], geomid]
407+
radius = rbound + geom_margin[worldid % geom_margin.shape[0], geomid] + geom_gap[worldid % geom_gap.shape[0], geomid]
402408
center = wp.dot(direction_in, xpos)
403409

404410
sort_index_out[worldid, geomid] = geomid
@@ -444,7 +450,7 @@ def _sap_range(
444450

445451

446452
@cache_kernel
447-
def _sap_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
453+
def _sap_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int, ngeom_gap: int):
448454
@wp.kernel(module="unique", enable_backward=False)
449455
def kernel(
450456
# Model:
@@ -453,6 +459,7 @@ def kernel(
453459
geom_aabb: wp.array3d[wp.vec3],
454460
geom_rbound: wp.array2d[float],
455461
geom_margin: wp.array2d[float],
462+
geom_gap: wp.array2d[float],
456463
nxn_pairid: wp.array[wp.vec2i],
457464
# Data in:
458465
geom_xpos_in: wp.array2d[wp.vec3],
@@ -503,8 +510,8 @@ def kernel(
503510
continue
504511

505512
if (
506-
wp.static(_broadphase_filter(opt_broadphase_filter, ngeom_aabb, ngeom_rbound, ngeom_margin))(
507-
geom_aabb, geom_rbound, geom_margin, geom_xpos_in, geom_xmat_in, geom1, geom2, worldid
513+
wp.static(_broadphase_filter(opt_broadphase_filter, ngeom_aabb, ngeom_rbound, ngeom_margin, ngeom_gap))(
514+
geom_aabb, geom_rbound, geom_margin, geom_gap, geom_xpos_in, geom_xmat_in, geom1, geom2, worldid
508515
)
509516
or pairid[1] >= 0
510517
):
@@ -588,7 +595,7 @@ def sap_broadphase(m: Model, d: Data, ctx: CollisionContext):
588595
wp.launch(
589596
kernel=_sap_project(m.opt.broadphase),
590597
dim=(d.nworld, m.ngeom),
591-
inputs=[m.ngeom, m.geom_rbound, m.geom_margin, d.geom_xpos, d.nworld, direction],
598+
inputs=[m.ngeom, m.geom_rbound, m.geom_margin, m.geom_gap, d.geom_xpos, d.nworld, direction],
592599
outputs=[
593600
projection_lower.reshape((-1, m.ngeom)),
594601
projection_upper,
@@ -624,14 +631,17 @@ def sap_broadphase(m: Model, d: Data, ctx: CollisionContext):
624631
# assumes each geom has 5 other geoms (batched over all worlds)
625632
nsweep = 5 * nworldgeom
626633
wp.launch(
627-
kernel=_sap_broadphase(m.opt.broadphase_filter, m.geom_aabb.shape[0], m.geom_rbound.shape[0], m.geom_margin.shape[0]),
634+
kernel=_sap_broadphase(
635+
m.opt.broadphase_filter, m.geom_aabb.shape[0], m.geom_rbound.shape[0], m.geom_margin.shape[0], m.geom_gap.shape[0]
636+
),
628637
dim=nsweep,
629638
inputs=[
630639
m.ngeom,
631640
m.geom_type,
632641
m.geom_aabb,
633642
m.geom_rbound,
634643
m.geom_margin,
644+
m.geom_gap,
635645
m.nxn_pairid,
636646
d.geom_xpos,
637647
d.geom_xmat,
@@ -646,14 +656,15 @@ def sap_broadphase(m: Model, d: Data, ctx: CollisionContext):
646656

647657

648658
@cache_kernel
649-
def _nxn_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
659+
def _nxn_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int, ngeom_gap: int):
650660
@wp.kernel(module="unique", enable_backward=False)
651661
def kernel(
652662
# Model:
653663
geom_type: wp.array[int],
654664
geom_aabb: wp.array3d[wp.vec3],
655665
geom_rbound: wp.array2d[float],
656666
geom_margin: wp.array2d[float],
667+
geom_gap: wp.array2d[float],
657668
nxn_geom_pair: wp.array[wp.vec2i],
658669
nxn_pairid: wp.array[wp.vec2i],
659670
# Data in:
@@ -674,8 +685,8 @@ def kernel(
674685
geom2 = geom[1]
675686

676687
if (
677-
wp.static(_broadphase_filter(opt_broadphase_filter, ngeom_aabb, ngeom_rbound, ngeom_margin))(
678-
geom_aabb, geom_rbound, geom_margin, geom_xpos_in, geom_xmat_in, geom1, geom2, worldid
688+
wp.static(_broadphase_filter(opt_broadphase_filter, ngeom_aabb, ngeom_rbound, ngeom_margin, ngeom_gap))(
689+
geom_aabb, geom_rbound, geom_margin, geom_gap, geom_xpos_in, geom_xmat_in, geom1, geom2, worldid
679690
)
680691
or nxn_pairid[elementid][1] >= 0
681692
):
@@ -711,13 +722,16 @@ def nxn_broadphase(m: Model, d: Data, ctx: CollisionContext):
711722
`contype`/`conaffinity`, parent-child relationships, and explicit `<exclude>` tags.
712723
"""
713724
wp.launch(
714-
_nxn_broadphase(m.opt.broadphase_filter, m.geom_aabb.shape[0], m.geom_rbound.shape[0], m.geom_margin.shape[0]),
725+
_nxn_broadphase(
726+
m.opt.broadphase_filter, m.geom_aabb.shape[0], m.geom_rbound.shape[0], m.geom_margin.shape[0], m.geom_gap.shape[0]
727+
),
715728
dim=(d.nworld, m.nxn_geom_pair_filtered.shape[0]),
716729
inputs=[
717730
m.geom_type,
718731
m.geom_aabb,
719732
m.geom_rbound,
720733
m.geom_margin,
734+
m.geom_gap,
721735
m.nxn_geom_pair_filtered,
722736
m.nxn_pairid_filtered,
723737
d.geom_xpos,

mujoco_warp/_src/collision_driver_test.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mujoco_warp._src.collision_driver import MJ_COLLISION_TABLE
3131
from mujoco_warp._src.collision_primitive import plane_convex
3232
from mujoco_warp._src.math import upper_trid_index
33+
from mujoco_warp._src.types import NEW_GAP_SEMANTICS
3334
from mujoco_warp.test_data.collision_sdf.utils import register_sdf_plugins
3435

3536
_TOLERANCE = 5e-5
@@ -702,7 +703,7 @@ def test_contact_pair(self, broadphase):
702703

703704
# 1 pair
704705
_, _, m, d = test_data.fixture(
705-
xml="""
706+
xml=f"""
706707
<mujoco>
707708
<worldbody>
708709
<body>
@@ -715,7 +716,7 @@ def test_contact_pair(self, broadphase):
715716
</body>
716717
</worldbody>
717718
<contact>
718-
<pair geom1="geom1" geom2="geom2" margin="2" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
719+
<pair geom1="geom1" geom2="geom2" margin="{-1 if NEW_GAP_SEMANTICS else 2}" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
719720
</contact>
720721
</mujoco>
721722
"""
@@ -746,7 +747,7 @@ def test_contact_pair(self, broadphase):
746747

747748
# 1 pair: override contype and conaffinity
748749
_, _, m, d = test_data.fixture(
749-
xml="""
750+
xml=f"""
750751
<mujoco>
751752
<worldbody>
752753
<body name="body1">
@@ -759,7 +760,7 @@ def test_contact_pair(self, broadphase):
759760
</body>
760761
</worldbody>
761762
<contact>
762-
<pair geom1="geom1" geom2="geom2" margin="2" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
763+
<pair geom1="geom1" geom2="geom2" margin="{-1 if NEW_GAP_SEMANTICS else 2}" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
763764
</contact>
764765
</mujoco>
765766
"""
@@ -790,7 +791,7 @@ def test_contact_pair(self, broadphase):
790791

791792
# 1 pair: override exclude
792793
_, _, m, d = test_data.fixture(
793-
xml="""
794+
xml=f"""
794795
<mujoco>
795796
<worldbody>
796797
<body name="body1">
@@ -804,7 +805,7 @@ def test_contact_pair(self, broadphase):
804805
</worldbody>
805806
<contact>
806807
<exclude body1="body1" body2="body2"/>
807-
<pair geom1="geom1" geom2="geom2" margin="2" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
808+
<pair geom1="geom1" geom2="geom2" margin="{-1 if NEW_GAP_SEMANTICS else 2}" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
808809
</contact>
809810
</mujoco>
810811
"""
@@ -835,7 +836,7 @@ def test_contact_pair(self, broadphase):
835836

836837
# 1 pair 1 exclude
837838
_, _, m, d = test_data.fixture(
838-
xml="""
839+
xml=f"""
839840
<mujoco>
840841
<worldbody>
841842
<body name="body1">
@@ -853,7 +854,7 @@ def test_contact_pair(self, broadphase):
853854
</worldbody>
854855
<contact>
855856
<exclude body1="body1" body2="body2"/>
856-
<pair geom1="geom2" geom2="geom3" margin="2" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
857+
<pair geom1="geom2" geom2="geom3" margin="{-1 if NEW_GAP_SEMANTICS else 2}" gap="3" condim="6" friction="5 4 3 2 1" solref="-.25 -.5" solreffriction="2 4" solimp=".1 .2 .3 .4 .5"/>
857858
</contact>
858859
</mujoco>
859860
"""
@@ -1126,22 +1127,23 @@ def test_sdf_volume_collision(self, fixture):
11261127
def test_ccd_margin_dist(self):
11271128
"""Tests that CCD contact dist matches MuJoCo when margin > 0.
11281129
1129-
Two ellipsoids are placed 0.05 m apart (not touching). With margin=0.1 on
1130-
each geom the pair margin is 0.2, so contacts are detected within the
1130+
Two ellipsoids are placed 0.05 m apart (not touching). With margin=0.01
1131+
and gap=0.2 on each geom, the pair margin is 0.02 and pair gap is 0.4.
1132+
CCD inflates geometries by margin, detecting contacts within the
11311133
speculative envelope. The reported dist must equal the true geometric
11321134
separation (≈0.05), not the margin-biased value that the inflated
11331135
GJK/EPA would produce.
11341136
"""
1135-
xml = """
1137+
xml = f"""
11361138
<mujoco>
11371139
<worldbody>
11381140
<body pos="0 0 0">
11391141
<freejoint/>
1140-
<geom type="ellipsoid" size="0.15 0.15 0.25" margin="0.1" gap="0.1"/>
1142+
<geom type="ellipsoid" size="0.15 0.15 0.25" margin="{0.01 if NEW_GAP_SEMANTICS else 0.1}" gap="0.2"/>
11411143
</body>
11421144
<body pos="0 0 0.35">
11431145
<freejoint/>
1144-
<geom type="ellipsoid" size="0.1 0.1 0.05" margin="0.1" gap="0.1"/>
1146+
<geom type="ellipsoid" size="0.1 0.1 0.05" margin="{0.01 if NEW_GAP_SEMANTICS else 0.1}" gap="0.2"/>
11451147
</body>
11461148
</worldbody>
11471149
</mujoco>
@@ -1173,7 +1175,7 @@ def test_ccd_margin_dist(self):
11731175
break
11741176
self.assertTrue(found, f"MJ contact {i} dist={mj_dist:.4f} not matched in MJW")
11751177

1176-
# Verify no constraint forces are generated (includemargin=0, dist > 0)
1178+
# dist(≈0.05) > margin(0.02): contacts are in gap zone, no constraints
11771179
self.assertEqual(mjd.nefc, 0, "Classic MuJoCo should have no active constraints")
11781180
self.assertEqual(d.nefc.numpy()[0], 0, "MuJoCo Warp should have no active constraints")
11791181

mujoco_warp/_src/constraint_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import mujoco_warp as mjw
2727
from mujoco_warp import ConeType
2828
from mujoco_warp import test_data
29+
from mujoco_warp._src.types import NEW_GAP_SEMANTICS
2930

3031
# tolerance for difference between MuJoCo and MJWarp constraint calculations,
3132
# mostly due to float precision
@@ -270,15 +271,15 @@ def test_equality_tendon(self, jacobian):
270271
def test_efc_address_inactive_contacts(self):
271272
"""Test that efc_address is -1 for inactive contacts in the gap zone."""
272273
# Sphere at z=0.35 with radius 0.1: dist ~ 0.15 to ground plane.
273-
# margin=0.5, gap=0.4 => includemargin = 0.1.
274-
# dist(0.15) < margin(0.5) => contact is detected.
275-
# dist(0.15) >= includemargin(0.1) => contact is NOT active (in gap zone).
276-
xml = """
274+
# margin=0.1, gap=0.4 => detection at margin+gap=0.5, forces at margin=0.1.
275+
# dist(0.15) < margin+gap(0.5) => contact is detected.
276+
# dist(0.15) >= margin(0.1) => contact is NOT active (in gap zone).
277+
xml = f"""
277278
<mujoco>
278279
<worldbody>
279-
<geom type="plane" size="10 10 .001" margin="0.5" gap="0.4"/>
280+
<geom type="plane" size="10 10 .001" margin="{0.1 if NEW_GAP_SEMANTICS else 0.5}" gap="0.4"/>
280281
<body pos="0 0 0.35">
281-
<geom type="sphere" size=".1" margin="0.5" gap="0.4"/>
282+
<geom type="sphere" size=".1" margin="{0.1 if NEW_GAP_SEMANTICS else 0.5}" gap="0.4"/>
282283
<freejoint/>
283284
</body>
284285
</worldbody>

0 commit comments

Comments
 (0)