Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions mujoco_warp/_src/collision_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import mat23
from mujoco_warp._src.types import mat63
from mujoco_warp._src.types import vec5
from mujoco_warp._src.warp_util import event_scope

wp.set_module_options({"enable_backward": False})
Expand Down Expand Up @@ -748,6 +749,216 @@ def _narrowphase(m: Model, d: Data, ctx: CollisionContext):
flex_narrowphase(m, d)


# Maximum geomcollisionid packed into sort key. Primitive box<>box generates at
# most 8 contacts, so larger geomcollisionid values need not contribute to the
# deterministic ordering key.
_CONTACT_SORT_GCID_MAX = 8


@wp.kernel
def _compute_contact_sort_keys(
# Model:
ngeom: int,
# Data in:
contact_geom_in: wp.array[wp.vec2i],
contact_worldid_in: wp.array[int],
contact_geomcollisionid_in: wp.array[int],
nacon_in: wp.array[int],
# In:
gcid_max: int,
# Out:
sort_keys_out: wp.array[int],
sort_indices_out: wp.array[int],
):
"""Compute composite sort keys for deterministic contact ordering."""
cid = wp.tid()
sort_indices_out[cid] = cid
if cid >= nacon_in[0]:
sort_keys_out[cid] = 2147483647 # INT_MAX: inactive contacts sort to end
return
geom = contact_geom_in[cid]
wid = contact_worldid_in[cid]
gcid = wp.min(contact_geomcollisionid_in[cid], gcid_max - 1)
sort_keys_out[cid] = ((wid * ngeom + geom[0]) * ngeom + geom[1]) * gcid_max + gcid


@wp.kernel
def _permute_contacts(
# Data in:
nacon_in: wp.array[int],
# In:
perm_in: wp.array[int],
src_dist_in: wp.array[float],
src_pos_in: wp.array[wp.vec3],
src_frame_in: wp.array[wp.mat33],
src_includemargin_in: wp.array[float],
src_friction_in: wp.array[vec5],
src_solref_in: wp.array[wp.vec2],
src_solreffriction_in: wp.array[wp.vec2],
src_solimp_in: wp.array[vec5],
src_dim_in: wp.array[int],
src_geom_in: wp.array[wp.vec2i],
src_flex_in: wp.array[wp.vec2i],
src_vert_in: wp.array[wp.vec2i],
src_worldid_in: wp.array[int],
src_type_in: wp.array[int],
src_gcid_in: wp.array[int],
src_efc_in: wp.array2d[int],
# Out:
dst_dist_out: wp.array[float],
dst_pos_out: wp.array[wp.vec3],
dst_frame_out: wp.array[wp.mat33],
dst_includemargin_out: wp.array[float],
dst_friction_out: wp.array[vec5],
dst_solref_out: wp.array[wp.vec2],
dst_solreffriction_out: wp.array[wp.vec2],
dst_solimp_out: wp.array[vec5],
dst_dim_out: wp.array[int],
dst_geom_out: wp.array[wp.vec2i],
dst_flex_out: wp.array[wp.vec2i],
dst_vert_out: wp.array[wp.vec2i],
dst_worldid_out: wp.array[int],
dst_type_out: wp.array[int],
dst_gcid_out: wp.array[int],
dst_efc_out: wp.array2d[int],
):
"""Permute contact fields using sorted indices."""
cid = wp.tid()
if cid >= nacon_in[0]:
return
src = perm_in[cid]
dst_dist_out[cid] = src_dist_in[src]
dst_pos_out[cid] = src_pos_in[src]
dst_frame_out[cid] = src_frame_in[src]
dst_includemargin_out[cid] = src_includemargin_in[src]
dst_friction_out[cid] = src_friction_in[src]
dst_solref_out[cid] = src_solref_in[src]
dst_solreffriction_out[cid] = src_solreffriction_in[src]
dst_solimp_out[cid] = src_solimp_in[src]
dst_dim_out[cid] = src_dim_in[src]
dst_geom_out[cid] = src_geom_in[src]
dst_flex_out[cid] = src_flex_in[src]
dst_vert_out[cid] = src_vert_in[src]
dst_worldid_out[cid] = src_worldid_in[src]
dst_type_out[cid] = src_type_in[src]
dst_gcid_out[cid] = src_gcid_in[src]
for j in range(src_efc_in.shape[1]):
dst_efc_out[cid, j] = src_efc_in[src, j]


def _sort_contacts(m: Model, d: Data):
"""Sort contacts by (worldid, geom0, geom1, geomcollisionid) for determinism."""
if d.naconmax == 0:
return

