Skip to content

Commit 68e52e9

Browse files
authored
Merge pull request #6 from chahak13/2d
Implement correct 2d JIT solve!
2 parents 4c3f7ed + 9b1e77a commit 68e52e9

7 files changed

Lines changed: 97 additions & 97 deletions

File tree

diffmpm/element.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@ def id_to_node_vel(self, id: int):
5252
return self.nodes.velocity[node_ids]
5353

5454
def tree_flatten(self):
55-
children = (self.nodes,)
55+
children = (self.nodes, self.volume)
5656
aux_data = (
5757
self.nelements,
58+
self.total_elements,
5859
self.el_len,
5960
self.constraints,
6061
self.concentrated_nodal_forces,
62+
self.initialized,
6163
)
6264
return children, aux_data
6365

@@ -67,8 +69,11 @@ def tree_unflatten(cls, aux_data, children):
6769
aux_data[0],
6870
aux_data[1],
6971
aux_data[2],
72+
aux_data[3],
7073
nodes=children[0],
71-
concentrated_nodal_forces=aux_data[3],
74+
concentrated_nodal_forces=aux_data[4],
75+
initialized=aux_data[5],
76+
volume=children[1],
7277
)
7378

7479
@abc.abstractmethod
@@ -142,7 +147,7 @@ def _step(pid, args):
142147
)
143148
_, self.nodes.momentum, _, _ = lax.fori_loop(0, len(particles), _step, args)
144149
self.nodes.momentum = jnp.where(
145-
self.nodes.momentum < 1e-12,
150+
jnp.abs(self.nodes.momentum) < 1e-12,
146151
jnp.zeros_like(self.nodes.momentum),
147152
self.nodes.momentum,
148153
)
@@ -154,7 +159,7 @@ def compute_velocity(self, particles):
154159
self.nodes.momentum / self.nodes.mass,
155160
)
156161
self.nodes.velocity = jnp.where(
157-
self.nodes.velocity < 1e-12,
162+
jnp.abs(self.nodes.velocity) < 1e-12,
158163
jnp.zeros_like(self.nodes.velocity),
159164
self.nodes.velocity,
160165
)
@@ -296,12 +301,12 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
296301
self.nodes.mass * self.nodes.velocity
297302
)
298303
self.nodes.velocity = jnp.where(
299-
self.nodes.velocity < 1e-12,
304+
jnp.abs(self.nodes.velocity) < 1e-12,
300305
jnp.zeros_like(self.nodes.velocity),
301306
self.nodes.velocity,
302307
)
303308
self.nodes.acceleration = jnp.where(
304-
self.nodes.acceleration < 1e-12,
309+
jnp.abs(self.nodes.acceleration) < 1e-12,
305310
jnp.zeros_like(self.nodes.acceleration),
306311
self.nodes.acceleration,
307312
)
@@ -350,7 +355,7 @@ def __init__(
350355
IDs of nodes that are supposed to be fixed (boundary).
351356
"""
352357
self.nelements = nelements
353-
self.ids = jnp.arange(nelements)
358+
self.total_elements = nelements
354359
self.el_len = el_len
355360
if nodes is None:
356361
self.nodes = Nodes(
@@ -495,7 +500,7 @@ def f(x):
495500

496501
def compute_volume(self):
497502
vol = jnp.ediff1d(self.nodes.loc)
498-
self.volume = jnp.ones((len(self.ids), 1, 1)) * vol
503+
self.volume = jnp.ones((self.total_elements, 1, 1)) * vol
499504

500505

501506
@register_pytree_node_class
@@ -523,10 +528,13 @@ class Quadrilateral4Node(_Element):
523528
def __init__(
524529
self,
525530
nelements: Tuple[int, int],
531+
total_elements: int,
526532
el_len: Tuple[float, float],
527533
constraints: List[Tuple[jnp.ndarray, Constraint]],
528534
nodes: Nodes = None,
529535
concentrated_nodal_forces=[],
536+
initialized: bool = None,
537+
volume: jnp.ndarray = None,
530538
):
531539
"""Initialize Quadrilateral4Node.
532540
@@ -539,8 +547,7 @@ def __init__(
539547
"""
540548
self.nelements = jnp.asarray(nelements)
541549
self.el_len = jnp.asarray(el_len)
542-
total_elements = jnp.product(self.nelements)
543-
self.ids = jnp.arange(total_elements)
550+
self.total_elements = total_elements
544551

545552
if nodes is None:
546553
total_nodes = jnp.product(self.nelements + 1)
@@ -561,6 +568,11 @@ def __init__(
561568

562569
self.constraints = constraints
563570
self.concentrated_nodal_forces = concentrated_nodal_forces
571+
if initialized is None:
572+
self.volume = jnp.ones((self.total_elements, 1, 1))
573+
else:
574+
self.volume = volume
575+
self.initialized = True
564576

565577
def id_to_node_ids(self, id: int):
566578
"""
@@ -773,12 +785,12 @@ def _step(pid, args):
773785
)
774786
self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
775787

776-
def compute_volume(self):
788+
def compute_volume(self, *args):
777789
a = c = self.el_len[1]
778790
b = d = self.el_len[0]
779791
p = q = jnp.sqrt(a**2 + b**2)
780792
vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2)
781-
self.volume = jnp.ones((len(self.ids), 1, 1)) * vol
793+
self.volume = self.volume.at[:].set(vol)
782794

