Skip to content

Commit 153a5b6

Browse files
committed
Code cleanup; doesn't change any functionality
1 parent 8912dd7 commit 153a5b6

7 files changed

Lines changed: 84 additions & 334 deletions

File tree

diffmpm/element.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def tree_flatten(self):
3131
def tree_unflatten(self):
3232
...
3333

34+
@abc.abstractmethod
35+
def shapefn(self):
36+
...
37+
38+
@abc.abstractmethod
39+
def shapefn_grad(self):
40+
...
41+
3442

3543
@register_pytree_node_class
3644
class Linear1D(_Element):
@@ -64,10 +72,10 @@ def __init__(
6472
IDs of nodes that are supposed to be fixed (boundary).
6573
"""
6674
self.nelements = nelements
67-
self.ids: jnp.ndarray = jnp.arange(nelements)
68-
self.el_len: float = el_len
75+
self.ids = jnp.arange(nelements)
76+
self.el_len = el_len
6977
if nodes is None:
70-
self.nodes: Nodes = Nodes(
78+
self.nodes = Nodes(
7179
nelements + 1,
7280
jnp.arange(nelements + 1).reshape(-1, 1, 1) * el_len,
7381
)
@@ -136,7 +144,7 @@ def id_to_node_vel(self, id: int):
136144
"""
137145
return self.nodes.velocity[jnp.array([id, id + 1])].reshape(2, 1)
138146

139-
def shapefn(self, xi):
147+
def shapefn(self, xi: float | jnp.ndarray):
140148
"""
141149
Evaluate linear shape function.
142150
@@ -164,7 +172,7 @@ def shapefn(self, xi):
164172
)
165173
return result
166174

167-
def _shapefn_natural_grad(self, xi):
175+
def _shapefn_natural_grad(self, xi: float | jnp.ndarray):
168176
"""
169177
Calculate the gradient of shape function.
170178
@@ -197,7 +205,7 @@ def _shapefn_natural_grad(self, xi):
197205
# )
198206
return result.reshape(-1, 2)
199207

200-
def shapefn_grad(self, xi, coords):
208+
def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray):
201209
"""
202210
Gradient of shape function in physical coordinates.
203211
@@ -219,8 +227,6 @@ def shapefn_grad(self, xi, coords):
219227
raise ValueError(
220228
f"`x` should be of size (npoints, 1, ndim); found {xi.shape}"
221229
)
222-
# natural_grad = self._shapefn_natural_grad(x)
223-
# result = natural_grad * 2 / (coords[1] - coords[0])
224230
grad_sf = self._shapefn_natural_grad(xi)
225231
_jacobian = grad_sf @ coords
226232

@@ -286,8 +292,6 @@ def _step(pid, args):
286292
mapped_positions,
287293
mapped_nodes,
288294
)
289-
# breakpoint()
290-
# _step(0, args)
291295
_, self.nodes.mass, _, _ = lax.fori_loop(0, len(particles), _step, args)
292296

293297
def compute_nodal_momentum(self, particles):
@@ -489,7 +493,8 @@ def _step(pid, args):
489493
0, len(particles), _step, args
490494
)
491495

492-
def update_nodal_momentum(self, particles, dt, *args):
496+
def update_nodal_momentum(self, particles, dt: float, *args):
497+
"""Update the nodal momentum based on total force on nodes."""
493498
total_force = self.nodes.get_total_force()
494499
self.nodes.acceleration = self.nodes.acceleration.at[:].set(
495500
jnp.divide(total_force, self.nodes.mass)
@@ -500,13 +505,15 @@ def update_nodal_momentum(self, particles, dt, *args):
500505
self.nodes.momentum = self.nodes.momentum.at[:].add(total_force * dt)
501506

502507
def apply_boundary_constraints(self, *args):
508+
"""Apply boundary conditions for nodal velocity."""
503509
self.nodes.velocity = self.nodes.velocity.at[self.boundary_nodes].set(0)
504510
self.nodes.momentum = self.nodes.momentum.at[self.boundary_nodes].set(0)
505511
self.nodes.acceleration = self.nodes.acceleration.at[
506512
self.boundary_nodes
507513
].set(0)
508514

509515
def apply_force_boundary_constraints(self, *args):
516+
"""Apply boundary conditions for nodal forces."""
510517
self.nodes.f_int = self.nodes.f_int.at[self.boundary_nodes].set(0)
511518
self.nodes.f_ext = self.nodes.f_ext.at[self.boundary_nodes].set(0)
512519
self.nodes.f_damp = self.nodes.f_damp.at[self.boundary_nodes].set(0)

diffmpm/mesh.py

Lines changed: 24 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
import itertools
2-
from typing import Sequence
1+
import abc
2+
from typing import Iterable
33

44
import jax.numpy as jnp
5-
from jax import jit, lax, vmap
65
from jax.tree_util import register_pytree_node_class
7-
from jax_tqdm import loop_tqdm
8-
from tqdm import tqdm
96

107
from diffmpm.element import _Element
11-
from diffmpm.node import Nodes
12-
from diffmpm.shapefn import Linear1DShapeFn, Linear4NodeQuad
8+
from diffmpm.particle import Particles
139

1410

15-
@register_pytree_node_class
16-
class _MeshBase:
11+
class _MeshBase(abc.ABC):
12+
"""
13+
Base class for Meshes.
14+
15+
Note: If attributes other than elements and particles are added
16+
then the child class should also implement `tree_flatten` and
17+
`tree_unflatten` correctly or that information will get lost.
18+
"""
19+
1720
def __init__(self, config: dict):
1821
"""Initialize mesh using configuration."""
19-
self.particles: Sequence = config["particles"]
22+
self.particles: Iterable[Particles, ...] = config["particles"]
2023
self.elements: _Element = config["elements"]
2124

2225
# TODO: Convert to using jax directives for loop
@@ -42,211 +45,32 @@ def tree_unflatten(cls, aux_data, children):
4245
return cls({"particles": children[0], "elements": children[1]})
4346

4447

45-
class Mesh1D:
46-
"""
47-
1D Mesh class with nodes, elements, and particles.
48-
"""
49-
50-
def __init__(
51-
self,
52-
nelements,
53-
material,
54-
domain_size,
55-
boundary_nodes,
56-
*,
57-
ppe=1,
58-
particle_distribution="uniform",
59-
elements=None,
60-
nodes=None,
61-
particles=None,
62-
shapefn=None,
63-
dim=1,
64-
):
65-
"""
66-
Construct a 1D Mesh.
67-
68-
Arguments
69-
---------
70-
nelements : int
71-
Number of elements in the mesh.
72-
material : diffmpm.material.Material
73-
Material to meshed.
74-
domain_size : float
75-
The size of the domain in consideration.
76-
boundary_nodes : array_like
77-
Node ids of boundary nodes of the mesh. Needs to be a JAX
78-
array.
79-
ppe : int
80-
Number of particles per element in Mesh.
81-
"""
82-
self.dim = dim
83-
self.material = material
84-
self.shapefn = (
85-
Linear1DShapeFn(self.dim)
86-
if (
87-
shapefn is None
88-
or type(shapefn) is object
89-
or isinstance(shapefn, Mesh1D)
90-
)
91-
else shapefn
92-
)
93-
self.domain_size = domain_size
94-
self.nelements = nelements
95-
self.element_length = domain_size / nelements
96-
self.elements = jnp.arange(nelements) if elements is None else elements
97-
nnodes = nelements + 1
98-
self.nodes = (
99-
Nodes(
100-
nnodes,
101-
jnp.arange(nelements + 1) * self.element_length,
102-
jnp.zeros(nnodes),
103-
jnp.zeros(nnodes),
104-
jnp.zeros(nnodes),
105-
jnp.zeros(nnodes),
106-
jnp.zeros(nnodes),
107-
jnp.zeros(nnodes),
108-
)
109-
if (
110-
nodes is None
111-
or type(nodes) is object
112-
or isinstance(nodes, Mesh1D)
113-
)
114-
else nodes
115-
)
116-
self.boundary_nodes = boundary_nodes
117-
self.ppe = ppe
118-
self.particles = (
119-
self._init_particles(particle_distribution)
120-
if (
121-
particles is None
122-
or type(particles) is object
123-
or isinstance(particles, Mesh1D)
124-
)
125-
else particles
126-
)
127-
return
128-
129-
130-
class Mesh2D:
131-
"""
132-
2D Mesh class with nodes, elements, and particles.
133-
"""
48+
@register_pytree_node_class
49+
class Mesh1D(_MeshBase):
50+
"""1D Mesh class with nodes, elements, and particles."""
13451

135-
def __init__(
136-
self,
137-
nelements,
138-
material,
139-
domain_size,
140-
boundary_nodes,
141-
*,
142-
ppe=1,
143-
particle_distribution="uniform",
144-
elements=None,
145-
nodes=None,
146-
particles=None,
147-
shapefn=None,
148-
dim=1,
149-
):
52+
def __init__(self, config: dict):
15053
"""
151-
Construct a 2D Mesh using 4-Node Quadrilateral Elements.
152-
153-
Nodes and elements are numbered as
154-
155-
0---0---0---0---0
156-
| 8 | 9 | 10| 11|
157-
10 0---0---0---0---0
158-
| 4 | 5 | 6 | 7 |
159-
5 0---0---0---0---0 9
160-
| 0 | 1 | 2 | 3 |
161-
0---0---0---0---0
162-
0 1 2 3 4
163-
54+
Initialize a 1D Mesh.
16455
16556
Arguments
16657
---------
167-
nelements : array_like
168-
Number of elements in the mesh in the x and y direction.
169-
material : diffmpm.material.Material
170-
Material to meshed.
171-
domain_size : 4-tuple, array_like
172-
The boundaries of the domain. Should be of the form
173-
(x_min, x_max, y_min, y_max)
174-
boundary_nodes : array_like
175-
Node ids of boundary nodes of the mesh. Needs to be a JAX
176-
array.
177-
ppe : int
178-
Number of particles per element in Mesh.
58+
config: dict
59+
Configuration to be used for initialization. It _should_
60+
contain `elements` and `particles` keys.
17961
"""
180-
self.dim = 2
181-
self.material = material
182-
self.shapefn = (
183-
Linear4NodeQuad()
184-
if (
185-
shapefn is None
186-
or type(shapefn) is object
187-
or isinstance(shapefn, Mesh1D)
188-
)
189-
else shapefn
190-
)
191-
self.domain_size = domain_size
192-
self.nelements = jnp.asarray(nelements)
193-
self.element_length = jnp.array(
194-
[
195-
(domain_size[1] - domain_size[0]) / nelements[0],
196-
(domain_size[3] - domain_size[2]) / nelements[1],
197-
]
198-
)
199-
self.elements = (
200-
jnp.arange(self.nelements[0] * self.nelements[1])
201-
if elements is None
202-
else elements
203-
)
204-
nnodes = jnp.product(self.nelements + 1)
205-
coords = jnp.asarray(
206-
list(
207-
itertools.product(
208-
jnp.arange(nelements[1] + 1), jnp.arange(nelements[0] + 1)
209-
)
210-
)
211-
)
212-
node_positions = (
213-
jnp.asarray([coords[:, 1], coords[:, 0]]).T * self.element_length
214-
)
215-
216-
self.nodes = (
217-
Nodes(
218-
nnodes,
219-
node_positions,
220-
jnp.zeros((nnodes, 2)),
221-
jnp.zeros(nnodes),
222-
jnp.zeros((nnodes, 2)),
223-
jnp.zeros((nnodes, 2)),
224-
jnp.zeros((nnodes, 2)),
225-
jnp.zeros((nnodes, 2)),
226-
)
227-
if (
228-
nodes is None
229-
or type(nodes) is object
230-
or isinstance(nodes, Mesh1D)
231-
)
232-
else nodes
233-
)
234-
self.boundary_nodes = boundary_nodes
235-
self.ppe = ppe
236-
self.particles = particles
237-
return
62+
super().__init__(config)
23863

23964

24065
if __name__ == "__main__":
241-
from diffmpm.utils import _show_example
242-
from diffmpm.particle import Particles
24366
from diffmpm.element import Linear1D
24467
from diffmpm.material import SimpleMaterial
68+
from diffmpm.utils import _show_example
24569

24670
particles = Particles(
24771
jnp.array([[[1]]]),
24872
SimpleMaterial({"E": 2, "density": 1}),
24973
jnp.array([0]),
25074
)
25175
elements = Linear1D(2, 1, jnp.array([0]))
252-
_show_example(_MeshBase({"particles": [particles], "elements": elements}))
76+
_show_example(Mesh1D({"particles": [particles], "elements": elements}))

0 commit comments

Comments
 (0)