|
| 1 | +import json |
| 2 | +import tomllib as tl |
| 3 | + |
| 4 | +import jax.numpy as jnp |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from diffmpm import element as mpel |
| 8 | +from diffmpm import material as mpmat |
| 9 | +from diffmpm import mesh as mpmesh |
| 10 | +from diffmpm.node import Nodes |
| 11 | +from diffmpm.particle import Particles |
| 12 | + |
| 13 | + |
| 14 | +class Config: |
| 15 | + def __init__(self, filepath): |
| 16 | + self._filepath = filepath |
| 17 | + self.config = {} |
| 18 | + self.parse() |
| 19 | + |
| 20 | + def parse(self): |
| 21 | + with open(self._filepath, "rb") as f: |
| 22 | + self._fileconfig = tl.load(f) |
| 23 | + |
| 24 | + self._parse_meta(self._fileconfig) |
| 25 | + self._parse_output(self._fileconfig) |
| 26 | + self._parse_materials(self._fileconfig) |
| 27 | + self._parse_particles(self._fileconfig) |
| 28 | + mesh = self._parse_mesh(self._fileconfig) |
| 29 | + return mesh |
| 30 | + |
| 31 | + def _parse_meta(self, config): |
| 32 | + self.config["meta"] = config["meta"] |
| 33 | + |
| 34 | + def _parse_output(self, config): |
| 35 | + self.config["output"] = config["output"] |
| 36 | + |
| 37 | + def _parse_materials(self, config): |
| 38 | + materials = [] |
| 39 | + for mat_config in config["materials"]: |
| 40 | + mat_type = mat_config.pop("type") |
| 41 | + mat_cls = getattr(mpmat, mat_type) |
| 42 | + mat = mat_cls(mat_config) |
| 43 | + materials.append(mat) |
| 44 | + self.config["materials"] = materials |
| 45 | + |
| 46 | + def _parse_particles(self, config): |
| 47 | + particle_sets = [] |
| 48 | + for pset_config in config["particles"]: |
| 49 | + pmat = self.config["materials"][pset_config["material_id"]] |
| 50 | + with open(pset_config["file"], "r") as f: |
| 51 | + ploc = jnp.asarray(json.load(f)) |
| 52 | + peids = jnp.zeros(ploc.shape[0], dtype=jnp.int32) |
| 53 | + pset = Particles(ploc, pmat, peids) |
| 54 | + pset.velocity = pset.velocity.at[:].set( |
| 55 | + pset_config["init_velocity"] |
| 56 | + ) |
| 57 | + particle_sets.append(pset) |
| 58 | + self.config["particles"] = particle_sets |
| 59 | + |
| 60 | + def _parse_mesh(self, config): |
| 61 | + element_cls = getattr(mpel, config["mesh"]["element"]) |
| 62 | + mesh_cls = getattr(mpmesh, f"Mesh{config['meta']['dimension']}D") |
| 63 | + if config["mesh"]["type"] == "file": |
| 64 | + nodes_loc = jnp.asarray(np.loadtxt(config["mesh"]["file"])) |
| 65 | + nodes = Nodes(len(nodes_loc), nodes_loc) |
| 66 | + elements = element_cls(nelements, el_len, boundary_nodes) |
| 67 | + elif config["mesh"]["type"] == "generator": |
| 68 | + elements = element_cls( |
| 69 | + config["mesh"]["nelements"], |
| 70 | + config["mesh"]["element_length"], |
| 71 | + jnp.asarray(config["mesh"]["boundary_nodes"]), |
| 72 | + ) |
| 73 | + self.config["elements"] = elements |
| 74 | + mesh = mesh_cls(self.config) |
| 75 | + return mesh |
0 commit comments