@@ -147,7 +147,7 @@ def _step(pid, args):
147147 )
148148 _ , self .nodes .momentum , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
149149 self .nodes .momentum = jnp .where (
150- self .nodes .momentum < 1e-12 ,
150+ jnp . abs ( self .nodes .momentum ) < 1e-12 ,
151151 jnp .zeros_like (self .nodes .momentum ),
152152 self .nodes .momentum ,
153153 )
@@ -159,7 +159,7 @@ def compute_velocity(self, particles):
159159 self .nodes .momentum / self .nodes .mass ,
160160 )
161161 self .nodes .velocity = jnp .where (
162- self .nodes .velocity < 1e-12 ,
162+ jnp . abs ( self .nodes .velocity ) < 1e-12 ,
163163 jnp .zeros_like (self .nodes .velocity ),
164164 self .nodes .velocity ,
165165 )
@@ -301,12 +301,12 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
301301 self .nodes .mass * self .nodes .velocity
302302 )
303303 self .nodes .velocity = jnp .where (
304- self .nodes .velocity < 1e-12 ,
304+ jnp . abs ( self .nodes .velocity ) < 1e-12 ,
305305 jnp .zeros_like (self .nodes .velocity ),
306306 self .nodes .velocity ,
307307 )
308308 self .nodes .acceleration = jnp .where (
309- self .nodes .acceleration < 1e-12 ,
309+ jnp . abs ( self .nodes .acceleration ) < 1e-12 ,
310310 jnp .zeros_like (self .nodes .acceleration ),
311311 self .nodes .acceleration ,
312312 )
0 commit comments