Skip to content

Commit 5416bc8

Browse files
committed
Implement correct 2d JIT solve!
To correct the 2D JIT solve with correct volume computation, 2 major changes were made: 1. The elements class constructor now takes the total number of elements as an argument along with number of elements in both directions. This is so that `jnp.bincount` can work correctly without having to pass a tracer value to its `length` parameter. 2. Update the `MPMExplicit` flattening/unflattening to include the `velocity_update` parameter to make it work with JIT.
1 parent c76842b commit 5416bc8

6 files changed

Lines changed: 50 additions & 56 deletions

File tree

diffmpm/element.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@ def id_to_node_vel(self, id: int):
5252
return self.nodes.velocity[node_ids]
5353

5454
def tree_flatten(self):
55-
children = (self.nodes,)
55+
children = (self.nodes, self.volume)
5656
aux_data = (
5757
self.nelements,
58+
self.total_elements,
5859
self.el_len,
5960
self.constraints,
6061
self.concentrated_nodal_forces,
62+
self.initialized,
6163
)
6264
return children, aux_data
6365

@@ -67,8 +69,11 @@ def tree_unflatten(cls, aux_data, children):
6769
aux_data[0],
6870
aux_data[1],
6971
aux_data[2],
72+
aux_data[3],
7073
nodes=children[0],
71-
concentrated_nodal_forces=aux_data[3],
74+
concentrated_nodal_forces=aux_data[4],
75+
initialized=aux_data[5],
76+
volume=children[1],
7277
)
7378

7479
@abc.abstractmethod
@@ -350,7 +355,7 @@ def __init__(
350355
IDs of nodes that are supposed to be fixed (boundary).
351356
"""
352357
self.nelements = nelements
353-
self.ids = jnp.arange(nelements)
358+
self.total_elements = nelements
354359
self.el_len = el_len
355360
if nodes is None:
356361
self.nodes = Nodes(
@@ -495,7 +500,7 @@ def f(x):
495500

496501
def compute_volume(self):
497502
vol = jnp.ediff1d(self.nodes.loc)
498-
self.volume = jnp.ones((len(self.ids), 1, 1)) * vol
503+
self.volume = jnp.ones((self.total_elements, 1, 1)) * vol
499504

500505

501506
@register_pytree_node_class
@@ -523,10 +528,13 @@ class Quadrilateral4Node(_Element):
523528
def __init__(
524529
self,
525530
nelements: Tuple[int, int],
531+
total_elements: int,
526532
el_len: Tuple[float, float],
527533
constraints: List[Tuple[jnp.ndarray, Constraint]],
528534
nodes: Nodes = None,
529535
concentrated_nodal_forces=[],
536+
initialized: bool = None,
537+
volume: jnp.ndarray = None,
530538
):
531539
"""Initialize Quadrilateral4Node.
532540
@@ -539,8 +547,7 @@ def __init__(
539547
"""
540548
self.nelements = jnp.asarray(nelements)
541549
self.el_len = jnp.asarray(el_len)
542-
total_elements = jnp.product(self.nelements)
543-
self.ids = jnp.arange(total_elements)
550+
self.total_elements = total_elements
544551

545552
if nodes is None:
546553
total_nodes = jnp.product(self.nelements + 1)
@@ -561,6 +568,11 @@ def __init__(
561568

562569
self.constraints = constraints
563570
self.concentrated_nodal_forces = concentrated_nodal_forces
571+
if initialized is None:
572+
self.volume = jnp.ones((self.total_elements, 1, 1))
573+
else:
574+
self.volume = volume
575+
self.initialized = True
564576

565577
def id_to_node_ids(self, id: int):
566578
"""
@@ -773,12 +785,12 @@ def _step(pid, args):
773785
)
774786
self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
775787

776-
def compute_volume(self):
788+
def compute_volume(self, *args):
777789
a = c = self.el_len[1]
778790
b = d = self.el_len[0]
779791
p = q = jnp.sqrt(a**2 + b**2)
780792
vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2)
781-
self.volume = jnp.ones((len(self.ids), 1, 1)) * vol
793+
self.volume = self.volume.at[:].set(vol)
782794

