@@ -52,12 +52,14 @@ def id_to_node_vel(self, id: int):
5252 return self .nodes .velocity [node_ids ]
5353
5454 def tree_flatten (self ):
55- children = (self .nodes ,)
55+ children = (self .nodes , self . volume )
5656 aux_data = (
5757 self .nelements ,
58+ self .total_elements ,
5859 self .el_len ,
5960 self .constraints ,
6061 self .concentrated_nodal_forces ,
62+ self .initialized ,
6163 )
6264 return children , aux_data
6365
@@ -67,8 +69,11 @@ def tree_unflatten(cls, aux_data, children):
6769 aux_data [0 ],
6870 aux_data [1 ],
6971 aux_data [2 ],
72+ aux_data [3 ],
7073 nodes = children [0 ],
71- concentrated_nodal_forces = aux_data [3 ],
74+ concentrated_nodal_forces = aux_data [4 ],
75+ initialized = aux_data [5 ],
76+ volume = children [1 ],
7277 )
7378
7479 @abc .abstractmethod
@@ -142,7 +147,7 @@ def _step(pid, args):
142147 )
143148 _ , self .nodes .momentum , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
144149 self .nodes .momentum = jnp .where (
145- self .nodes .momentum < 1e-12 ,
150+ jnp . abs ( self .nodes .momentum ) < 1e-12 ,
146151 jnp .zeros_like (self .nodes .momentum ),
147152 self .nodes .momentum ,
148153 )
@@ -154,7 +159,7 @@ def compute_velocity(self, particles):
154159 self .nodes .momentum / self .nodes .mass ,
155160 )
156161 self .nodes .velocity = jnp .where (
157- self .nodes .velocity < 1e-12 ,
162+ jnp . abs ( self .nodes .velocity ) < 1e-12 ,
158163 jnp .zeros_like (self .nodes .velocity ),
159164 self .nodes .velocity ,
160165 )
@@ -296,12 +301,12 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
296301 self .nodes .mass * self .nodes .velocity
297302 )
298303 self .nodes .velocity = jnp .where (
299- self .nodes .velocity < 1e-12 ,
304+ jnp . abs ( self .nodes .velocity ) < 1e-12 ,
300305 jnp .zeros_like (self .nodes .velocity ),
301306 self .nodes .velocity ,
302307 )
303308 self .nodes .acceleration = jnp .where (
304- self .nodes .acceleration < 1e-12 ,
309+ jnp . abs ( self .nodes .acceleration ) < 1e-12 ,
305310 jnp .zeros_like (self .nodes .acceleration ),
306311 self .nodes .acceleration ,
307312 )
@@ -350,7 +355,7 @@ def __init__(
350355 IDs of nodes that are supposed to be fixed (boundary).
351356 """
352357 self .nelements = nelements
353- self .ids = jnp . arange ( nelements )
358+ self .total_elements = nelements
354359 self .el_len = el_len
355360 if nodes is None :
356361 self .nodes = Nodes (
@@ -495,7 +500,7 @@ def f(x):
495500
496501 def compute_volume (self ):
497502 vol = jnp .ediff1d (self .nodes .loc )
498- self .volume = jnp .ones ((len ( self .ids ) , 1 , 1 )) * vol
503+ self .volume = jnp .ones ((self .total_elements , 1 , 1 )) * vol
499504
500505
501506@register_pytree_node_class
@@ -523,10 +528,13 @@ class Quadrilateral4Node(_Element):
523528 def __init__ (
524529 self ,
525530 nelements : Tuple [int , int ],
531+ total_elements : int ,
526532 el_len : Tuple [float , float ],
527533 constraints : List [Tuple [jnp .ndarray , Constraint ]],
528534 nodes : Nodes = None ,
529535 concentrated_nodal_forces = [],
536+ initialized : bool = None ,
537+ volume : jnp .ndarray = None ,
530538 ):
531539 """Initialize Quadrilateral4Node.
532540
@@ -539,8 +547,7 @@ def __init__(
539547 """
540548 self .nelements = jnp .asarray (nelements )
541549 self .el_len = jnp .asarray (el_len )
542- total_elements = jnp .product (self .nelements )
543- self .ids = jnp .arange (total_elements )
550+ self .total_elements = total_elements
544551
545552 if nodes is None :
546553 total_nodes = jnp .product (self .nelements + 1 )
@@ -561,6 +568,11 @@ def __init__(
561568
562569 self .constraints = constraints
563570 self .concentrated_nodal_forces = concentrated_nodal_forces
571+ if initialized is None :
572+ self .volume = jnp .ones ((self .total_elements , 1 , 1 ))
573+ else :
574+ self .volume = volume
575+ self .initialized = True
564576
565577 def id_to_node_ids (self , id : int ):
566578 """
@@ -773,12 +785,12 @@ def _step(pid, args):
773785 )
774786 self .nodes .f_int , _ , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
775787
776- def compute_volume (self ):
788+ def compute_volume (self , * args ):
777789 a = c = self .el_len [1 ]
778790 b = d = self .el_len [0 ]
779791 p = q = jnp .sqrt (a ** 2 + b ** 2 )
780792 vol = 0.25 * jnp .sqrt (4 * p * p * q * q - (a * a + c * c - b * b - d * d ) ** 2 )
781- self .volume = jnp . ones (( len ( self .ids ), 1 , 1 )) * vol
793+ self .volume = self .volume . at [:]. set ( vol )
782794
783795
784796if __name__ == "__main__" :
0 commit comments