Skip to content

Commit 4c3f7ed

Browse files
authored
Merge pull request #5 from chahak13/2d
2d `solve` works now. Bug fixed!
2 parents bcaed6a + c76842b commit 4c3f7ed

5 files changed

Lines changed: 39 additions & 40 deletions

File tree

diffmpm/element.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -127,25 +127,36 @@ def compute_nodal_momentum(self, particles):
127127
"""
128128

129129
def _step(pid, args):
130-
pmom, pvel, mom, vel, mapped_pos, el_nodes = args
130+
pmom, mom, mapped_pos, el_nodes = args
131131
mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid])
132-
vel = vel.at[el_nodes[pid]].add(mapped_pos[pid] @ pvel[pid])
133-
return pmom, pvel, mom, vel, mapped_pos, el_nodes
132+
return pmom, mom, mapped_pos, el_nodes
134133

135134
self.nodes.momentum = self.nodes.momentum.at[:].set(0)
136-
self.nodes.velocity = self.nodes.velocity.at[:].set(0)
137135
mapped_positions = self.shapefn(particles.reference_loc)
138136
mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
139137
args = (
140138
particles.mass * particles.velocity,
141-
particles.velocity,
142139
self.nodes.momentum,
143-
self.nodes.velocity,
144140
mapped_positions,
145141
mapped_nodes,
146142
)
147-
_, _, self.nodes.momentum, self.nodes.velocity, _, _ = lax.fori_loop(
148-
0, len(particles), _step, args
143+
_, self.nodes.momentum, _, _ = lax.fori_loop(0, len(particles), _step, args)
144+
self.nodes.momentum = jnp.where(
145+
self.nodes.momentum < 1e-12,
146+
jnp.zeros_like(self.nodes.momentum),
147+
self.nodes.momentum,
148+
)
149+
150+
def compute_velocity(self, particles):
151+
self.nodes.velocity = jnp.where(
152+
self.nodes.mass == 0,
153+
self.nodes.velocity,
154+
self.nodes.momentum / self.nodes.mass,
155+
)
156+
self.nodes.velocity = jnp.where(
157+
self.nodes.velocity < 1e-12,
158+
jnp.zeros_like(self.nodes.velocity),
159+
self.nodes.velocity,
149160
)
150161

151162
def compute_external_force(self, particles):
@@ -213,29 +224,11 @@ def _step(pid, args):
213224
self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
214225

215226
def apply_concentrated_nodal_forces(self, particles, curr_time):
216-
# def _step(fid, args):
217-
# f_ext, cnf, curr_time = args
218-
# breakpoint()
219-
# factor = cnf[fid].function.value(curr_time)
220-
# f_ext = f_ext.at[cnf[fid].node_ids].add(factor * cnf[fid].force)
221-
# return f_ext, cnf, curr_time
222-
223-
# args = (self.nodes.f_ext, self.concentrated_nodal_forces, curr_time)
224-
# self.nodes.f_ext, _, _ = lax.fori_loop(
225-
# 0, len(self.concentrated_nodal_forces), _step, args
226-
# )
227-
# breakpoint()
228-
import jax.debug as db
229-
230227
for cnf in self.concentrated_nodal_forces:
231228
factor = cnf.function.value(curr_time)
232229
self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add(
233230
factor * cnf.force
234231
)
235-
db.print(
236-
f"Factor: {factor}, curr_time: {curr_time}, "
237-
f"f_ext[3]: {self.nodes.f_ext[3].squeeze()}"
238-
)
239232

240233
def compute_internal_force(self, particles):
241234
r"""
@@ -291,11 +284,7 @@ def _step(pid, args):
291284

