Skip to content

Commit b514e3a

Browse files
committed
Use absolute values to check tolerance threshold
1 parent 5416bc8 commit b514e3a

2 files changed

Lines changed: 6 additions & 10 deletions

File tree

diffmpm/element.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _step(pid, args):
147147
)
148148
_, self.nodes.momentum, _, _ = lax.fori_loop(0, len(particles), _step, args)
149149
self.nodes.momentum = jnp.where(
150-
self.nodes.momentum < 1e-12,
150+
jnp.abs(self.nodes.momentum) < 1e-12,
151151
jnp.zeros_like(self.nodes.momentum),
152152
self.nodes.momentum,
153153
)
@@ -159,7 +159,7 @@ def compute_velocity(self, particles):
159159
self.nodes.momentum / self.nodes.mass,
160160
)
161161
self.nodes.velocity = jnp.where(
162-
self.nodes.velocity < 1e-12,
162+
jnp.abs(self.nodes.velocity) < 1e-12,
163163
jnp.zeros_like(self.nodes.velocity),
164164
self.nodes.velocity,
165165
)
@@ -301,12 +301,12 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
301301
self.nodes.mass * self.nodes.velocity
302302
)
303303
self.nodes.velocity = jnp.where(
304-
self.nodes.velocity < 1e-12,
304+
jnp.abs(self.nodes.velocity) < 1e-12,
305305
jnp.zeros_like(self.nodes.velocity),
306306
self.nodes.velocity,
307307
)
308308
self.nodes.acceleration = jnp.where(
309-
self.nodes.acceleration < 1e-12,
309+
jnp.abs(self.nodes.acceleration) < 1e-12,
310310
jnp.zeros_like(self.nodes.acceleration),
311311
self.nodes.acceleration,
312312
)

diffmpm/particle.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import Tuple
22

33
import jax.numpy as jnp
4-
from functools import partial
5-
from jax import jit, vmap, lax
6-
import jax.debug as db
4+
from jax import vmap, lax
75
from jax.tree_util import register_pytree_node_class
86

97
from diffmpm.element import _Element
@@ -245,9 +243,7 @@ def compute_strain(self, elements: _Element, dt: float):
245243
self.strain_rate = self._compute_strain_rate(dn_dx_, elements)
246244
self.dstrain = self.dstrain.at[:].set(self.strain_rate * dt)
247245

248-
# db.print(f"compute_strain() - dstrain: {self.dstrain.squeeze()[3, :2]}")
249246
self.strain = self.strain.at[:].add(self.dstrain)
250-
# db.print(f"compute_strain() - strain: {self.strain.squeeze()[3, :2]}")
251247
centroids = jnp.zeros_like(self.loc)
252248
dn_dx_centroid_ = vmap(elements.shapefn_grad)(
253249
centroids[:, jnp.newaxis, ...], mapped_coords
@@ -298,7 +294,7 @@ def _step(pid, args):
298294
args = (dn_dx, temp, strain_rate)
299295
_, _, strain_rate = lax.fori_loop(0, self.loc.shape[0], _step, args)
300296
strain_rate = jnp.where(
301-
strain_rate < 1e-12, jnp.zeros_like(strain_rate), strain_rate
297+
jnp.abs(strain_rate) < 1e-12, jnp.zeros_like(strain_rate), strain_rate
302298
)
303299
return strain_rate
304300

0 commit comments

Comments
 (0)