Skip to content

Commit efd7843

Browse files
committed
Restructure code and add materials subdir
To make the imports easier from the materials subdir, also restructured other files. This moves `MPM` to a separate file so as to remove circular imports for materials module.
1 parent f6946b6 commit efd7843

16 files changed

Lines changed: 126 additions & 120 deletions

File tree

benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import jax.numpy as jnp
55

6-
from diffmpm import MPM
6+
from diffmpm.mpm import MPM
77

88

99
def test_benchmarks():

benchmarks/2d/uniaxial_particle_traction/test_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import jax.numpy as jnp
55

6-
from diffmpm import MPM
6+
from diffmpm.mpm import MPM
77

88

99
def test_benchmarks():

benchmarks/2d/uniaxial_stress/test_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import jax.numpy as jnp
55

6-
from diffmpm import MPM
6+
from diffmpm.mpm import MPM
77

88

99
def test_benchmarks():

diffmpm/__init__.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,5 @@
11
from importlib.metadata import version
2-
from pathlib import Path
32

4-
import diffmpm.writers as writers
5-
from diffmpm.io import Config
6-
from diffmpm.solver import MPMExplicit
7-
8-
__all__ = ["MPM", "__version__"]
3+
__all__ = ["__version__"]
94

105
__version__ = version("diffmpm")
11-
12-
13-
class MPM:
14-
def __init__(self, filepath):
15-
self._config = Config(filepath)
16-
mesh = self._config.parse()
17-
out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath(
18-
self._config.parsed_config["meta"]["title"],
19-
)
20-
21-
write_format = self._config.parsed_config["output"].get("format", None)
22-
if write_format is None or write_format.lower() == "none":
23-
writer_func = None
24-
elif write_format == "npz":
25-
writer_func = writers.NPZWriter().write
26-
else:
27-
raise ValueError(f"Specified output format not supported: {write_format}")
28-
29-
if self._config.parsed_config["meta"]["type"] == "MPMExplicit":
30-
self.solver = MPMExplicit(
31-
mesh,
32-
self._config.parsed_config["meta"]["dt"],
33-
velocity_update=self._config.parsed_config["meta"]["velocity_update"],
34-
sim_steps=self._config.parsed_config["meta"]["nsteps"],
35-
out_steps=self._config.parsed_config["output"]["step_frequency"],
36-
out_dir=out_dir,
37-
writer_func=writer_func,
38-
)
39-
else:
40-
raise ValueError("Wrong type of solver specified.")
41-
42-
def solve(self):
43-
"""Solve the MPM simulation using JIT solver."""
44-
arrays = self.solver.solve_jit(
45-
self._config.parsed_config["external_loading"]["gravity"],
46-
)
47-
return arrays

diffmpm/cli/mpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import click
22

3-
from diffmpm import MPM
3+
from diffmpm.mpm import MPM
44

55

66
@click.command() # type: ignore

diffmpm/io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import json
22
import tomllib as tl
3-
from collections import namedtuple
43

54
import jax.numpy as jnp
65

76
from diffmpm import element as mpel
8-
from diffmpm import material as mpmat
7+
from diffmpm import materials as mpmat
98
from diffmpm import mesh as mpmesh
109
from diffmpm.constraint import Constraint
1110
from diffmpm.forces import NodalForce, ParticleTraction

diffmpm/materials/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from diffmpm.materials._base import _Material
2+
from diffmpm.materials.simple import SimpleMaterial
3+
from diffmpm.materials.linear_elastic import LinearElastic

diffmpm/materials/_base.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import abc
2+
from typing import Tuple
3+
4+
5+
class _Material(abc.ABC):
6+
"""Base material class."""
7+
8+
_props: Tuple[str, ...]
9+
10+
def __init__(self, material_properties):
11+
"""Initialize material properties.
12+
13+
Parameters
14+
----------
15+
material_properties: dict
16+
A key-value map for various material properties.
17+
"""
18+
self.properties = material_properties
19+
20+
# @abc.abstractmethod
21+
def tree_flatten(self):
22+
"""Flatten this class as PyTree Node."""
23+
return (tuple(), self.properties)
24+
25+
# @abc.abstractmethod
26+
@classmethod
27+
def tree_unflatten(cls, aux_data, children):
28+
"""Unflatten this class as PyTree Node."""
29+
del children
30+
return cls(aux_data)
31+
32+
@abc.abstractmethod
33+
def __repr__(self):
34+
"""Repr for Material class."""
35+
...
36+
37+
@abc.abstractmethod
38+
def compute_stress(self):
39+
"""Compute stress for the material."""
40+
...
41+
42+
def validate_props(self, material_properties):
43+
for key in self._props:
44+
if key not in material_properties:
45+
raise KeyError(
46+
f"'{key}' should be present in `material_properties` "
47+
f"for {self.__class__.__name__} materials."
48+
)
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,11 @@
1-
import abc
2-
from typing import Tuple
3-
41
import jax.numpy as jnp
52
from jax.tree_util import register_pytree_node_class
63

7-
8-
class Material(abc.ABC):
9-
"""Base material class."""
10-
11-
_props: Tuple[str, ...]
12-
13-
def __init__(self, material_properties):
14-
"""Initialize material properties.
15-
16-
Parameters
17-
----------
18-
material_properties: dict
19-
A key-value map for various material properties.
20-
"""
21-
self.properties = material_properties
22-
23-
# @abc.abstractmethod
24-
def tree_flatten(self):
25-
"""Flatten this class as PyTree Node."""
26-
return (tuple(), self.properties)
27-
28-
# @abc.abstractmethod
29-
@classmethod
30-
def tree_unflatten(cls, aux_data, children):
31-
"""Unflatten this class as PyTree Node."""
32-
del children
33-
return cls(aux_data)
34-
35-
@abc.abstractmethod
36-
def __repr__(self):
37-
"""Repr for Material class."""
38-
...
39-
40-
@abc.abstractmethod
41-
def compute_stress(self):
42-
"""Compute stress for the material."""
43-
...
44-
45-
def validate_props(self, material_properties):
46-
for key in self._props:
47-
if key not in material_properties:
48-
raise KeyError(
49-
f"'{key}' should be present in `material_properties` "
50-
f"for {self.__class__.__name__} materials."
51-
)
4+
from ._base import _Material
525

536

547
@register_pytree_node_class
55-
class LinearElastic(Material):
8+
class LinearElastic(_Material):
569
"""Linear Elastic Material."""
5710

5811
_props = ("density", "youngs_modulus", "poisson_ratio")
@@ -114,18 +67,3 @@ def compute_stress(self, dstrain):
11467
"""Compute material stress."""
11568
dstress = self.de @ dstrain
11669
return dstress
117-
118-
119-
@register_pytree_node_class
120-
class SimpleMaterial(Material):
121-
_props = ("E", "density")
122-
123-
def __init__(self, material_properties):
124-
self.validate_props(material_properties)
125-
self.properties = material_properties
126-
127-
def __repr__(self):
128-
return f"SimpleMaterial(props={self.properties})"
129-
130-
def compute_stress(self, dstrain):
131-
return dstrain * self.properties["E"]

diffmpm/materials/newtonian.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#!/usr/bin/env python3

0 commit comments

Comments
 (0)