Skip to content

Commit 5991f27

Browse files
committed
Add option to not write results, return final val
1 parent ed78795 commit 5991f27

3 files changed

Lines changed: 93 additions & 39 deletions

File tree

diffmpm/io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def parse(self):
2929
self._parse_particles(self._fileconfig)
3030
if "math_functions" in self._fileconfig:
3131
self._parse_math_functions(self._fileconfig)
32-
if "external_loading" in self._fileconfig:
33-
self._parse_external_loading(self._fileconfig)
32+
self._parse_external_loading(self._fileconfig)
3433
mesh = self._parse_mesh(self._fileconfig)
3534
return mesh
3635

diffmpm/solver.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ def __init__(self, filepath):
2020
self._config.parsed_config["meta"]["title"],
2121
)
2222

23-
write_format = self._config.parsed_config["output"]["format"]
24-
if write_format == "npz":
25-
writer = writers.NPZWriter()
23+
write_format = self._config.parsed_config["output"].get("format", None)
24+
if write_format is None or write_format.lower() == "none":
25+
writer_func = None
26+
elif write_format == "npz":
27+
writer_func = writers.NPZWriter().write
2628
else:
2729
raise ValueError(f"Specified output format not supported: {write_format}")
2830

@@ -31,20 +33,20 @@ def __init__(self, filepath):
3133
mesh,
3234
self._config.parsed_config["meta"]["dt"],
3335
velocity_update=self._config.parsed_config["meta"]["velocity_update"],
36+
sim_steps=self._config.parsed_config["meta"]["nsteps"],
3437
out_steps=self._config.parsed_config["output"]["step_frequency"],
3538
out_dir=out_dir,
36-
writer_func=writer.write,
39+
writer_func=writer_func,
3740
)
3841
else:
3942
raise ValueError("Wrong type of solver specified.")
4043