292285
def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
293286
"""Update the nodal momentum based on total force on nodes."""
294-
import jax.debug as db
295-
296287
total_force = self.nodes.get_total_force()
297-
db.print(f"Before: nodes.velocity[3]: {self.nodes.velocity[3].squeeze()}")
298-
breakpoint()
299288
self.nodes.acceleration = self.nodes.acceleration.at[:].set(
300289
jnp.nan_to_num(jnp.divide(total_force, self.nodes.mass))
301290
)
@@ -306,7 +295,16 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
306295
self.nodes.momentum = self.nodes.momentum.at[:].set(
307296
self.nodes.mass * self.nodes.velocity
308297
)
309-
db.print(f"After: nodes.velocity[3]: {self.nodes.velocity[3].squeeze()}")
298+
self.nodes.velocity = jnp.where(
299+
self.nodes.velocity < 1e-12,
300+
jnp.zeros_like(self.nodes.velocity),
301+
self.nodes.velocity,
302+
)
303+
self.nodes.acceleration = jnp.where(
304+
self.nodes.acceleration < 1e-12,
305+
jnp.zeros_like(self.nodes.acceleration),
306+
self.nodes.acceleration,
307+
)
310308

311309
def apply_boundary_constraints(self, *args):
312310
"""Apply boundary conditions for nodal velocity."""
@@ -773,7 +771,6 @@ def _step(pid, args):
773771
mapped_nodes,
774772
particles.stress,
775773
)
776-
_step(0, args)
777774
self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
778775

779776
def compute_volume(self):

diffmpm/particle.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,12 @@ def update_position_velocity(
208208
mapped_positions * elements.nodes.acceleration[mapped_ids],
209209
axis=1,
210210
)
211-
self.velocity = self.velocity.at[:].add(
211+
self.velocity = self.velocity.at[:].set(
212212
lax.cond(
213213
velocity_update,
214-
lambda nv, na, t: nv,
215-
lambda nv, na, t: na * t,
214+
lambda sv, nv, na, t: nv,
215+
lambda sv, nv, na, t: sv + na * t,
216+
self.velocity,
216217
nodal_velocity,
217218
nodal_acceleration,
218219
dt,
@@ -286,17 +287,17 @@ def _compute_strain_rate(self, dn_dx: jnp.ndarray, elements: _Element):
286287

287288
def _step(pid, args):
288289
dndx, nvel, strain_rate = args
289-
# breakpoint()
290290
matmul = dndx[pid].T @ nvel[pid]
291291
strain_rate = strain_rate.at[pid, 0].add(matmul[0, 0])
292292
strain_rate = strain_rate.at[pid, 1].add(matmul[1, 1])
293293
strain_rate = strain_rate.at[pid, 3].add(matmul[0, 1] + matmul[1, 0])
294294
return dndx, nvel, strain_rate
295295

296296
args = (dn_dx, temp, strain_rate)
297-
# _step(0, args)
298297
_, _, strain_rate = lax.fori_loop(0, self.loc.shape[0], _step, args)
299-
# breakpoint()
298+
strain_rate = jnp.where(
299+
strain_rate < 1e-12, jnp.zeros_like(strain_rate), strain_rate
300+
)
300301
return strain_rate
301302

302303
def compute_stress(self, *args):

diffmpm/scheme.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def compute_nodal_kinematics(self):
1414
self.mesh.apply_on_particles("update_natural_coords")
1515
self.mesh.apply_on_elements("compute_nodal_mass")
1616
self.mesh.apply_on_elements("compute_nodal_momentum")
17+
self.mesh.apply_on_elements("compute_velocity")
1718
self.mesh.apply_on_elements("apply_boundary_constraints")
1819

1920
def compute_stress_strain(self):
@@ -39,8 +40,6 @@ def compute_particle_kinematics(self):
3940
args=(self.dt, self.velocity_update),
4041
)
4142
# TODO: Apply particle velocity constraints.
42-
self.mesh.apply_on_elements("compute_nodal_momentum")
43-
self.mesh.apply_on_elements("apply_boundary_constraints")
4443

4544
@abc.abstractmethod
4645
def precompute_stress_strain():

diffmpm/solver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, mesh, dt, scheme="usf", velocity_update=False):
4242
self.mesh = mesh
4343
self.dt = dt
4444
self.scheme = scheme
45+
self.mesh.apply_on_elements("set_particle_element_ids")
4546
self.mesh.apply_on_particles("compute_volume")
4647

4748
def tree_flatten(self):

examples/simple_2d_file.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
mpm = MPM(sys.argv[1])
77
result = mpm.solve()
88
# breakpoint()
9+
print(result["stress"][-1])
910
exit()
1011

1112

0 commit comments

Comments
 (0)