1- import itertools
2- from typing import Sequence
1+ import abc
2+ from typing import Iterable
33
44import jax .numpy as jnp
5- from jax import jit , lax , vmap
65from jax .tree_util import register_pytree_node_class
7- from jax_tqdm import loop_tqdm
8- from tqdm import tqdm
96
107from diffmpm .element import _Element
11- from diffmpm .node import Nodes
12- from diffmpm .shapefn import Linear1DShapeFn , Linear4NodeQuad
8+ from diffmpm .particle import Particles
139
1410
15- @register_pytree_node_class
16- class _MeshBase :
11+ class _MeshBase (abc .ABC ):
12+ """
13+ Base class for Meshes.
14+
15+ Note: If attributes other than elements and particles are added
16+ then the child class should also implement `tree_flatten` and
17+ `tree_unflatten` correctly or that information will get lost.
18+ """
19+
1720 def __init__ (self , config : dict ):
1821 """Initialize mesh using configuration."""
19- self .particles : Sequence = config ["particles" ]
22+ self .particles : Iterable [ Particles , ...] = config ["particles" ]
2023 self .elements : _Element = config ["elements" ]
2124
2225 # TODO: Convert to using jax directives for loop
@@ -42,211 +45,32 @@ def tree_unflatten(cls, aux_data, children):
4245 return cls ({"particles" : children [0 ], "elements" : children [1 ]})
4346
4447
45- class Mesh1D :
46- """
47- 1D Mesh class with nodes, elements, and particles.
48- """
49-
50- def __init__ (
51- self ,
52- nelements ,
53- material ,
54- domain_size ,
55- boundary_nodes ,
56- * ,
57- ppe = 1 ,
58- particle_distribution = "uniform" ,
59- elements = None ,
60- nodes = None ,
61- particles = None ,
62- shapefn = None ,
63- dim = 1 ,
64- ):
65- """
66- Construct a 1D Mesh.
67-
68- Arguments
69- ---------
70- nelements : int
71- Number of elements in the mesh.
72- material : diffmpm.material.Material
73- Material to meshed.
74- domain_size : float
75- The size of the domain in consideration.
76- boundary_nodes : array_like
77- Node ids of boundary nodes of the mesh. Needs to be a JAX
78- array.
79- ppe : int
80- Number of particles per element in Mesh.
81- """
82- self .dim = dim
83- self .material = material
84- self .shapefn = (
85- Linear1DShapeFn (self .dim )
86- if (
87- shapefn is None
88- or type (shapefn ) is object
89- or isinstance (shapefn , Mesh1D )
90- )
91- else shapefn
92- )
93- self .domain_size = domain_size
94- self .nelements = nelements
95- self .element_length = domain_size / nelements
96- self .elements = jnp .arange (nelements ) if elements is None else elements
97- nnodes = nelements + 1
98- self .nodes = (
99- Nodes (
100- nnodes ,
101- jnp .arange (nelements + 1 ) * self .element_length ,
102- jnp .zeros (nnodes ),
103- jnp .zeros (nnodes ),
104- jnp .zeros (nnodes ),
105- jnp .zeros (nnodes ),
106- jnp .zeros (nnodes ),
107- jnp .zeros (nnodes ),
108- )
109- if (
110- nodes is None
111- or type (nodes ) is object
112- or isinstance (nodes , Mesh1D )
113- )
114- else nodes
115- )
116- self .boundary_nodes = boundary_nodes
117- self .ppe = ppe
118- self .particles = (
119- self ._init_particles (particle_distribution )
120- if (
121- particles is None
122- or type (particles ) is object
123- or isinstance (particles , Mesh1D )
124- )
125- else particles
126- )
127- return
128-
129-
130- class Mesh2D :
131- """
132- 2D Mesh class with nodes, elements, and particles.
133- """
48+ @register_pytree_node_class
49+ class Mesh1D (_MeshBase ):
50+ """1D Mesh class with nodes, elements, and particles."""
13451
135- def __init__ (
136- self ,
137- nelements ,
138- material ,
139- domain_size ,
140- boundary_nodes ,
141- * ,
142- ppe = 1 ,
143- particle_distribution = "uniform" ,
144- elements = None ,
145- nodes = None ,
146- particles = None ,
147- shapefn = None ,
148- dim = 1 ,
149- ):
52+ def __init__ (self , config : dict ):
15053 """
151- Construct a 2D Mesh using 4-Node Quadrilateral Elements.
152-
153- Nodes and elements are numbered as
154-
155- 0---0---0---0---0
156- | 8 | 9 | 10| 11|
157- 10 0---0---0---0---0
158- | 4 | 5 | 6 | 7 |
159- 5 0---0---0---0---0 9
160- | 0 | 1 | 2 | 3 |
161- 0---0---0---0---0
162- 0 1 2 3 4
163-
54+ Initialize a 1D Mesh.
16455
16556 Arguments
16657 ---------
167- nelements : array_like
168- Number of elements in the mesh in the x and y direction.
169- material : diffmpm.material.Material
170- Material to meshed.
171- domain_size : 4-tuple, array_like
172- The boundaries of the domain. Should be of the form
173- (x_min, x_max, y_min, y_max)
174- boundary_nodes : array_like
175- Node ids of boundary nodes of the mesh. Needs to be a JAX
176- array.
177- ppe : int
178- Number of particles per element in Mesh.
58+ config: dict
59+ Configuration to be used for initialization. It _should_
60+ contain `elements` and `particles` keys.
17961 """
180- self .dim = 2
181- self .material = material
182- self .shapefn = (
183- Linear4NodeQuad ()
184- if (
185- shapefn is None
186- or type (shapefn ) is object
187- or isinstance (shapefn , Mesh1D )
188- )
189- else shapefn
190- )
191- self .domain_size = domain_size
192- self .nelements = jnp .asarray (nelements )
193- self .element_length = jnp .array (
194- [
195- (domain_size [1 ] - domain_size [0 ]) / nelements [0 ],
196- (domain_size [3 ] - domain_size [2 ]) / nelements [1 ],
197- ]
198- )
199- self .elements = (
200- jnp .arange (self .nelements [0 ] * self .nelements [1 ])
201- if elements is None
202- else elements
203- )
204- nnodes = jnp .product (self .nelements + 1 )
205- coords = jnp .asarray (
206- list (
207- itertools .product (
208- jnp .arange (nelements [1 ] + 1 ), jnp .arange (nelements [0 ] + 1 )
209- )
210- )
211- )
212- node_positions = (
213- jnp .asarray ([coords [:, 1 ], coords [:, 0 ]]).T * self .element_length
214- )
215-
216- self .nodes = (
217- Nodes (
218- nnodes ,
219- node_positions ,
220- jnp .zeros ((nnodes , 2 )),
221- jnp .zeros (nnodes ),
222- jnp .zeros ((nnodes , 2 )),
223- jnp .zeros ((nnodes , 2 )),
224- jnp .zeros ((nnodes , 2 )),
225- jnp .zeros ((nnodes , 2 )),
226- )
227- if (
228- nodes is None
229- or type (nodes ) is object
230- or isinstance (nodes , Mesh1D )
231- )
232- else nodes
233- )
234- self .boundary_nodes = boundary_nodes
235- self .ppe = ppe
236- self .particles = particles
237- return
62+ super ().__init__ (config )
23863
23964
24065if __name__ == "__main__" :
241- from diffmpm .utils import _show_example
242- from diffmpm .particle import Particles
24366 from diffmpm .element import Linear1D
24467 from diffmpm .material import SimpleMaterial
68+ from diffmpm .utils import _show_example
24569
24670 particles = Particles (
24771 jnp .array ([[[1 ]]]),
24872 SimpleMaterial ({"E" : 2 , "density" : 1 }),
24973 jnp .array ([0 ]),
25074 )
25175 elements = Linear1D (2 , 1 , jnp .array ([0 ]))
252- _show_example (_MeshBase ({"particles" : [particles ], "elements" : elements }))
76+ _show_example (Mesh1D ({"particles" : [particles ], "elements" : elements }))
0 commit comments