Skip to content

Commit 3c81712

Browse files
committed
Fix multi flex indexing
1 parent c290bdb commit 3c81712

3 files changed

Lines changed: 13 additions & 5 deletions

File tree

mujoco_warp/_src/collision_flex.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def _flex_narrowphase_dim2(
397397
flex_dim: wp.array(dtype=int),
398398
flex_vertadr: wp.array(dtype=int),
399399
flex_elemadr: wp.array(dtype=int),
400+
flex_elemdataadr: wp.array(dtype=int),
400401
flex_elemnum: wp.array(dtype=int),
401402
flex_elem: wp.array(dtype=int),
402403
flex_radius: wp.array(dtype=float),
@@ -443,7 +444,7 @@ def _flex_narrowphase_dim2(
443444
tri_radius = flex_radius[flexid]
444445
tri_margin = flex_margin[flexid]
445446

446-
elem_data_idx = elemid * 3
447+
elem_data_idx = flex_elemdataadr[flexid] + (elemid - flex_elemadr[flexid]) * 3
447448
v0_local = flex_elem[elem_data_idx]
448449
v1_local = flex_elem[elem_data_idx + 1]
449450
v2_local = flex_elem[elem_data_idx + 2]
@@ -708,6 +709,7 @@ def flex_narrowphase(m: Model, d: Data):
708709
m.flex_dim,
709710
m.flex_vertadr,
710711
m.flex_elemadr,
712+
m.flex_elemdataadr,
711713
m.flex_elemnum,
712714
m.flex_elem,
713715
m.flex_radius,

mujoco_warp/_src/passive.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ def _flex_elasticity(
573573
flex_vertadr: wp.array(dtype=int),
574574
flex_edgeadr: wp.array(dtype=int),
575575
flex_elemadr: wp.array(dtype=int),
576+
flex_elemdataadr: wp.array(dtype=int),
576577
flex_elemnum: wp.array(dtype=int),
577578
flex_elemedgeadr: wp.array(dtype=int),
578579
flex_vertbodyid: wp.array(dtype=int),
@@ -593,12 +594,14 @@ def _flex_elasticity(
593594
worldid, elemid = wp.tid()
594595
timestep = opt_timestep[worldid % opt_timestep.shape[0]]
595596

597+
f = int(0)
596598
for i in range(nflex):
597599
locid = elemid - flex_elemadr[i]
598600
if locid >= 0 and locid < flex_elemnum[i]:
599601
f = i
600602
break
601603

604+
local_elemid = elemid - flex_elemadr[f]
602605
dim = flex_dim[f]
603606
nvert = dim + 1
604607
nedge = nvert * (nvert - 1) / 2
@@ -612,10 +615,11 @@ def _flex_elasticity(
612615
else:
613616
kD = 0.0
614617

618+
elem_data_adr = flex_elemdataadr[f] + local_elemid * (dim + 1)
615619
gradient = wp.matrix(0.0, shape=(6, 6))
616620
for e in range(nedge):
617-
vert0 = flex_elem[(dim + 1) * elemid + edges[e, 0]]
618-
vert1 = flex_elem[(dim + 1) * elemid + edges[e, 1]]
621+
vert0 = flex_elem[elem_data_adr + edges[e, 0]]
622+
vert1 = flex_elem[elem_data_adr + edges[e, 1]]
619623
xpos0 = flexvert_xpos_in[worldid, vert0]
620624
xpos1 = flexvert_xpos_in[worldid, vert1]
621625
for i in range(3):
@@ -624,7 +628,7 @@ def _flex_elasticity(
624628

625629
elongation = wp.spatial_vectorf(0.0)
626630
for e in range(nedge):
627-
idx = flex_elemedge[elemid * nedge + e]
631+
idx = flex_elemedge[flex_elemedgeadr[f] + local_elemid * int(nedge) + e]
628632
vel = flexedge_velocity_in[worldid, flex_edgeadr[f] + idx]
629633
deformed = flexedge_length_in[worldid, flex_edgeadr[f] + idx]
630634
reference = flexedge_length0[flex_edgeadr[f] + idx]
@@ -647,7 +651,7 @@ def _flex_elasticity(
647651
force[edges[ed2, i], x] -= elongation[ed1] * gradient[ed2, 3 * i + x] * metric[ed1, ed2]
648652

649653
for v in range(nvert):
650-
vert = flex_elem[(dim + 1) * elemid + v]
654+
vert = flex_elem[elem_data_adr + v]
651655
bodyid = flex_vertbodyid[flex_vertadr[f] + vert]
652656
for x in range(3):
653657
wp.atomic_add(qfrc_spring_out, worldid, body_dofadr[bodyid] + x, force[v, x])
@@ -783,6 +787,7 @@ def passive(m: Model, d: Data):
783787
m.flex_vertadr,
784788
m.flex_edgeadr,
785789
m.flex_elemadr,
790+
m.flex_elemdataadr,
786791
m.flex_elemnum,
787792
m.flex_elemedgeadr,
788793
m.flex_vertbodyid,

mujoco_warp/_src/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,7 @@ class Model:
13651365
flex_edgeadr: array("nflex", int)
13661366
flex_edgenum: array("nflex", int)
13671367
flex_elemadr: array("nflex", int)
1368+
flex_elemdataadr: array("nflex", int)
13681369
flex_elemnum: array("nflex", int)
13691370
flex_elemedgeadr: array("nflex", int)
13701371
flex_shellnum: array("nflex", int)

0 commit comments

Comments
 (0)