@@ -573,6 +573,7 @@ def _flex_elasticity(
573573 flex_vertadr : wp .array (dtype = int ),
574574 flex_edgeadr : wp .array (dtype = int ),
575575 flex_elemadr : wp .array (dtype = int ),
576+ flex_elemdataadr : wp .array (dtype = int ),
576577 flex_elemnum : wp .array (dtype = int ),
577578 flex_elemedgeadr : wp .array (dtype = int ),
578579 flex_vertbodyid : wp .array (dtype = int ),
@@ -593,12 +594,14 @@ def _flex_elasticity(
593594 worldid , elemid = wp .tid ()
594595 timestep = opt_timestep [worldid % opt_timestep .shape [0 ]]
595596
597+ f = int (0 )
596598 for i in range (nflex ):
597599 locid = elemid - flex_elemadr [i ]
598600 if locid >= 0 and locid < flex_elemnum [i ]:
599601 f = i
600602 break
601603
604+ local_elemid = elemid - flex_elemadr [f ]
602605 dim = flex_dim [f ]
603606 nvert = dim + 1
604607 nedge = nvert * (nvert - 1 ) / 2
@@ -612,10 +615,11 @@ def _flex_elasticity(
612615 else :
613616 kD = 0.0
614617
618+ elem_data_adr = flex_elemdataadr [f ] + local_elemid * (dim + 1 )
615619 gradient = wp .matrix (0.0 , shape = (6 , 6 ))
616620 for e in range (nedge ):
617- vert0 = flex_elem [( dim + 1 ) * elemid + edges [e , 0 ]]
618- vert1 = flex_elem [( dim + 1 ) * elemid + edges [e , 1 ]]
621+ vert0 = flex_elem [elem_data_adr + edges [e , 0 ]]
622+ vert1 = flex_elem [elem_data_adr + edges [e , 1 ]]
619623 xpos0 = flexvert_xpos_in [worldid , vert0 ]
620624 xpos1 = flexvert_xpos_in [worldid , vert1 ]
621625 for i in range (3 ):
@@ -624,7 +628,7 @@ def _flex_elasticity(
624628
625629 elongation = wp .spatial_vectorf (0.0 )
626630 for e in range (nedge ):
627- idx = flex_elemedge [elemid * nedge + e ]
631+ idx = flex_elemedge [flex_elemedgeadr [ f ] + local_elemid * int ( nedge ) + e ]
628632 vel = flexedge_velocity_in [worldid , flex_edgeadr [f ] + idx ]
629633 deformed = flexedge_length_in [worldid , flex_edgeadr [f ] + idx ]
630634 reference = flexedge_length0 [flex_edgeadr [f ] + idx ]
@@ -647,7 +651,7 @@ def _flex_elasticity(
647651 force [edges [ed2 , i ], x ] -= elongation [ed1 ] * gradient [ed2 , 3 * i + x ] * metric [ed1 , ed2 ]
648652
649653 for v in range (nvert ):
650- vert = flex_elem [( dim + 1 ) * elemid + v ]
654+ vert = flex_elem [elem_data_adr + v ]
651655 bodyid = flex_vertbodyid [flex_vertadr [f ] + vert ]
652656 for x in range (3 ):
653657 wp .atomic_add (qfrc_spring_out , worldid , body_dofadr [bodyid ] + x , force [v , x ])
@@ -783,6 +787,7 @@ def passive(m: Model, d: Data):
783787 m .flex_vertadr ,
784788 m .flex_edgeadr ,
785789 m .flex_elemadr ,
790+ m .flex_elemdataadr ,
786791 m .flex_elemnum ,
787792 m .flex_elemedgeadr ,
788793 m .flex_vertbodyid ,
0 commit comments