Skip to content

Commit 693a0df

Browse files
authored
Merge pull request #1 from chahak13/jaxport
[WIP] Port library to JAX.
2 parents 2fc82a5 + 31a7c1f commit 693a0df

11 files changed

Lines changed: 2760 additions & 85 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# DiffMPM

diffmpm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#!/usr/bin/env python3

diffmpm/material.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
1-
import jax.numpy as jnp
1+
from jax.tree_util import register_pytree_node_class
22

3-
class LinearElastic:
4-
def __init__(self, E, density):
3+
4+
@register_pytree_node_class
5+
class Material:
6+
"""
7+
Base material class.
8+
"""
9+
10+
def __init__(self, E, density):
11+
"""
12+
Initialize material properties.
13+
14+
Arguments
15+
---------
16+
E : float
17+
Young's modulus of the material.
18+
density : float
19+
Density of the material.
20+
"""
521
self.E = E
622
self.density = density
723

8-
def update_stress(self, particle, dt):
9-
particle.stress+=particle.dstrain*self.E
24+
def tree_flatten(self):
25+
return (tuple(), (self.E, self.density))
26+
27+
@classmethod
28+
def tree_unflatten(cls, aux_data, children):
29+
del children
30+
return cls(*aux_data)
31+
32+
def __repr__(self):
33+
return f"Material(E={self.E}, density={self.density})"

0 commit comments

Comments
 (0)