Skip to content

Commit d32f207

Browse files
committed
Code cleanup
1 parent c7f3f6f commit d32f207

4 files changed

Lines changed: 15 additions & 29 deletions

File tree

diffmpm/element.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -224,29 +224,11 @@ def _step(pid, args):
224224
self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
225225

226226
def apply_concentrated_nodal_forces(self, particles, curr_time):
227-
# def _step(fid, args):
228-
# f_ext, cnf, curr_time = args
229-
# breakpoint()
230-
# factor = cnf[fid].function.value(curr_time)
231-
# f_ext = f_ext.at[cnf[fid].node_ids].add(factor * cnf[fid].force)
232-
# return f_ext, cnf, curr_time
233-
234-
# args = (self.nodes.f_ext, self.concentrated_nodal_forces, curr_time)
235-
# self.nodes.f_ext, _, _ = lax.fori_loop(
236-
# 0, len(self.concentrated_nodal_forces), _step, args
237-
# )
238-
# breakpoint()
239-
import jax.debug as db
240-
241227
for cnf in self.concentrated_nodal_forces:
242228
factor = cnf.function.value(curr_time)
243229
self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add(
244230
factor * cnf.force
245231
)
246-
db.print(
247-
f"Factor: {factor}, curr_time: {curr_time}, "
248-
f"f_ext[3]: {self.nodes.f_ext[3].squeeze()}"
249-
)
250232

251233
def compute_internal_force(self, particles):
252234
r"""
@@ -302,11 +284,7 @@ def _step(pid, args):
302284

303285
def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
304286
"""Update the nodal momentum based on total force on nodes."""
305-
import jax.debug as db
306-
307287
total_force = self.nodes.get_total_force()
308-
db.print(f"Before: nodes.velocity[3]: {self.nodes.velocity[3].squeeze()}")
309-
breakpoint()
310288
self.nodes.acceleration = self.nodes.acceleration.at[:].set(
311289
jnp.nan_to_num(jnp.divide(total_force, self.nodes.mass))
312290
)
@@ -317,7 +295,16 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
317295
self.nodes.momentum = self.nodes.momentum.at[:].set(
318296
self.nodes.mass * self.nodes.velocity
319297
)
320-
db.print(f"After: nodes.velocity[3]: {self.nodes.velocity[3].squeeze()}")
298+
self.nodes.velocity = jnp.where(
299+
self.nodes.velocity < 1e-12,
300+
jnp.zeros_like(self.nodes.velocity),
301+
self.nodes.velocity,
302+
)
303+
self.nodes.acceleration = jnp.where(
304+
self.nodes.acceleration < 1e-12,
305+
jnp.zeros_like(self.nodes.acceleration),
306+
self.nodes.acceleration,
307+
)
321308

322309
def apply_boundary_constraints(self, *args):
323310
"""Apply boundary conditions for nodal velocity."""
@@ -784,7 +771,6 @@ def _step(pid, args):
784771
mapped_nodes,
785772
particles.stress,
786773
)
787-
_step(0, args)
788774
self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
789775

790776
def compute_volume(self):

diffmpm/particle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,17 @@ def _compute_strain_rate(self, dn_dx: jnp.ndarray, elements: _Element):
286286

287287
def _step(pid, args):
288288
dndx, nvel, strain_rate = args
289-
# breakpoint()
290289
matmul = dndx[pid].T @ nvel[pid]
291290
strain_rate = strain_rate.at[pid, 0].add(matmul[0, 0])
292291
strain_rate = strain_rate.at[pid, 1].add(matmul[1, 1])
293292
strain_rate = strain_rate.at[pid, 3].add(matmul[0, 1] + matmul[1, 0])
294293
return dndx, nvel, strain_rate
295294

296295
args = (dn_dx, temp, strain_rate)
297-
# _step(0, args)
298296
_, _, strain_rate = lax.fori_loop(0, self.loc.shape[0], _step, args)
299-
# breakpoint()
297+
strain_rate = jnp.where(
298+
strain_rate < 1e-12, jnp.zeros_like(strain_rate), strain_rate
299+
)
300300
return strain_rate
301301

302302
def compute_stress(self, *args):

diffmpm/scheme.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def compute_nodal_kinematics(self):
1414
self.mesh.apply_on_particles("update_natural_coords")
1515
self.mesh.apply_on_elements("compute_nodal_mass")
1616
self.mesh.apply_on_elements("compute_nodal_momentum")
17+
self.mesh.apply_on_elements("compute_velocity")
1718
self.mesh.apply_on_elements("apply_boundary_constraints")
1819

1920
def compute_stress_strain(self):
@@ -39,8 +40,6 @@ def compute_particle_kinematics(self):
3940
args=(self.dt, self.velocity_update),
4041
)
4142
# TODO: Apply particle velocity constraints.
42-
self.mesh.apply_on_elements("compute_nodal_momentum")
43-
self.mesh.apply_on_elements("apply_boundary_constraints")
4443

4544
@abc.abstractmethod
4645
def precompute_stress_strain():

examples/simple_2d_file.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
mpm = MPM(sys.argv[1])
77
result = mpm.solve()
88
# breakpoint()
9+
print(result["stress"][-1])
910
exit()
1011

1112

0 commit comments

Comments
 (0)