783795

784796
if __name__ == "__main__":

diffmpm/io.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ def parse(self):
2727
self._parse_output(self._fileconfig)
2828
self._parse_materials(self._fileconfig)
2929
self._parse_particles(self._fileconfig)
30-
self._parse_math_functions(self._fileconfig)
31-
self._parse_external_loading(self._fileconfig)
30+
if "math_functions" in self._fileconfig:
31+
self._parse_math_functions(self._fileconfig)
32+
if "external_loading" in self._fileconfig:
33+
self._parse_external_loading(self._fileconfig)
3234
mesh = self._parse_mesh(self._fileconfig)
3335
return mesh
3436

@@ -78,33 +80,39 @@ def _parse_math_functions(self, config):
7880
def _parse_external_loading(self, config):
7981
external_loading = {}
8082
external_loading["gravity"] = jnp.array(config["external_loading"]["gravity"])
81-
cnf_list = []
82-
for cnfconfig in config["external_loading"]["concentrated_nodal_forces"]:
83-
if "math_function_id" in cnfconfig:
84-
fn = self.parsed_config["math_functions"][cnfconfig["math_function_id"]]
85-
else:
86-
fn = Unit(-1)
87-
cnf = NodalForce(
88-
node_ids=jnp.array(cnfconfig["node_ids"]),
89-
function=fn,
90-
dir=cnfconfig["dir"],
91-
force=cnfconfig["force"],
92-
)
93-
cnf_list.append(cnf)
83+
external_loading["concentrated_nodal_forces"] = []
84+
external_loading["particle_surface_traction"] = []
85+
if "concentrated_nodal_forces" in config["external_loading"]:
86+
cnf_list = []
87+
for cnfconfig in config["external_loading"]["concentrated_nodal_forces"]:
88+
if "math_function_id" in cnfconfig:
89+
fn = self.parsed_config["math_functions"][
90+
cnfconfig["math_function_id"]
91+
]
92+
else:
93+
fn = Unit(-1)
94+
cnf = NodalForce(
95+
node_ids=jnp.array(cnfconfig["node_ids"]),
96+
function=fn,
97+
dir=cnfconfig["dir"],
98+
force=cnfconfig["force"],
99+
)
100+
cnf_list.append(cnf)
101+
external_loading["concentrated_nodal_forces"] = cnf_list
94102

95-
pst_list = []
96-
for pstconfig in config["external_loading"]["particle_surface_traction"]:
97-
pst = ParticleTraction(
98-
pset=jnp.array(pstconfig["pset"]),
99-
function=self.parsed_config["math_functions"][
100-
pstconfig["math_function_id"]
101-
],
102-
dir=pstconfig["dir"],
103-
traction=pstconfig["traction"],
104-
)
105-
pst_list.append(pst)
106-
external_loading["concentrated_nodal_forces"] = cnf_list
107-
external_loading["particle_surface_traction"] = pst_list
103+
if "particle_surface_traction" in config["external_loading"]:
104+
pst_list = []
105+
for pstconfig in config["external_loading"]["particle_surface_traction"]:
106+
pst = ParticleTraction(
107+
pset=jnp.array(pstconfig["pset"]),
108+
function=self.parsed_config["math_functions"][
109+
pstconfig["math_function_id"]
110+
],
111+
dir=pstconfig["dir"],
112+
traction=pstconfig["traction"],
113+
)
114+
pst_list.append(pst)
115+
external_loading["particle_surface_traction"] = pst_list
108116
self.parsed_config["external_loading"] = external_loading
109117

