@@ -127,25 +127,36 @@ def compute_nodal_momentum(self, particles):
127127 """
128128
129129 def _step (pid , args ):
130- pmom , pvel , mom , vel , mapped_pos , el_nodes = args
130+ pmom , mom , mapped_pos , el_nodes = args
131131 mom = mom .at [el_nodes [pid ]].add (mapped_pos [pid ] @ pmom [pid ])
132- vel = vel .at [el_nodes [pid ]].add (mapped_pos [pid ] @ pvel [pid ])
133- return pmom , pvel , mom , vel , mapped_pos , el_nodes
132+ return pmom , mom , mapped_pos , el_nodes
134133
135134 self .nodes .momentum = self .nodes .momentum .at [:].set (0 )
136- self .nodes .velocity = self .nodes .velocity .at [:].set (0 )
137135 mapped_positions = self .shapefn (particles .reference_loc )
138136 mapped_nodes = vmap (self .id_to_node_ids )(particles .element_ids ).squeeze (- 1 )
139137 args = (
140138 particles .mass * particles .velocity ,
141- particles .velocity ,
142139 self .nodes .momentum ,
143- self .nodes .velocity ,
144140 mapped_positions ,
145141 mapped_nodes ,
146142 )
147- _ , _ , self .nodes .momentum , self .nodes .velocity , _ , _ = lax .fori_loop (
148- 0 , len (particles ), _step , args
143+ _ , self .nodes .momentum , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
144+ self .nodes .momentum = jnp .where (
145+ self .nodes .momentum < 1e-12 ,
146+ jnp .zeros_like (self .nodes .momentum ),
147+ self .nodes .momentum ,
148+ )
149+
150+ def compute_velocity (self , particles ):
151+ self .nodes .velocity = jnp .where (
152+ self .nodes .mass == 0 ,
153+ self .nodes .velocity ,
154+ self .nodes .momentum / self .nodes .mass ,
155+ )
156+ self .nodes .velocity = jnp .where (
157+ self .nodes .velocity < 1e-12 ,
158+ jnp .zeros_like (self .nodes .velocity ),
159+ self .nodes .velocity ,
149160 )
150161
151162 def compute_external_force (self , particles ):
@@ -213,29 +224,11 @@ def _step(pid, args):
213224 self .nodes .f_ext , _ , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
214225
215226 def apply_concentrated_nodal_forces (self , particles , curr_time ):
216- # def _step(fid, args):
217- # f_ext, cnf, curr_time = args
218- # breakpoint()
219- # factor = cnf[fid].function.value(curr_time)
220- # f_ext = f_ext.at[cnf[fid].node_ids].add(factor * cnf[fid].force)
221- # return f_ext, cnf, curr_time
222-
223- # args = (self.nodes.f_ext, self.concentrated_nodal_forces, curr_time)
224- # self.nodes.f_ext, _, _ = lax.fori_loop(
225- # 0, len(self.concentrated_nodal_forces), _step, args
226- # )
227- # breakpoint()
228- import jax .debug as db
229-
230227 for cnf in self .concentrated_nodal_forces :
231228 factor = cnf .function .value (curr_time )
232229 self .nodes .f_ext = self .nodes .f_ext .at [cnf .node_ids , 0 , cnf .dir ].add (
233230 factor * cnf .force
234231 )
235- db .print (
236- f"Factor: { factor } , curr_time: { curr_time } , "
237- f"f_ext[3]: { self .nodes .f_ext [3 ].squeeze ()} "
238- )
239232
240233 def compute_internal_force (self , particles ):
241234 r"""
@@ -291,11 +284,7 @@ def _step(pid, args):
291284
292285 def update_nodal_acceleration_velocity (self , particles , dt : float , * args ):
293286 """Update the nodal momentum based on total force on nodes."""
294- import jax .debug as db
295-
296287 total_force = self .nodes .get_total_force ()
297- db .print (f"Before: nodes.velocity[3]: { self .nodes .velocity [3 ].squeeze ()} " )
298- breakpoint ()
299288 self .nodes .acceleration = self .nodes .acceleration .at [:].set (
300289 jnp .nan_to_num (jnp .divide (total_force , self .nodes .mass ))
301290 )
@@ -306,7 +295,16 @@ def update_nodal_acceleration_velocity(self, particles, dt: float, *args):
306295 self .nodes .momentum = self .nodes .momentum .at [:].set (
307296 self .nodes .mass * self .nodes .velocity
308297 )
309- db .print (f"After: nodes.velocity[3]: { self .nodes .velocity [3 ].squeeze ()} " )
298+ self .nodes .velocity = jnp .where (
299+ self .nodes .velocity < 1e-12 ,
300+ jnp .zeros_like (self .nodes .velocity ),
301+ self .nodes .velocity ,
302+ )
303+ self .nodes .acceleration = jnp .where (
304+ self .nodes .acceleration < 1e-12 ,
305+ jnp .zeros_like (self .nodes .acceleration ),
306+ self .nodes .acceleration ,
307+ )
310308
311309 def apply_boundary_constraints (self , * args ):
312310 """Apply boundary conditions for nodal velocity."""
@@ -773,7 +771,6 @@ def _step(pid, args):
773771 mapped_nodes ,
774772 particles .stress ,
775773 )
776- _step (0 , args )
777774 self .nodes .f_int , _ , _ , _ , _ = lax .fori_loop (0 , len (particles ), _step , args )
778775
779776 def compute_volume (self ):
0 commit comments