783795

784796
if __name__ == "__main__":

diffmpm/io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def _parse_mesh(self, config):
117117
if config["mesh"]["type"] == "generator":
118118
elements = element_cls(
119119
config["mesh"]["nelements"],
120+
jnp.product(jnp.array(config["mesh"]["nelements"])),
120121
config["mesh"]["element_length"],
121122
constraints,
122123
concentrated_nodal_forces=self.parsed_config["external_loading"][

diffmpm/particle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Tuple
22

33
import jax.numpy as jnp
4+
from functools import partial
45
from jax import jit, vmap, lax
56
import jax.debug as db
67
from jax.tree_util import register_pytree_node_class
@@ -144,9 +145,10 @@ def set_mass_volume(self, m: float | jnp.ndarray):
144145
)
145146
self.volume = jnp.divide(self.mass, self.material.properties["density"])
146147

147-
def compute_volume(self, elements: _Element):
148-
elements.compute_volume()
149-
particles_per_element = jnp.bincount(self.element_ids, length=len(elements.ids))
148+
def compute_volume(self, elements, total_elements):
149+
particles_per_element = jnp.bincount(
150+
self.element_ids, length=elements.total_elements
151+
)
150152
vol = (
151153
elements.volume.squeeze((1, 2))[self.element_ids]
152154
/ particles_per_element[self.element_ids]

diffmpm/solver.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, filepath):
2323
raise ValueError("Wrong type of solver specified.")
2424

2525
def solve(self):
26-
res = self.solver.solve(
26+
res = self.solver.solve_jit(
2727
self._config.parsed_config["meta"]["nsteps"],
2828
self._config.parsed_config["external_loading"]["gravity"],
2929
)
@@ -42,22 +42,27 @@ def __init__(self, mesh, dt, scheme="usf", velocity_update=False):
4242
self.mesh = mesh
4343
self.dt = dt
4444
self.scheme = scheme
45+
self.velocity_update = velocity_update
4546
self.mesh.apply_on_elements("set_particle_element_ids")
46-
self.mesh.apply_on_particles("compute_volume")
47+
self.mesh.apply_on_elements("compute_volume")
48+
self.mesh.apply_on_particles(
49+
"compute_volume", args=(self.mesh.elements.total_elements,)
50+
)
4751

4852
def tree_flatten(self):
4953
children = (self.mesh,)
50-
aux_data = (self.dt, self.scheme)
54+
aux_data = (self.dt, self.scheme, self.velocity_update)
5155
return children, aux_data
5256

5357
@classmethod
5458
def tree_unflatten(cls, aux_data, children):
55-
return cls(*children, aux_data[0], scheme=aux_data[1])
59+
return cls(
60+
*children, aux_data[0], scheme=aux_data[1], velocity_update=aux_data[2]
61+
)
5662

5763
def solve(self, nsteps: int, gravity: float | jnp.ndarray):
5864
result = defaultdict(list)
5965
for step in tqdm(range(nsteps)):
60-
# breakpoint()
6166
self.mpm_scheme.compute_nodal_kinematics()
6267
self.mpm_scheme.precompute_stress_strain()
6368
self.mpm_scheme.compute_forces(gravity, step)
@@ -75,21 +80,17 @@ def solve(self, nsteps: int, gravity: float | jnp.ndarray):
7580
def solve_jit(self, nsteps: int, gravity: float | jnp.ndarray):
7681
nparticles = sum(pset.loc.shape[0] for pset in self.mesh.particles)
7782
result = {
78-
"position": jnp.zeros((nsteps, nparticles)),
79-
"velocity": jnp.zeros((nsteps, nparticles)),
80-
"strain_energy": jnp.zeros((nsteps, nparticles)),
81-
"kinetic_energy": jnp.zeros((nsteps, nparticles)),
82-
"total_energy": jnp.zeros((nsteps, nparticles)),
83-
"stress": jnp.zeros((nsteps, nparticles)),
84-
"strain": jnp.zeros((nsteps, nparticles)),
83+
"position": jnp.zeros((nsteps, nparticles, 2)),
84+
"velocity": jnp.zeros((nsteps, nparticles, 2)),
85+
"stress": jnp.zeros((nsteps, nparticles, 6)),
86+
"strain": jnp.zeros((nsteps, nparticles, 6)),
8587
}
8688

8789
def _step(i, data):
8890
self, result = data
8991
self.mpm_scheme.compute_nodal_kinematics()
9092
self.mpm_scheme.precompute_stress_strain()
91-
self.mpm_scheme.compute_forces(gravity)
92-
# self.mpm_scheme.update_nodal_momentum()
93+
self.mpm_scheme.compute_forces(gravity, i)
9394
self.mpm_scheme.compute_particle_kinematics()
9495
self.mpm_scheme.postcompute_stress_strain()
9596

@@ -99,45 +100,23 @@ def _step(i, data):
99100
idu += len(self.mesh.particles[j])
100101
result["position"] = (
101102
result["position"]
102-
.at[i, idl:idu]
103+
.at[i, idl:idu, :]
103104
.set(self.mesh.particles[j].loc.squeeze())
104105
)
105106
result["velocity"] = (
106107
result["velocity"]
107-
.at[i, idl:idu]
108+
.at[i, idl:idu, :]
108109
.set(self.mesh.particles[j].velocity.squeeze())
109110
)
110111
result["stress"] = (
111112
result["stress"]
112-
.at[i, idl:idu]
113-
.set(self.mesh.particles[j].stress[:, 0, :].squeeze())
113+
.at[i, idl:idu, :]
114+
.set(self.mesh.particles[j].stress[:, :, 0].squeeze())
114115
)
115116
result["strain"] = (
116117
result["strain"]
117-
.at[i, idl:idu]
118-
.set(self.mesh.particles[j].strain[:, 0, :].squeeze())
119-
)
120-
strain_energy = (
121-
0.5
122-
* self.mesh.particles[j].stress[:, 0, :].squeeze()
123-
* self.mesh.particles[j].strain[:, 0, :].squeeze()
124-
* self.mesh.particles[j].volume.squeeze()
125-
)
126-
kinetic_energy = (
127-
0.5
128-
* self.mesh.particles[j].velocity.squeeze() ** 2
129-
* self.mesh.particles[j].mass.squeeze()
130-
)
131-
result["strain_energy"] = (
132-
result["strain_energy"].at[i, idl:idu].set(strain_energy)
133-
)
134-
result["kinetic_energy"] = (
135-
result["kinetic_energy"].at[i, idl:idu].set(kinetic_energy)
136-
)
137-
result["total_energy"] = (
138-
result["total_energy"]
139-
.at[i, idl:idu]
140-
.set(strain_energy + kinetic_energy)
118+
.at[i, idl:idu, :]
119+
.set(self.mesh.particles[j].strain[:, :, 0].squeeze())
141120
)
142121
return (self, result)
143122

examples/mpm-nodal-forces.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type = "MPMExplicit"
1616
dimension = 2
1717
scheme = "usf"
1818
dt = 0.001
19-
nsteps = 20
19+
nsteps = 301
2020
velocity_update = true
2121

2222
[output]
@@ -59,7 +59,7 @@ gravity = [0, 0]
5959
node_ids = [3, 7]
6060
math_function_id = 0
6161
dir = 0
62-
force = 0.01
62+
force = 0.05
6363

6464
[[external_loading.particle_surface_traction]]
6565
pset = [1]

examples/simple_2d_file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
mpm = MPM(sys.argv[1])
77
result = mpm.solve()
88
# breakpoint()
9-
print(result["stress"][-1])
9+
print(result["stress"][-1][:, :2])
1010
exit()
1111

1212

0 commit comments

Comments
 (0)