Skip to content

Commit 54ab811

Browse files
committed
Swap locations of two functions
`set_particle_element_ids` is moved to element.py since the element id mapping is based on what type of element the particles lie in. On the other hand, `update_natural_coords` was moved to particle.py since it is an intrinsic property of the particles and is just updated based on the elements passed as argument.
1 parent 153a5b6 commit 54ab811

4 files changed

Lines changed: 49 additions & 54 deletions

File tree

diffmpm/element.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -233,33 +233,35 @@ def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray):
233233
result = grad_sf.T @ jnp.linalg.inv(_jacobian)
234234
return result.T
235235

236-
# TODO: See if this is generic enough to be moved to
237-
# `Particles` instead.
238-
def update_particle_natural_coords(self, particles):
236+
def set_particle_element_ids(self, particles):
239237
"""
240-
Update natural coordinates for the particles.
238+
Set the element IDs for the particles.
241239
242-
Whenever the particles' physical coordinates change, their
243-
natural coordinates need to be updated. This function updates
244-
the natural coordinates of the particles based on the element
245-
a particle is a part of. The update formula is
246-
247-
:math:`xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e)`
240+
If the particle doesn't lie between the boundaries of any
241+
element, it sets the element index to -1.
242+
"""
248243

249-
If a particle is not in any element (element_id = -1), its
250-
natural coordinate is set to 0.
244+
@jit
245+
def f(x):
246+
idl = (
247+
len(self.nodes.loc)
248+
- 1
249+
- jnp.asarray(self.nodes.loc[::-1] <= x).nonzero(
250+
size=1, fill_value=-1
251+
)[0][-1]
252+
)
253+
idg = (
254+
jnp.asarray(self.nodes.loc > x).nonzero(size=1, fill_value=-1)[
255+
0
256+
][0]
257+
- 1
258+
)
259+
return (idl, idg)
251260

252-
Arguments
253-
---------
254-
particles: diffmpm.particle.Particles
255-
Particles whose natural coordinates need to be updated based
256-
on these elements.
257-
"""
258-
t = self.id_to_node_loc(particles.element_ids)
259-
xi_coords = (particles.loc - (t[:, 0, ...] + t[:, 1, ...]) / 2) * (
260-
2 / (t[:, 1, ...] - t[:, 0, ...])
261+
ids = vmap(f)(particles.loc)
262+
particles.element_ids = jnp.where(
263+
ids[0] == ids[1], ids[0], jnp.ones_like(ids[0]) * -1
261264
)
262-
particles.reference_loc = xi_coords
263265

264266
# Mapping from particles to nodes (P2G)
265267
def compute_nodal_mass(self, particles):

diffmpm/particle.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -143,38 +143,31 @@ def set_mass_volume(self, m: float | jnp.ndarray):
143143
)
144144
self.volume = jnp.divide(self.mass, self.material.properties["density"])
145145

146-
# TODO: This needs to be in element so that each type of element
147-
# can govern how ot check if particle is in an element.
148-
# This implementation is true only for 1D.
149-
def set_particle_element_ids(self, elements: _Element):
146+
def update_natural_coords(self, elements: _Element):
150147
"""
151-
Set the element IDs for the particles.
148+
Update natural coordinates for the particles.
152149
153-
If the particle doesn't lie between the boundaries of any
154-
element, it sets the element index to -1.
155-
"""
150+
Whenever the particles' physical coordinates change, their
151+
natural coordinates need to be updated. This function updates
152+
the natural coordinates of the particles based on the element
153+
a particle is a part of. The update formula is
156154
157-
@jit
158-
def f(x):
159-
idl = (
160-
len(elements.nodes.loc)
161-
- 1
162-
- jnp.asarray(elements.nodes.loc[::-1] <= x).nonzero(
163-
size=1, fill_value=-1
164-
)[0][-1]
165-
)
166-
idg = (
167-
jnp.asarray(elements.nodes.loc > x).nonzero(
168-
size=1, fill_value=-1
169-
)[0][0]
170-
- 1
171-
)
172-
return (idl, idg)
155+
:math:`xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e)`
156+
157+
If a particle is not in any element (element_id = -1), its
158+
natural coordinate is set to 0.
173159
174-
ids = vmap(f)(self.loc)
175-
self.element_ids = jnp.where(
176-
ids[0] == ids[1], ids[0], jnp.ones_like(ids[0]) * -1
160+
Arguments
161+
---------
162+
elements: diffmpm.element._Element
163+
Elements based on which to update the natural coordinates
164+
of the particles.
165+
"""
166+
t = elements.id_to_node_loc(self.element_ids)
167+
xi_coords = (self.loc - (t[:, 0, ...] + t[:, 1, ...]) / 2) * (
168+
2 / (t[:, 1, ...] - t[:, 0, ...])
177169
)
170+
self.reference_loc = xi_coords
178171

179172
def update_position_velocity(self, elements: _Element, dt: float):
180173
"""

diffmpm/scheme.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def __init__(self, mesh, dt):
99
self.dt = dt
1010

1111
def compute_nodal_kinematics(self):
12-
self.mesh.apply_on_elements("update_particle_natural_coords")
13-
self.mesh.apply_on_particles("set_particle_element_ids")
12+
self.mesh.apply_on_particles("update_natural_coords")
13+
self.mesh.apply_on_elements("set_particle_element_ids")
1414
self.mesh.apply_on_elements("compute_nodal_mass")
1515
self.mesh.apply_on_elements("compute_nodal_momentum")
1616
self.mesh.apply_on_elements("apply_boundary_constraints")

examples/optim_1d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import optax
44
from diffmpm.element import Linear1D
55
from diffmpm.material import SimpleMaterial
6-
from diffmpm.mesh import _MeshBase
6+
from diffmpm.mesh import Mesh1D
77
from diffmpm.particle import Particles
88
from diffmpm.solver import MPMExplicit
99
from jax import value_and_grad, grad, jit
@@ -20,7 +20,7 @@
2020
particles.set_mass_volume(1.0)
2121
dt = 0.01
2222
nsteps = 1000
23-
mesh = _MeshBase({"particles": [particles], "elements": elements})
23+
mesh = Mesh1D({"particles": [particles], "elements": elements})
2424

2525
mpm = MPMExplicit(mesh, dt, scheme="usl")
2626
true_result = mpm.solve_jit(nsteps, 0)
@@ -68,7 +68,7 @@ def optax_adam(params, niter, mpm, target_vel):
6868
)
6969
particles.velocity += 0.1
7070
particles.set_mass_volume(1.0)
71-
mesh = _MeshBase({"particles": [particles], "elements": elements})
71+
mesh = Mesh1D({"particles": [particles], "elements": elements})
7272

7373
mpm = MPMExplicit(mesh, dt, scheme="usl")
7474
param_list, loss_list = optax_adam(

0 commit comments

Comments
 (0)