Skip to content

Commit 2fc82a5

Browse files
committed
1D MPM bar differentiable code
1 parent f64ff9d commit 2fc82a5

1 file changed

Lines changed: 228 additions & 0 deletions

File tree

diffmpm-1dbar.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
from tensorflow_probability.substrates import jax as tfp
2+
import jax.numpy as jnp
3+
from jax import grad, jit, vmap, lax
4+
import jax.scipy as jsp
5+
import jax.scipy.optimize as jsp_opt
6+
import optax
7+
import jaxopt
8+
from jaxopt import ScipyBoundedMinimize
9+
import matplotlib.pyplot as plt
10+
import jax
11+
12+
@jit
13+
def mpm(E):
14+
# nsteps
15+
nsteps = 5200
16+
17+
# mom tolerance
18+
tol = 1e-12
19+
20+
# Domain
21+
L = 25
22+
23+
# Material properties
24+
# E = 100
25+
rho = 1
26+
27+
# Computational grid
28+
29+
nelements = 13 # number of elements
30+
dx = L / nelements # element length
31+
32+
# Create equally spaced nodes
33+
x_n = jnp.linspace(0, L, nelements+1)
34+
nnodes = len(x_n)
35+
36+
# Set-up a 2D array of elements with node ids
37+
elements = jnp.zeros((nelements, 2), dtype = int)
38+
for nid in range(nelements):
39+
elements = elements.at[nid,0].set(nid)
40+
elements = elements.at[nid,1].set(nid+1)
41+
42+
# Loading conditions
43+
v0 = 0.1 # initial velocity
44+
c = jnp.sqrt(E/rho) # speed of sound
45+
b1 = jnp.pi / (2 * L) # beta1
46+
w1 = b1 * c # omega1
47+
48+
# Create material points at the center of each element
49+
nparticles = nelements # number of particles
50+
# Id of the particle in the central element
51+
pmid = 6
52+
53+
# Material point properties
54+
x_p = jnp.zeros(nparticles) # positions
55+
vol_p = jnp.ones(nparticles) * dx # volume
56+
mass_p = vol_p * rho # mass
57+
stress_p = jnp.zeros(nparticles) # stress
58+
vel_p = jnp.zeros(nparticles) # velocity
59+
60+
# Create particle at the center
61+
x_p = 0.5 * (x_n[:-1] + x_n[1:])
62+
63+
# set initial velocities
64+
vel_p = v0 * jnp.sin(b1 * x_p)
65+
66+
# Time steps and duration
67+
dt_crit = dx / c
68+
dt = 0.02
69+
70+
# results
71+
tt = jnp.zeros(nsteps)
72+
vt = jnp.zeros(nsteps)
73+
xt = jnp.zeros(nsteps)
74+
75+
def step(i, carry):
76+
x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = carry
77+
# reset nodal values
78+
mass_n = jnp.zeros(nnodes) # mass
79+
mom_n = jnp.zeros(nnodes) # momentum
80+
fint_n = jnp.zeros(nnodes) # internal force
81+
82+
# iterate through each element
83+
for eid in range(nelements):
84+
# get nodal ids
85+
nid1, nid2 = elements[eid]
86+
87+
# compute shape functions and derivatives
88+
N1 = 1 - abs(x_p[eid] - x_n[nid1])/dx
89+
N2 = 1 - abs(x_p[eid] - x_n[nid2])/dx
90+
dN1 = -1/dx
91+
dN2 = 1/dx
92+
93+
# map particle mass and momentum to nodes
94+
mass_n = mass_n.at[nid1].set(mass_n[nid1] + N1 * mass_p[eid])
95+
mass_n = mass_n.at[nid2].set(mass_n[nid2] + N2 * mass_p[eid])
96+
97+
mom_n = mom_n.at[nid1].set(mom_n[nid1] + N1 * mass_p[eid] * vel_p[eid])
98+
mom_n = mom_n.at[nid2].set(mom_n[nid2] + N2 * mass_p[eid] * vel_p[eid])
99+
100+
# compute nodal internal force
101+
fint_n = fint_n.at[nid1].set(fint_n[nid1] - vol_p[eid] * stress_p[eid] * dN1)
102+
fint_n = fint_n.at[nid2].set(fint_n[nid2] - vol_p[eid] * stress_p[eid] * dN2)
103+
104+
# apply boundary conditions
105+
mom_n = mom_n.at[0].set(0) # Nodal velocity v = 0 in m * v at node 0.
106+
fint_n = fint_n.at[0].set(0) # Nodal force f = m * a, where a = 0 at node 0.
107+
108+
# update nodal momentum
109+
mom_n = mom_n + fint_n * dt
110+
111+
# update particle velocity position and stress
112+
# iterate through each element
113+
for eid in range(nelements):
114+
# get nodal ids
115+
nid1, nid2 = elements[eid]
116+
117+
# compute shape functions and derivatives
118+
N1 = 1 - abs(x_p[eid] - x_n[nid1])/dx
119+
N2 = 1 - abs(x_p[eid] - x_n[nid2])/dx
120+
dN1 = -1/dx
121+
dN2 = 1/dx
122+
123+
# compute particle velocity
124+
# if (mass_n[nid1]) > tol:
125+
vel_p = vel_p.at[eid].set(vel_p[eid] + dt * N1 * fint_n[nid1] / mass_n[nid1])
126+
# if (mass_n[nid2]) > tol:
127+
vel_p = vel_p.at[eid].set(vel_p[eid] + dt * N2 * fint_n[nid2] / mass_n[nid2])
128+
129+
# update particle position based on nodal momentum
130+
x_p = x_p.at[eid].set(x_p[eid] + dt * (N1 * mom_n[nid1]/mass_n[nid1] + N2 * mom_n[nid2]/mass_n[nid2]))
131+
132+
# nodal velocity
133+
nv1 = mom_n[nid1]/mass_n[nid1]
134+
nv2 = mom_n[nid2]/mass_n[nid2]
135+
136+
# rate of strain increment
137+
grad_v = dN1 * nv1 + dN2 * nv2
138+
# particle dstrain
139+
dstrain = grad_v * dt
140+
# particle volume
141+
vol_p = vol_p.at[eid].set((1 + dstrain) * vol_p[eid])
142+
# update stress using linear elastic model
143+
stress_p = stress_p.at[eid].set(stress_p[eid] + E * dstrain)
144+
145+
# results
146+
vt = vt.at[i].set(vel_p[pmid])
147+
xt = xt.at[i].set(x_p[pmid])
148+
149+
return (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt)
150+
151+
x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = lax.fori_loop(0, nsteps, step, (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt))
152+
153+
154+
return vt
155+
156+
157+
# Assign target
158+
Etarget = 100
159+
target = mpm(Etarget)
160+
161+
162+
#############################################################
163+
# NOTE: Uncomment the line only for TFP optimizer and
164+
# jaxopt value_and_grad = True
165+
#############################################################
166+
# @jax.value_and_grad
167+
@jit
168+
def compute_loss(E):
169+
vt = mpm(E)
170+
return jnp.linalg.norm(vt - target)
171+
172+
# BFGS Optimizer
173+
# TODO: Implement box constrained optimizer
174+
def jaxopt_bfgs(params, niter):
175+
opt= jaxopt.BFGS(fun=compute_loss, value_and_grad=True, tol=1e-5, implicit_diff=False, maxiter=niter)
176+
res = opt.run(init_params=params)
177+
result, _ = res
178+
return result
179+
180+
# Optimizers
181+
def optax_adam(params, niter):
182+
# Initialize parameters of the model + optimizer.
183+
start_learning_rate = 1e-1
184+
optimizer = optax.adam(start_learning_rate)
185+
opt_state = optimizer.init(params)
186+
187+
# A simple update loop.
188+
for _ in range(niter):
189+
grads = grad(compute_loss)(params)
190+
updates, opt_state = optimizer.update(grads, opt_state)
191+
params = optax.apply_updates(params, updates)
192+
return params
193+
194+
# Tensor Flow Probability Optimization library
195+
def tfp_lbfgs(params):
196+
results = tfp.optimizer.lbfgs_minimize(
197+
jax.jit(compute_loss), initial_position=params, tolerance=1e-5)
198+
return results.position
199+
200+
# Initial model - Young's modulus
201+
params = 95.0
202+
203+
# vt = tfp_lbfgs(params) # LBFGS optimizer
204+
result = optax_adam(params, 1000) # ADAM optimizer
205+
206+
"""
207+
f = jax.jit(compute_loss)
208+
df = jax.jit(jax.grad(compute_loss))
209+
E = 95.0
210+
print(0, E)
211+
for i in range(10):
212+
E = E - f(E)/df(E)
213+
print(i, E)
214+
"""
215+
print("E: {}".format(result))
216+
vel = mpm(result)
217+
# update time steps
218+
dt = 0.02
219+
nsteps = 5200
220+
tt = jnp.arange(0, nsteps) * dt
221+
222+
# Plot results
223+
plt.plot(tt, vel, 'r', markersize=1, label='mpm')
224+
plt.plot(tt, target, 'ob', markersize=1, label='mpm-target')
225+
plt.xlabel('time (s)')
226+
plt.ylabel('velocity (m/s)')
227+
plt.legend()
228+
plt.show()

0 commit comments

Comments
 (0)