@@ -23,7 +23,7 @@ def __init__(self, filepath):
2323 raise ValueError ("Wrong type of solver specified." )
2424
2525 def solve (self ):
26- res = self .solver .solve (
26+ res = self .solver .solve_jit (
2727 self ._config .parsed_config ["meta" ]["nsteps" ],
2828 self ._config .parsed_config ["external_loading" ]["gravity" ],
2929 )
@@ -42,22 +42,27 @@ def __init__(self, mesh, dt, scheme="usf", velocity_update=False):
4242 self .mesh = mesh
4343 self .dt = dt
4444 self .scheme = scheme
45+ self .velocity_update = velocity_update
4546 self .mesh .apply_on_elements ("set_particle_element_ids" )
46- self .mesh .apply_on_particles ("compute_volume" )
47+ self .mesh .apply_on_elements ("compute_volume" )
48+ self .mesh .apply_on_particles (
49+ "compute_volume" , args = (self .mesh .elements .total_elements ,)
50+ )
4751
4852 def tree_flatten (self ):
4953 children = (self .mesh ,)
50- aux_data = (self .dt , self .scheme )
54+ aux_data = (self .dt , self .scheme , self . velocity_update )
5155 return children , aux_data
5256
5357 @classmethod
5458 def tree_unflatten (cls , aux_data , children ):
55- return cls (* children , aux_data [0 ], scheme = aux_data [1 ])
59+ return cls (
60+ * children , aux_data [0 ], scheme = aux_data [1 ], velocity_update = aux_data [2 ]
61+ )
5662
5763 def solve (self , nsteps : int , gravity : float | jnp .ndarray ):
5864 result = defaultdict (list )
5965 for step in tqdm (range (nsteps )):
60- # breakpoint()
6166 self .mpm_scheme .compute_nodal_kinematics ()
6267 self .mpm_scheme .precompute_stress_strain ()
6368 self .mpm_scheme .compute_forces (gravity , step )
@@ -75,21 +80,17 @@ def solve(self, nsteps: int, gravity: float | jnp.ndarray):
7580 def solve_jit (self , nsteps : int , gravity : float | jnp .ndarray ):
7681 nparticles = sum (pset .loc .shape [0 ] for pset in self .mesh .particles )
7782 result = {
78- "position" : jnp .zeros ((nsteps , nparticles )),
79- "velocity" : jnp .zeros ((nsteps , nparticles )),
80- "strain_energy" : jnp .zeros ((nsteps , nparticles )),
81- "kinetic_energy" : jnp .zeros ((nsteps , nparticles )),
82- "total_energy" : jnp .zeros ((nsteps , nparticles )),
83- "stress" : jnp .zeros ((nsteps , nparticles )),
84- "strain" : jnp .zeros ((nsteps , nparticles )),
83+ "position" : jnp .zeros ((nsteps , nparticles , 2 )),
84+ "velocity" : jnp .zeros ((nsteps , nparticles , 2 )),
85+ "stress" : jnp .zeros ((nsteps , nparticles , 6 )),
86+ "strain" : jnp .zeros ((nsteps , nparticles , 6 )),
8587 }
8688
8789 def _step (i , data ):
8890 self , result = data
8991 self .mpm_scheme .compute_nodal_kinematics ()
9092 self .mpm_scheme .precompute_stress_strain ()
91- self .mpm_scheme .compute_forces (gravity )
92- # self.mpm_scheme.update_nodal_momentum()
93+ self .mpm_scheme .compute_forces (gravity , i )
9394 self .mpm_scheme .compute_particle_kinematics ()
9495 self .mpm_scheme .postcompute_stress_strain ()
9596
@@ -99,45 +100,23 @@ def _step(i, data):
99100 idu += len (self .mesh .particles [j ])
100101 result ["position" ] = (
101102 result ["position" ]
102- .at [i , idl :idu ]
103+ .at [i , idl :idu , : ]
103104 .set (self .mesh .particles [j ].loc .squeeze ())
104105 )
105106 result ["velocity" ] = (
106107 result ["velocity" ]
107- .at [i , idl :idu ]
108+ .at [i , idl :idu , : ]
108109 .set (self .mesh .particles [j ].velocity .squeeze ())
109110 )
110111 result ["stress" ] = (
111112 result ["stress" ]
112- .at [i , idl :idu ]
113- .set (self .mesh .particles [j ].stress [:, 0 , : ].squeeze ())
113+ .at [i , idl :idu , : ]
114+ .set (self .mesh .particles [j ].stress [:, :, 0 ].squeeze ())
114115 )
115116 result ["strain" ] = (
116117 result ["strain" ]
117- .at [i , idl :idu ]
118- .set (self .mesh .particles [j ].strain [:, 0 , :].squeeze ())
119- )
120- strain_energy = (
121- 0.5
122- * self .mesh .particles [j ].stress [:, 0 , :].squeeze ()
123- * self .mesh .particles [j ].strain [:, 0 , :].squeeze ()
124- * self .mesh .particles [j ].volume .squeeze ()
125- )
126- kinetic_energy = (
127- 0.5
128- * self .mesh .particles [j ].velocity .squeeze () ** 2
129- * self .mesh .particles [j ].mass .squeeze ()
130- )
131- result ["strain_energy" ] = (
132- result ["strain_energy" ].at [i , idl :idu ].set (strain_energy )
133- )
134- result ["kinetic_energy" ] = (
135- result ["kinetic_energy" ].at [i , idl :idu ].set (kinetic_energy )
136- )
137- result ["total_energy" ] = (
138- result ["total_energy" ]
139- .at [i , idl :idu ]
140- .set (strain_energy + kinetic_energy )
118+ .at [i , idl :idu , :]
119+ .set (self .mesh .particles [j ].strain [:, :, 0 ].squeeze ())
141120 )
142121 return (self , result )
143122
0 commit comments