Skip to content

Commit e0377da

Browse files
authored
Merge pull request #11 from chahak13/unittests
Merging this as it just adds tests and doesn't affect the main library. Merging this would make merging #14 easier.
2 parents c6291cc + 7a28a38 commit e0377da

7 files changed

Lines changed: 349 additions & 1051 deletions

File tree

diffmpm/element.py

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def id_to_node_loc(self, id: int):
2828
-------
2929
jax.numpy.ndarray
3030
Nodal locations for the element. Shape of returned
31-
array is (nodes_in_element, ndim)
31+
array is (nodes_in_element, 1, ndim)
3232
"""
3333
node_ids = self.id_to_node_ids(id).squeeze()
3434
return self.nodes.loc[node_ids]
@@ -46,7 +46,7 @@ def id_to_node_vel(self, id: int):
4646
-------
4747
jax.numpy.ndarray
4848
Nodal velocities for the element. Shape of returned
49-
array is (nodes_in_element, ndim)
49+
array is (nodes_in_element, 1, ndim)
5050
"""
5151
node_ids = self.id_to_node_ids(id).squeeze()
5252
return self.nodes.velocity[node_ids]
@@ -84,6 +84,10 @@ def shapefn(self):
8484
def shapefn_grad(self):
8585
...
8686

87+
@abc.abstractmethod
88+
def set_particle_element_ids(self):
89+
...
90+
8791
# Mapping from particles to nodes (P2G)
8892
def compute_nodal_mass(self, particles):
8993
r"""
@@ -195,7 +199,7 @@ def _step(pid, args):
195199
)
196200
self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
197201

198-
def compute_body_force(self, particles, gravity: float):
202+
def compute_body_force(self, particles, gravity: float | jnp.ndarray):
199203
r"""
200204
Update the nodal external force based on particle mass.
201205
@@ -235,58 +239,6 @@ def apply_concentrated_nodal_forces(self, particles, curr_time):
235239
factor * cnf.force
236240
)
237241

238-
def compute_internal_force(self, particles):
239-
r"""
240-
Update the nodal internal force based on particle mass.
241-
242-
The nodal force is updated as a sum of internal forces for
243-
all particles mapped to the node.
244-
245-
:math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)`
246-
247-
Arguments
248-
---------
249-
particles: diffmpm.particle.Particles
250-
Particles to map to the nodal values.
251-
"""
252-
253-
def _step(pid, args):
254-
(
255-
f_int,
256-
pvol,
257-
mapped_grads,
258-
el_nodes,
259-
pstress,
260-
) = args
261-
# TODO: correct matrix multiplication for n-d
262-
# update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
263-
update = -pvol[pid] * pstress[pid][0] * mapped_grads[pid]
264-
f_int = f_int.at[el_nodes[pid]].add(update[..., jnp.newaxis])
265-
return (
266-
f_int,
267-
pvol,
268-
mapped_grads,
269-
el_nodes,
270-
pstress,
271-
)
272-
273-
self.nodes.f_int = self.nodes.f_int.at[:].set(0)
274-
mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
275-
mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
276-
mapped_grads = vmap(self.shapefn_grad)(
277-
particles.reference_loc[:, jnp.newaxis, ...],
278-
mapped_coords,
279-
)
280-
args = (
281-
self.nodes.f_int,
282-
particles.volume,
283-
mapped_grads,
284-
mapped_nodes,
285-
particles.stress,
286-
)
287-
# _step(0, args)
288-
self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
289-
290242
def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
291243
"""Update the nodal momentum based on total force on nodes."""
292244
total_force = self.nodes.get_total_force()
@@ -512,6 +464,57 @@ def compute_volume(self, *args):
512464
vol = jnp.ediff1d(self.nodes.loc)
513465
self.volume = jnp.ones((self.total_elements, 1, 1)) * vol
514466

467+
def compute_internal_force(self, particles):
468+
r"""
469+
Update the nodal internal force based on particle mass.
470+
471+
The nodal force is updated as a sum of internal forces for
472+
all particles mapped to the node.
473+
474+
:math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)`
475+
476+
Arguments
477+
---------
478+
particles: diffmpm.particle.Particles
479+
Particles to map to the nodal values.
480+
"""
481+
482+
def _step(pid, args):
483+
(
484+
f_int,
485+
pvol,
486+
mapped_grads,
487+
el_nodes,
488+
pstress,
489+
) = args
490+
# TODO: correct matrix multiplication for n-d
491+
# update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
492+
update = -pvol[pid] * pstress[pid][0] * mapped_grads[pid]
493+
f_int = f_int.at[el_nodes[pid]].add(update[..., jnp.newaxis])
494+
return (
495+
f_int,
496+
pvol,
497+
mapped_grads,
498+
el_nodes,
499+
pstress,
500+
)
501+
502+
self.nodes.f_int = self.nodes.f_int.at[:].set(0)
503+
mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
504+
mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
505+
mapped_grads = vmap(self.shapefn_grad)(
506+
particles.reference_loc[:, jnp.newaxis, ...],
507+
mapped_coords,
508+
)
509+
args = (
510+
self.nodes.f_int,
511+
particles.volume,
512+
mapped_grads,
513+
mapped_nodes,
514+
particles.stress,
515+
)
516+
self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
517+
515518

516519
@register_pytree_node_class
517520
class Quadrilateral4Node(_Element):
@@ -560,7 +563,7 @@ def __init__(
560563
self.total_elements = total_elements
561564

562565
if nodes is None:
563-
total_nodes = jnp.product(self.nelements + 1)
566+
total_nodes = jnp.prod(self.nelements + 1)
564567
coords = jnp.asarray(
565568
list(
566569
itertools.product(
@@ -758,8 +761,6 @@ def _step(pid, args):
758761
el_nodes,
759762
pstress,
760763
) = args
761-
# TODO: correct matrix multiplication for n-d
762-
# update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
763764
force = jnp.zeros((mapped_grads.shape[1], 1, 2))
764765
force = force.at[:, 0, 0].set(
765766
mapped_grads[pid][:, 0] * pstress[pid][0]

0 commit comments

Comments
 (0)