@@ -20,9 +20,11 @@ def __init__(self, filepath):
2020 self ._config .parsed_config ["meta" ]["title" ],
2121 )
2222
23- write_format = self ._config .parsed_config ["output" ]["format" ]
24- if write_format == "npz" :
25- writer = writers .NPZWriter ()
23+ write_format = self ._config .parsed_config ["output" ].get ("format" , None )
24+ if write_format is None or write_format .lower () == "none" :
25+ writer_func = None
26+ elif write_format == "npz" :
27+ writer_func = writers .NPZWriter ().write
2628 else :
2729 raise ValueError (f"Specified output format not supported: { write_format } " )
2830
@@ -31,20 +33,20 @@ def __init__(self, filepath):
3133 mesh ,
3234 self ._config .parsed_config ["meta" ]["dt" ],
3335 velocity_update = self ._config .parsed_config ["meta" ]["velocity_update" ],
36+ sim_steps = self ._config .parsed_config ["meta" ]["nsteps" ],
3437 out_steps = self ._config .parsed_config ["output" ]["step_frequency" ],
3538 out_dir = out_dir ,
36- writer_func = writer . write ,
39+ writer_func = writer_func ,
3740 )
3841 else :
3942 raise ValueError ("Wrong type of solver specified." )
4043
4144 def solve (self ):
4245 """Solve the MPM simulation."""
43- res = self .solver .solve_jit (
44- self ._config .parsed_config ["meta" ]["nsteps" ],
46+ arrays = self .solver .solve_jit (
4547 self ._config .parsed_config ["external_loading" ]["gravity" ],
4648 )
47- return res
49+ return arrays
4850
4951
5052@register_pytree_node_class
@@ -57,6 +59,7 @@ def __init__(
5759 dt ,
5860 scheme = "usf" ,
5961 velocity_update = False ,
62+ sim_steps = 1 ,
6063 out_steps = 1 ,
6164 out_dir = "results/" ,
6265 writer_func = None ,
@@ -71,48 +74,51 @@ def __init__(
7174 self .dt = dt
7275 self .scheme = scheme
7376 self .velocity_update = velocity_update
77+ self .sim_steps = sim_steps
7478 self .out_steps = out_steps
7579 self .out_dir = out_dir
7680 self .writer_func = writer_func
77- self .mesh .apply_on_elements ("set_particle_element_ids" )
78- self .mesh .apply_on_elements ("compute_volume" )
79- self .mesh .apply_on_particles (
81+ self .mpm_scheme . mesh .apply_on_elements ("set_particle_element_ids" )
82+ self .mpm_scheme . mesh .apply_on_elements ("compute_volume" )
83+ self .mpm_scheme . mesh .apply_on_particles (
8084 "compute_volume" , args = (self .mesh .elements .total_elements ,)
8185 )
8286
8387 def tree_flatten (self ):
8488 children = (self .mesh ,)
85- aux_data = (
86- self .dt ,
87- self .scheme ,
88- self .velocity_update ,
89- self .out_steps ,
90- self .out_dir ,
91- self .writer_func ,
92- )
89+ aux_data = {
90+ "dt" : self .dt ,
91+ "scheme" : self .scheme ,
92+ "velocity_update" : self .velocity_update ,
93+ "sim_steps" : self .sim_steps ,
94+ "out_steps" : self .out_steps ,
95+ "out_dir" : self .out_dir ,
96+ "writer_func" : self .writer_func ,
97+ }
9398 return children , aux_data
9499
95100 @classmethod
96101 def tree_unflatten (cls , aux_data , children ):
97102 return cls (
98103 * children ,
99- aux_data [0 ],
100- scheme = aux_data [1 ],
101- velocity_update = aux_data [2 ],
102- out_steps = aux_data [3 ],
103- out_dir = aux_data [4 ],
104- writer_func = aux_data [5 ],
104+ aux_data ["dt" ],
105+ scheme = aux_data ["scheme" ],
106+ velocity_update = aux_data ["velocity_update" ],
107+ sim_steps = aux_data ["sim_steps" ],
108+ out_steps = aux_data ["out_steps" ],
109+ out_dir = aux_data ["out_dir" ],
110+ writer_func = aux_data ["writer_func" ],
105111 )
106112
107113 def jax_writer (self , func , args ):
108114 id_tap (func , args )
109115
110- def solve (self , nsteps : int , gravity : float | jnp .ndarray ):
116+ def solve (self , gravity : float | jnp .ndarray ):
111117 from collections import defaultdict
112118 from tqdm import tqdm
113119
114120 result = defaultdict (list )
115- for step in tqdm (range (nsteps )):
121+ for step in tqdm (range (self . sim_steps )):
116122 self .mpm_scheme .compute_nodal_kinematics ()
117123 self .mpm_scheme .precompute_stress_strain ()
118124 self .mpm_scheme .compute_forces (gravity , step )
@@ -127,9 +133,9 @@ def solve(self, nsteps: int, gravity: float | jnp.ndarray):
127133 result = {k : jnp .asarray (v ) for k , v in result .items ()}
128134 return result
129135
130- def solve_jit (self , nsteps : int , gravity : float | jnp .ndarray ):
136+ def solve_jit (self , gravity : float | jnp .ndarray ):
131137 def _step (i , data ):
132- self , nsteps = data
138+ self = data
133139 self .mpm_scheme .compute_nodal_kinematics ()
134140 self .mpm_scheme .precompute_stress_strain ()
135141 self .mpm_scheme .compute_forces (gravity , i )
@@ -141,18 +147,34 @@ def _write(self, i):
141147 for name in self .__particle_props :
142148 arrays [name ] = jnp .array (
143149 [
144- getattr (self .mesh .particles [j ], name )
150+ getattr (self .mesh .particles [j ], name ). squeeze ()
145151 for j in range (len (self .mesh .particles ))
146152 ]
147- ). squeeze ()
153+ )
148154 self .jax_writer (
149- functools .partial (self .writer_func , out_dir = self .out_dir ),
155+ functools .partial (
156+ self .writer_func , out_dir = self .out_dir , max_steps = self .sim_steps
157+ ),
150158 (arrays , i ),
151159 )
152160
153- lax .cond (
154- (i + 1 ) % self .out_steps == 0 , _write , lambda s , i : None , self , i + 1
155- )
156- return (self , nsteps )
157-
158- _ , nsteps = lax .fori_loop (0 , nsteps , _step , (self , nsteps ))
161+ if self .writer_func is not None :
162+ lax .cond (
163+ i % self .out_steps == 0 ,
164+ _write ,
165+ lambda s , i : None ,
166+ self ,
167+ i ,
168+ )
169+ return self
170+
171+ self = lax .fori_loop (0 , self .sim_steps , _step , self )
172+ arrays = {}
173+ for name in self .__particle_props :
174+ arrays [name ] = jnp .array (
175+ [
176+ getattr (self .mesh .particles [j ], name )
177+ for j in range (len (self .mesh .particles ))
178+ ]
179+ ).squeeze ()
180+ return arrays
0 commit comments