4144
def solve(self):
4245
"""Solve the MPM simulation."""
43-
res = self.solver.solve_jit(
44-
self._config.parsed_config["meta"]["nsteps"],
46+
arrays = self.solver.solve_jit(
4547
self._config.parsed_config["external_loading"]["gravity"],
4648
)
47-
return res
49+
return arrays
4850

4951

5052
@register_pytree_node_class
@@ -57,6 +59,7 @@ def __init__(
5759
dt,
5860
scheme="usf",
5961
velocity_update=False,
62+
sim_steps=1,
6063
out_steps=1,
6164
out_dir="results/",
6265
writer_func=None,
@@ -71,48 +74,51 @@ def __init__(
7174
self.dt = dt
7275
self.scheme = scheme
7376
self.velocity_update = velocity_update
77+
self.sim_steps = sim_steps
7478
self.out_steps = out_steps
7579
self.out_dir = out_dir
7680
self.writer_func = writer_func
77-
self.mesh.apply_on_elements("set_particle_element_ids")
78-
self.mesh.apply_on_elements("compute_volume")
79-
self.mesh.apply_on_particles(
81+
self.mpm_scheme.mesh.apply_on_elements("set_particle_element_ids")
82+
self.mpm_scheme.mesh.apply_on_elements("compute_volume")
83+
self.mpm_scheme.mesh.apply_on_particles(
8084
"compute_volume", args=(self.mesh.elements.total_elements,)
8185
)
8286

8387
def tree_flatten(self):
8488
children = (self.mesh,)
85-
aux_data = (
86-
self.dt,
87-
self.scheme,
88-
self.velocity_update,
89-
self.out_steps,
90-
self.out_dir,
91-
self.writer_func,
92-
)
89+
aux_data = {
90+
"dt": self.dt,
91+
"scheme": self.scheme,
92+
"velocity_update": self.velocity_update,
93+
"sim_steps": self.sim_steps,
94+
"out_steps": self.out_steps,
95+
"out_dir": self.out_dir,
96+
"writer_func": self.writer_func,
97+
}
9398
return children, aux_data
9499

95100
@classmethod
96101
def tree_unflatten(cls, aux_data, children):
97102
return cls(
98103
*children,
99-
aux_data[0],
100-
scheme=aux_data[1],
101-
velocity_update=aux_data[2],
102-
out_steps=aux_data[3],
103-
out_dir=aux_data[4],
104-
writer_func=aux_data[5],
104+
aux_data["dt"],
105+
scheme=aux_data["scheme"],
106+
velocity_update=aux_data["velocity_update"],
107+
sim_steps=aux_data["sim_steps"],
108+
out_steps=aux_data["out_steps"],
109+
out_dir=aux_data["out_dir"],
110+
writer_func=aux_data["writer_func"],
105111
)
106112

107113
def jax_writer(self, func, args):
108114
id_tap(func, args)
109115

110-
def solve(self, nsteps: int, gravity: float | jnp.ndarray):
116+
def solve(self, gravity: float | jnp.ndarray):
111117
from collections import defaultdict
112118
from tqdm import tqdm
113119

114120
result = defaultdict(list)
115-
for step in tqdm(range(nsteps)):
121+
for step in tqdm(range(self.sim_steps)):
116122
self.mpm_scheme.compute_nodal_kinematics()
117123
self.mpm_scheme.precompute_stress_strain()
118124
self.mpm_scheme.compute_forces(gravity, step)
@@ -127,9 +133,9 @@ def solve(self, nsteps: int, gravity: float | jnp.ndarray):
127133
result = {k: jnp.asarray(v) for k, v in result.items()}
128134
return result
129135

130-
def solve_jit(self, nsteps: int, gravity: float | jnp.ndarray):
136+
def solve_jit(self, gravity: float | jnp.ndarray):
131137
def _step(i, data):
132-
self, nsteps = data
138+
self = data
133139
self.mpm_scheme.compute_nodal_kinematics()
134140
self.mpm_scheme.precompute_stress_strain()
135141
self.mpm_scheme.compute_forces(gravity, i)
@@ -141,18 +147,34 @@ def _write(self, i):
141147
for name in self.__particle_props:
142148
arrays[name] = jnp.array(
143149
[
144-
getattr(self.mesh.particles[j], name)
150+
getattr(self.mesh.particles[j], name).squeeze()
145151
for j in range(len(self.mesh.particles))
146152
]
147-
).squeeze()
153+
)
148154
self.jax_writer(
149-
functools.partial(self.writer_func, out_dir=self.out_dir),
155+
functools.partial(
156+
self.writer_func, out_dir=self.out_dir, max_steps=self.sim_steps
157+
),
150158
(arrays, i),
151159
)
152160

153-
lax.cond(
154-
(i + 1) % self.out_steps == 0, _write, lambda s, i: None, self, i + 1
155-
)
156-
return (self, nsteps)
157-
158-
_, nsteps = lax.fori_loop(0, nsteps, _step, (self, nsteps))
161+
if self.writer_func is not None:
162+
lax.cond(
163+
i % self.out_steps == 0,
164+
_write,
165+
lambda s, i: None,
166+
self,
167+
i,
168+
)
169+
return self
170+
171+
self = lax.fori_loop(0, self.sim_steps, _step, self)
172+
arrays = {}
173+
for name in self.__particle_props:
174+
arrays[name] = jnp.array(
175+
[
176+
getattr(self.mesh.particles[j], name)
177+
for j in range(len(self.mesh.particles))
178+
]
179+
).squeeze()
180+
return arrays

diffmpm/writers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import abc
2+
import logging
3+
import numpy as np
4+
from pathlib import Path
5+
6+
logger = logging.getLogger(__file__)
7+
8+
9+
class Writer(abc.ABC):
10+
@abc.abstractmethod
11+
def write(self):
12+
...
13+
14+
15+
class EmptyWriter(Writer):
16+
def write(self, args, transforms, **kwargs):
17+
pass
18+
19+
20+
class NPZWriter(Writer):
21+
def write(self, args, transforms, **kwargs):
22+
arrays, step = args
23+
max_digits = int(np.log10(kwargs["max_steps"])) + 1
24+
if step == 0:
25+
req_zeros = max_digits - 1
26+
else:
27+
req_zeros = max_digits - (int(np.log10(step)) + 1)
28+
fileno = f"{'0' * req_zeros}{step}"
29+
filepath = Path(kwargs["out_dir"]).joinpath(f"particles_{fileno}.npz")
30+
if not filepath.parent.is_dir():
31+
filepath.parent.mkdir(parents=True)
32+
np.savez(filepath, **arrays)
33+
logger.info(f"Saved particle data for step {step} at {filepath}")

0 commit comments

Comments
 (0)