@@ -224,29 +224,11 @@ def _step(pid, args):
224224 self .nodes .f_ext , _ , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
225225
226226 def apply_concentrated_nodal_forces (self , particles , curr_time ):
227- # def _step(fid, args):
228- # f_ext, cnf, curr_time = args
229- # breakpoint()
230- # factor = cnf[fid].function.value(curr_time)
231- # f_ext = f_ext.at[cnf[fid].node_ids].add(factor * cnf[fid].force)
232- # return f_ext, cnf, curr_time
233-
234- # args = (self.nodes.f_ext, self.concentrated_nodal_forces, curr_time)
235- # self.nodes.f_ext, _, _ = lax.fori_loop(
236- # 0, len(self.concentrated_nodal_forces), _step, args
237- # )
238- # breakpoint()
239- import jax .debug as db
240-
241227 for cnf in self .concentrated_nodal_forces :
242228 factor = cnf .function .value (curr_time )
243229 self .nodes .f_ext = self .nodes .f_ext .at [cnf .node_ids , 0 , cnf .dir ].add (
244230 factor * cnf .force
245231 )
246- db .print (
247- f"Factor: { factor } , curr_time: { curr_time } , "
248- f"f_ext[3]: { self .nodes .f_ext [3 ].squeeze ()} "
249- )
250232
251233 def compute_internal_force (self , particles ):
252234 r"""
@@ -302,11 +284,7 @@ def _step(pid, args):
302284
303285 def update_nodal_acceleration_velocity (self , particles , dt : float , * args ):
304286 """Update the nodal momentum based on total force on nodes."""
305- import jax .debug as db
306-
307287 total_force = self .nodes .get_total_force ()
308- db .print (f"Before: nodes.velocity[3]: { self .nodes .velocity [3 ].squeeze ()} " )
309- breakpoint ()
310288 self .nodes .acceleration = self .nodes .acceleration .at [:].set (
311289 jnp .nan_to_num (jnp .divide (total_force , self .nodes .mass ))
312290 )
@@ -317,7 +295,16 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
317295 self .nodes .momentum = self .nodes .momentum .at [:].set (
318296 self .nodes .mass * self .nodes .velocity
319297 )
320- db .print (f"After: nodes.velocity[3]: { self .nodes .velocity [3 ].squeeze ()} " )
298+ self .nodes .velocity = jnp .where (
299+ self .nodes .velocity < 1e-12 ,
300+ jnp .zeros_like (self .nodes .velocity ),
301+ self .nodes .velocity ,
302+ )
303+ self .nodes .acceleration = jnp .where (
304+ self .nodes .acceleration < 1e-12 ,
305+ jnp .zeros_like (self .nodes .acceleration ),
306+ self .nodes .acceleration ,
307+ )
321308
322309 def apply_boundary_constraints (self , * args ):
323310 """Apply boundary conditions for nodal velocity."""
@@ -784,7 +771,6 @@ def _step(pid, args):
784771 mapped_nodes ,
785772 particles .stress ,
786773 )
787- _step (0 , args )
788774 self .nodes .f_int , _ , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
789775
790776 def compute_volume (self ):
0 commit comments