110118
def _parse_mesh(self, config):
@@ -117,6 +125,7 @@ def _parse_mesh(self, config):
117125
if config["mesh"]["type"] == "generator":
118126
elements = element_cls(
119127
config["mesh"]["nelements"],
128+
jnp.product(jnp.array(config["mesh"]["nelements"])),
120129
config["mesh"]["element_length"],
121130
constraints,
122131
concentrated_nodal_forces=self.parsed_config["external_loading"][

diffmpm/particle.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Tuple
22

33
import jax.numpy as jnp
4-
from jax import jit, vmap, lax
5-
import jax.debug as db
4+
from jax import vmap, lax
65
from jax.tree_util import register_pytree_node_class
76

87
from diffmpm.element import _Element
@@ -144,9 +143,10 @@ def set_mass_volume(self, m: float | jnp.ndarray):
144143
)
145144
self.volume = jnp.divide(self.mass, self.material.properties["density"])
146145

147-
def compute_volume(self, elements: _Element):
148-
elements.compute_volume()
149-
particles_per_element = jnp.bincount(self.element_ids, length=len(elements.ids))
146+
def compute_volume(self, elements, total_elements):
147+
particles_per_element = jnp.bincount(
148+
self.element_ids, length=elements.total_elements
149+
)
150150
vol = (
151151
elements.volume.squeeze((1, 2))[self.element_ids]
152152
/ particles_per_element[self.element_ids]
@@ -243,9 +243,7 @@ def compute_strain(self, elements: _Element, dt: float):
243243
self.strain_rate = self._compute_strain_rate(dn_dx_, elements)
244244
self.dstrain = self.dstrain.at[:].set(self.strain_rate * dt)
245245

246-
# db.print(f"compute_strain() - dstrain: {self.dstrain.squeeze()[3, :2]}")
247246
self.strain = self.strain.at[:].add(self.dstrain)
248-
# db.print(f"compute_strain() - strain: {self.strain.squeeze()[3, :2]}")
249247
centroids = jnp.zeros_like(self.loc)
250248
dn_dx_centroid_ = vmap(elements.shapefn_grad)(
251249
centroids[:, jnp.newaxis, ...], mapped_coords
@@ -296,7 +294,7 @@ def _step(pid, args):
296294
args = (dn_dx, temp, strain_rate)
297295
_, _, strain_rate = lax.fori_loop(0, self.loc.shape[0], _step, args)
298296
strain_rate = jnp.where(
299-
strain_rate < 1e-12, jnp.zeros_like(strain_rate), strain_rate
297+
jnp.abs(strain_rate) < 1e-12, jnp.zeros_like(strain_rate), strain_rate
300298
)
301299
return strain_rate
302300

diffmpm/solver.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, filepath):
2323
raise ValueError("Wrong type of solver specified.")
2424

2525
def solve(self):
26-
res = self.solver.solve(
26+
res = self.solver.solve_jit(
2727
self._config.parsed_config["meta"]["nsteps"],
2828
self._config.parsed_config["external_loading"]["gravity"],
2929
)
@@ -42,22 +42,27 @@ def __init__(self, mesh, dt, scheme="usf", velocity_update=False):
4242
self.mesh = mesh
4343
self.dt = dt
4444
self.scheme = scheme
45+
self.velocity_update = velocity_update
4546
self.mesh.apply_on_elements("set_particle_element_ids")
46-
self.mesh.apply_on_particles("compute_volume")
47+
self.mesh.apply_on_elements("compute_volume")
48+
self.mesh.apply_on_particles(
49+
"compute_volume", args=(self.mesh.elements.total_elements,)
50+
)
4751

4852
def tree_flatten(self):
4953
children = (self.mesh,)
50-
aux_data = (self.dt, self.scheme)
54+
aux_data = (self.dt, self.scheme, self.velocity_update)
5155
return children, aux_data
5256

5357
@classmethod
5458
def tree_unflatten(cls, aux_data, children):
55-
return cls(*children, aux_data[0], scheme=aux_data[1])
59+
return cls(
60+
*children, aux_data[0], scheme=aux_data[1], velocity_update=aux_data[2]
61+
)
5662

5763
def solve(self, nsteps: int, gravity: float | jnp.ndarray):
5864
result = defaultdict(list)
5965
for step in tqdm(range(nsteps)):
60-
# breakpoint()
6166
self.mpm_scheme.compute_nodal_kinematics()
6267
self.mpm_scheme.precompute_stress_strain()
6368
self.mpm_scheme.compute_forces(gravity, step)
@@ -75,21 +80,17 @@ def solve(self, nsteps: int, gravity: float | jnp.ndarray):
7580
def solve_jit(self, nsteps: int, gravity: float | jnp.ndarray):
7681
nparticles = sum(pset.loc.shape[0] for pset in self.mesh.particles)
7782
result = {
78-
"position": jnp.zeros((nsteps, nparticles)),
79-
"velocity": jnp.zeros((nsteps, nparticles)),
80-
"strain_energy": jnp.zeros((nsteps, nparticles)),
81-
"kinetic_energy": jnp.zeros((nsteps, nparticles)),
82-
"total_energy": jnp.zeros((nsteps, nparticles)),
83-
"stress": jnp.zeros((nsteps, nparticles)),
84-
"strain": jnp.zeros((nsteps, nparticles)),
83+
"position": jnp.zeros((nsteps, nparticles, 2)),
84+
"velocity": jnp.zeros((nsteps, nparticles, 2)),
85+
"stress": jnp.zeros((nsteps, nparticles, 6)),
86+
"strain": jnp.zeros((nsteps, nparticles, 6)),
8587
}
8688

8789
def _step(i, data):
8890
self, result = data
8991
self.mpm_scheme.compute_nodal_kinematics()
9092
self.mpm_scheme.precompute_stress_strain()
91-
self.mpm_scheme.compute_forces(gravity)
92-
# self.mpm_scheme.update_nodal_momentum()
93+
self.mpm_scheme.compute_forces(gravity, i)
9394
self.mpm_scheme.compute_particle_kinematics()
9495
self.mpm_scheme.postcompute_stress_strain()
9596

@@ -99,45 +100,23 @@ def _step(i, data):
99100
idu += len(self.mesh.particles[j])
100101
result["position"] = (
101102
result["position"]
102-
.at[i, idl:idu]
103+
.at[i, idl:idu, :]
103104
.set(self.mesh.particles[j].loc.squeeze())
104105
)
105106
result["velocity"] = (
106107
result["velocity"]
107-
.at[i, idl:idu]
108+
.at[i, idl:idu, :]
108109
.set(self.mesh.particles[j].velocity.squeeze())
109110
)
110111
result["stress"] = (
111112
result["stress"]
112-
.at[i, idl:idu]
113-
.set(self.mesh.particles[j].stress[:, 0, :].squeeze())
113+
.at[i, idl:idu, :]
114+
.set(self.mesh.particles[j].stress[:, :, 0].squeeze())
114115
)
115116
result["strain"] = (
116117
result["strain"]
117-
.at[i, idl:idu]
118-
.set(self.mesh.particles[j].strain[:, 0, :].squeeze())
119-
)
120-
strain_energy = (
121-
0.5
122-
* self.mesh.particles[j].stress[:, 0, :].squeeze()
123-
* self.mesh.particles[j].strain[:, 0, :].squeeze()
124-
* self.mesh.particles[j].volume.squeeze()
125-
)
126-
kinetic_energy = (
127-
0.5
128-
* self.mesh.particles[j].velocity.squeeze() ** 2
129-
* self.mesh.particles[j].mass.squeeze()
130-
)
131-
result["strain_energy"] = (
132-
result["strain_energy"].at[i, idl:idu].set(strain_energy)
133-
)
134-
result["kinetic_energy"] = (
135-
result["kinetic_energy"].at[i, idl:idu].set(kinetic_energy)
136-
)
137-
result["total_energy"] = (
138-
result["total_energy"]
139-
.at[i, idl:idu]
140-
.set(strain_energy + kinetic_energy)
118+
.at[i, idl:idu, :]
119+
.set(self.mesh.particles[j].strain[:, :, 0].squeeze())
141120
)
142121
return (self, result)
143122

examples/mpm-nodal-forces.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type = "MPMExplicit"
1616
dimension = 2
1717
scheme = "usf"
1818
dt = 0.001
19-
nsteps = 20
19+
nsteps = 301
2020
velocity_update = true
2121

2222
[output]
@@ -59,7 +59,7 @@ gravity = [0, 0]
5959
node_ids = [3, 7]
6060
math_function_id = 0
6161
dir = 0
62-
force = 0.01
62+
force = 0.05
6363

6464
[[external_loading.particle_surface_traction]]
6565
pset = [1]

0 commit comments

Comments
 (0)