# Check for sort-key overflow. Fall back to no-gcid key if needed.
gcid_max = _CONTACT_SORT_GCID_MAX
if d.nworld * m.ngeom * m.ngeom * gcid_max > 2**31 - 1:
gcid_max = 1

# Allocate sort buffers (radix_sort_pairs needs 2x capacity for internal use).
sort_keys = wp.empty(2 * d.naconmax, dtype=int)
sort_indices = wp.empty(2 * d.naconmax, dtype=int)

# Step 1: Compute sort keys and initialise indices to identity.
wp.launch(
_compute_contact_sort_keys,
dim=d.naconmax,
inputs=[
m.ngeom,
d.contact.geom,
d.contact.worldid,
d.contact.geomcollisionid,
d.nacon,
gcid_max,
],
outputs=[sort_keys, sort_indices],
)

# Step 2: Stable radix sort on keys, carrying indices.
wp.utils.radix_sort_pairs(sort_keys, sort_indices, d.naconmax)

# TODO(team): investigate a single kernel that copies all contact fields to scratch.
# Step 3: Copy contact fields to temporary buffers.
tmp_dist = wp.empty_like(d.contact.dist)
tmp_pos = wp.empty_like(d.contact.pos)
tmp_frame = wp.empty_like(d.contact.frame)
tmp_includemargin = wp.empty_like(d.contact.includemargin)
tmp_friction = wp.empty_like(d.contact.friction)
tmp_solref = wp.empty_like(d.contact.solref)
tmp_solreffriction = wp.empty_like(d.contact.solreffriction)
tmp_solimp = wp.empty_like(d.contact.solimp)
tmp_dim = wp.empty_like(d.contact.dim)
tmp_geom = wp.empty_like(d.contact.geom)
tmp_flex = wp.empty_like(d.contact.flex)
tmp_vert = wp.empty_like(d.contact.vert)
tmp_worldid = wp.empty_like(d.contact.worldid)
tmp_type = wp.empty_like(d.contact.type)
tmp_gcid = wp.empty_like(d.contact.geomcollisionid)
tmp_efc = wp.empty_like(d.contact.efc_address)

wp.copy(tmp_dist, d.contact.dist)
Comment thread
mar-yan24 marked this conversation as resolved.
wp.copy(tmp_pos, d.contact.pos)
wp.copy(tmp_frame, d.contact.frame)
wp.copy(tmp_includemargin, d.contact.includemargin)
wp.copy(tmp_friction, d.contact.friction)
wp.copy(tmp_solref, d.contact.solref)
wp.copy(tmp_solreffriction, d.contact.solreffriction)
wp.copy(tmp_solimp, d.contact.solimp)
wp.copy(tmp_dim, d.contact.dim)
wp.copy(tmp_geom, d.contact.geom)
wp.copy(tmp_flex, d.contact.flex)
wp.copy(tmp_vert, d.contact.vert)
wp.copy(tmp_worldid, d.contact.worldid)
wp.copy(tmp_type, d.contact.type)
wp.copy(tmp_gcid, d.contact.geomcollisionid)
wp.copy(tmp_efc, d.contact.efc_address)

# Step 4: Gather-permute from temp buffers back into contact arrays.
wp.launch(
Comment thread
mar-yan24 marked this conversation as resolved.
_permute_contacts,
dim=d.naconmax,
inputs=[
d.nacon,
sort_indices,
tmp_dist,
tmp_pos,
tmp_frame,
tmp_includemargin,
tmp_friction,
tmp_solref,
tmp_solreffriction,
tmp_solimp,
tmp_dim,
tmp_geom,
tmp_flex,
tmp_vert,
tmp_worldid,
tmp_type,
tmp_gcid,
tmp_efc,
],
outputs=[
d.contact.dist,
d.contact.pos,
d.contact.frame,
d.contact.includemargin,
d.contact.friction,
d.contact.solref,
d.contact.solreffriction,
d.contact.solimp,
d.contact.dim,
d.contact.geom,
d.contact.flex,
d.contact.vert,
d.contact.worldid,
d.contact.type,
d.contact.geomcollisionid,
d.contact.efc_address,
],
)


@event_scope
def collision(m: Model, d: Data):
"""Runs the full collision detection pipeline.
Expand Down Expand Up @@ -782,5 +993,8 @@ def collision(m: Model, d: Data):

_narrowphase(m, d, ctx)

if m.opt.deterministic:
_sort_contacts(m, d)

if m.callback.contactfilter:
m.callback.contactfilter(m, d)
Loading