@@ -28,7 +28,7 @@ def id_to_node_loc(self, id: int):
2828 -------
2929 jax.numpy.ndarray
3030 Nodal locations for the element. Shape of returned
31- array is (nodes_in_element, ndim)
31+ array is (nodes_in_element, 1, ndim)
3232 """
3333 node_ids = self .id_to_node_ids (id ).squeeze ()
3434 return self .nodes .loc [node_ids ]
@@ -46,7 +46,7 @@ def id_to_node_vel(self, id: int):
4646 -------
4747 jax.numpy.ndarray
4848 Nodal velocities for the element. Shape of returned
49- array is (nodes_in_element, ndim)
49+ array is (nodes_in_element, 1, ndim)
5050 """
5151 node_ids = self .id_to_node_ids (id ).squeeze ()
5252 return self .nodes .velocity [node_ids ]
@@ -84,6 +84,10 @@ def shapefn(self):
8484 def shapefn_grad (self ):
8585 ...
8686
87+ @abc .abstractmethod
88+ def set_particle_element_ids (self ):
89+ ...
90+
8791 # Mapping from particles to nodes (P2G)
8892 def compute_nodal_mass (self , particles ):
8993 r"""
@@ -195,7 +199,7 @@ def _step(pid, args):
195199 )
196200 self .nodes .f_ext , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
197201
198- def compute_body_force (self , particles , gravity : float ):
202+ def compute_body_force (self , particles , gravity : float | jnp . ndarray ):
199203 r"""
200204 Update the nodal external force based on particle mass.
201205
@@ -235,58 +239,6 @@ def apply_concentrated_nodal_forces(self, particles, curr_time):
235239 factor * cnf .force
236240 )
237241
238- def compute_internal_force (self , particles ):
239- r"""
240- Update the nodal internal force based on particle mass.
241-
242- The nodal force is updated as a sum of internal forces for
243- all particles mapped to the node.
244-
245- :math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)`
246-
247- Arguments
248- ---------
249- particles: diffmpm.particle.Particles
250- Particles to map to the nodal values.
251- """
252-
253- def _step (pid , args ):
254- (
255- f_int ,
256- pvol ,
257- mapped_grads ,
258- el_nodes ,
259- pstress ,
260- ) = args
261- # TODO: correct matrix multiplication for n-d
262- # update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
263- update = - pvol [pid ] * pstress [pid ][0 ] * mapped_grads [pid ]
264- f_int = f_int .at [el_nodes [pid ]].add (update [..., jnp .newaxis ])
265- return (
266- f_int ,
267- pvol ,
268- mapped_grads ,
269- el_nodes ,
270- pstress ,
271- )
272-
273- self .nodes .f_int = self .nodes .f_int .at [:].set (0 )
274- mapped_nodes = vmap (self .id_to_node_ids )(particles .element_ids ).squeeze (- 1 )
275- mapped_coords = vmap (self .id_to_node_loc )(particles .element_ids ).squeeze (2 )
276- mapped_grads = vmap (self .shapefn_grad )(
277- particles .reference_loc [:, jnp .newaxis , ...],
278- mapped_coords ,
279- )
280- args = (
281- self .nodes .f_int ,
282- particles .volume ,
283- mapped_grads ,
284- mapped_nodes ,
285- particles .stress ,
286- )
287- # _step(0, args)
288- self .nodes .f_int , _ , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
289-
290242 def update_nodal_acceleration_velocity (self , particles , dt : float , * args ):
291243 """Update the nodal momentum based on total force on nodes."""
292244 total_force = self .nodes .get_total_force ()
@@ -512,6 +464,57 @@ def compute_volume(self, *args):
512464 vol = jnp .ediff1d (self .nodes .loc )
513465 self .volume = jnp .ones ((self .total_elements , 1 , 1 )) * vol
514466
467+ def compute_internal_force (self , particles ):
468+ r"""
469+ Update the nodal internal force based on particle mass.
470+
471+ The nodal force is updated as a sum of internal forces for
472+ all particles mapped to the node.
473+
474+ :math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)`
475+
476+ Arguments
477+ ---------
478+ particles: diffmpm.particle.Particles
479+ Particles to map to the nodal values.
480+ """
481+
482+ def _step (pid , args ):
483+ (
484+ f_int ,
485+ pvol ,
486+ mapped_grads ,
487+ el_nodes ,
488+ pstress ,
489+ ) = args
490+ # TODO: correct matrix multiplication for n-d
491+ # update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
492+ update = - pvol [pid ] * pstress [pid ][0 ] * mapped_grads [pid ]
493+ f_int = f_int .at [el_nodes [pid ]].add (update [..., jnp .newaxis ])
494+ return (
495+ f_int ,
496+ pvol ,
497+ mapped_grads ,
498+ el_nodes ,
499+ pstress ,
500+ )
501+
502+ self .nodes .f_int = self .nodes .f_int .at [:].set (0 )
503+ mapped_nodes = vmap (self .id_to_node_ids )(particles .element_ids ).squeeze (- 1 )
504+ mapped_coords = vmap (self .id_to_node_loc )(particles .element_ids ).squeeze (2 )
505+ mapped_grads = vmap (self .shapefn_grad )(
506+ particles .reference_loc [:, jnp .newaxis , ...],
507+ mapped_coords ,
508+ )
509+ args = (
510+ self .nodes .f_int ,
511+ particles .volume ,
512+ mapped_grads ,
513+ mapped_nodes ,
514+ particles .stress ,
515+ )
516+ self .nodes .f_int , _ , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
517+
515518
516519@register_pytree_node_class
517520class Quadrilateral4Node (_Element ):
@@ -560,7 +563,7 @@ def __init__(
560563 self .total_elements = total_elements
561564
562565 if nodes is None :
563- total_nodes = jnp .product (self .nelements + 1 )
566+ total_nodes = jnp .prod (self .nelements + 1 )
564567 coords = jnp .asarray (
565568 list (
566569 itertools .product (
@@ -758,8 +761,6 @@ def _step(pid, args):
758761 el_nodes ,
759762 pstress ,
760763 ) = args
761- # TODO: correct matrix multiplication for n-d
762- # update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
763764 force = jnp .zeros ((mapped_grads .shape [1 ], 1 , 2 ))
764765 force = force .at [:, 0 , 0 ].set (
765766 mapped_grads [pid ][:, 0 ] * pstress [pid ][0 ]
0 commit comments