Skip to content

Commit 66f34e1

Browse files
authored
Merge pull request #3 from chahak13/refactor
Refactor to decouple classes
2 parents 693a0df + e3bf6db commit 66f34e1

20 files changed

Lines changed: 2607 additions & 1129 deletions

diffmpm-1dbar-old.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# from tensorflow_probability.substrates import jax as tfp
2+
import jax.numpy as jnp
3+
from jax import grad, jit, vmap, lax
4+
import jax.scipy as jsp
5+
import jax.scipy.optimize as jsp_opt
6+
import optax
7+
import jaxopt
8+
from jaxopt import ScipyBoundedMinimize
9+
import matplotlib.pyplot as plt
10+
import jax
11+
from tqdm import tqdm
12+
13+
14+
@jit
15+
def mpm(E):
16+
# nsteps
17+
nsteps = 5200
18+
19+
# mom tolerance
20+
tol = 1e-12
21+
22+
# Domain
23+
L = 25
24+
25+
# Material properties
26+
# E = 100
27+
rho = 1
28+
29+
# Computational grid
30+
31+
nelements = 13 # number of elements
32+
dx = L / nelements # element length
33+
34+
# Create equally spaced nodes
35+
x_n = jnp.linspace(0, L, nelements + 1)
36+
nnodes = len(x_n)
37+
38+
# Set-up a 2D array of elements with node ids
39+
elements = jnp.zeros((nelements, 2), dtype=int)
40+
for nid in range(nelements):
41+
elements = elements.at[nid, 0].set(nid)
42+
elements = elements.at[nid, 1].set(nid + 1)
43+
44+
# Loading conditions
45+
v0 = 0.1 # initial velocity
46+
c = jnp.sqrt(E / rho) # speed of sound
47+
b1 = jnp.pi / (2 * L) # beta1
48+
w1 = b1 * c # omega1
49+
50+
# Create material points at the center of each element
51+
nparticles = nelements # number of particles
52+
# Id of the particle in the central element
53+
pmid = 6
54+
55+
# Material point properties
56+
x_p = jnp.zeros(nparticles) # positions
57+
vol_p = jnp.ones(nparticles) * dx # volume
58+
mass_p = vol_p * rho # mass
59+
stress_p = jnp.zeros(nparticles) # stress
60+
vel_p = jnp.zeros(nparticles) # velocity
61+
62+
# Create particle at the center
63+
x_p = 0.5 * (x_n[:-1] + x_n[1:])
64+
65+
# set initial velocities
66+
vel_p = v0 * jnp.sin(b1 * x_p)
67+
68+
# Time steps and duration
69+
dt_crit = dx / c
70+
dt = 0.02
71+
72+
# results
73+
tt = jnp.zeros(nsteps)
74+
vt = jnp.zeros(nsteps)
75+
xt = jnp.zeros(nsteps)
76+
77+
def step(i, carry):
78+
x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = carry
79+
# reset nodal values
80+
mass_n = jnp.zeros(nnodes) # mass
81+
mom_n = jnp.zeros(nnodes) # momentum
82+
fint_n = jnp.zeros(nnodes) # internal force
83+
84+
# iterate through each element
85+
for eid in range(nelements):
86+
# get nodal ids
87+
nid1, nid2 = elements[eid]
88+
89+
# compute shape functions and derivatives
90+
N1 = 1 - abs(x_p[eid] - x_n[nid1]) / dx
91+
N2 = 1 - abs(x_p[eid] - x_n[nid2]) / dx
92+
dN1 = -1 / dx
93+
dN2 = 1 / dx
94+
95+
# map particle mass and momentum to nodes
96+
mass_n = mass_n.at[nid1].set(mass_n[nid1] + N1 * mass_p[eid])
97+
mass_n = mass_n.at[nid2].set(mass_n[nid2] + N2 * mass_p[eid])
98+
99+
mom_n = mom_n.at[nid1].set(
100+
mom_n[nid1] + N1 * mass_p[eid] * vel_p[eid]
101+
)
102+
mom_n = mom_n.at[nid2].set(
103+
mom_n[nid2] + N2 * mass_p[eid] * vel_p[eid]
104+
)
105+
106+
# compute nodal internal force
107+
fint_n = fint_n.at[nid1].set(
108+
fint_n[nid1] - vol_p[eid] * stress_p[eid] * dN1
109+
)
110+
fint_n = fint_n.at[nid2].set(
111+
fint_n[nid2] - vol_p[eid] * stress_p[eid] * dN2
112+
)
113+
114+
# apply boundary conditions
115+
mom_n = mom_n.at[0].set(0) # Nodal velocity v = 0 in m * v at node 0.
116+
fint_n = fint_n.at[0].set(
117+
0
118+
) # Nodal force f = m * a, where a = 0 at node 0.
119+
120+
# update nodal momentum
121+
mom_n = mom_n + fint_n * dt
122+
123+
# update particle velocity position and stress
124+
# iterate through each element
125+
for eid in range(nelements):
126+
# get nodal ids
127+
nid1, nid2 = elements[eid]
128+
129+
# compute shape functions and derivatives
130+
N1 = 1 - abs(x_p[eid] - x_n[nid1]) / dx
131+
N2 = 1 - abs(x_p[eid] - x_n[nid2]) / dx
132+
dN1 = -1 / dx
133+
dN2 = 1 / dx
134+
135+
# compute particle velocity
136+
# if (mass_n[nid1]) > tol:
137+
vel_p = vel_p.at[eid].set(
138+
vel_p[eid] + dt * N1 * fint_n[nid1] / mass_n[nid1]
139+
)
140+
# if (mass_n[nid2]) > tol:
141+
vel_p = vel_p.at[eid].set(
142+
vel_p[eid] + dt * N2 * fint_n[nid2] / mass_n[nid2]
143+
)
144+
145+
# update particle position based on nodal momentum
146+
x_p = x_p.at[eid].set(
147+
x_p[eid]
148+
+ dt
149+
* (
150+
N1 * mom_n[nid1] / mass_n[nid1]
151+
+ N2 * mom_n[nid2] / mass_n[nid2]
152+
)
153+
)
154+
155+
# nodal velocity
156+
nv1 = mom_n[nid1] / mass_n[nid1]
157+
nv2 = mom_n[nid2] / mass_n[nid2]
158+
159+
# rate of strain increment
160+
grad_v = dN1 * nv1 + dN2 * nv2
161+
# particle dstrain
162+
dstrain = grad_v * dt
163+
# particle volume
164+
vol_p = vol_p.at[eid].set((1 + dstrain) * vol_p[eid])
165+
# update stress using linear elastic model
166+
stress_p = stress_p.at[eid].set(stress_p[eid] + E * dstrain)
167+
168+
# results
169+
vt = vt.at[i].set(vel_p[pmid])
170+
xt = xt.at[i].set(x_p[pmid])
171+
172+
return (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt)
173+
174+
x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = lax.fori_loop(
175+
0, nsteps, step, (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt)
176+
)
177+
178+
return vt
179+
180+
181+
# Assign target
182+
Etarget = 100
183+
target = mpm(Etarget)
184+
185+
186+
#############################################################
187+
# NOTE: Uncomment the line only for TFP optimizer and
188+
# jaxopt value_and_grad = True
189+
#############################################################
190+
# @jax.value_and_grad
191+
@jit
192+
def compute_loss(E):
193+
vt = mpm(E)
194+
return jnp.linalg.norm(vt - target)
195+
196+
197+
# BFGS Optimizer
198+
# TODO: Implement box constrained optimizer
199+
def jaxopt_bfgs(params, niter):
200+
opt = jaxopt.BFGS(
201+
fun=compute_loss,
202+
value_and_grad=True,
203+
tol=1e-5,
204+
implicit_diff=False,
205+
maxiter=niter,
206+
)
207+
res = opt.run(init_params=params)
208+
result, _ = res
209+
return result
210+
211+
212+
# Optimizers
213+
def optax_adam(params, niter):
214+
# Initialize parameters of the model + optimizer.
215+
start_learning_rate = 1e-1
216+
optimizer = optax.adam(start_learning_rate)
217+
opt_state = optimizer.init(params)
218+
219+
# A simple update loop.
220+
for _ in tqdm(range(niter)):
221+
grads = grad(compute_loss)(params)
222+
updates, opt_state = optimizer.update(grads, opt_state)
223+
params = optax.apply_updates(params, updates)
224+
return params
225+
226+
227+
# # Tensor Flow Probability Optimization library
228+
# def tfp_lbfgs(params):
229+
# results = tfp.optimizer.lbfgs_minimize(
230+
# jax.jit(compute_loss), initial_position=params, tolerance=1e-5)
231+
# return results.position
232+
233+
# Initial model - Young's modulus
234+
params = 95.0
235+
236+
# vt = tfp_lbfgs(params) # LBFGS optimizer
237+
result = optax_adam(params, 1000) # ADAM optimizer
238+
239+
"""
240+
f = jax.jit(compute_loss)
241+
df = jax.jit(jax.grad(compute_loss))
242+
E = 95.0
243+
print(0, E)
244+
for i in range(10):
245+
E = E - f(E)/df(E)
246+
print(i, E)
247+
"""
248+
print("E: {}".format(result))
249+
vel = mpm(result)
250+
# update time steps
251+
dt = 0.02
252+
nsteps = 5200
253+
tt = jnp.arange(0, nsteps) * dt
254+
255+
# Plot results
256+
plt.plot(tt, vel, "r", markersize=1, label="mpm")
257+
plt.plot(tt, target, "ob", markersize=1, label="mpm-target")
258+
plt.xlabel("time (s)")
259+
plt.ylabel("velocity (m/s)")
260+
plt.legend()
261+
plt.show()

0 commit comments

Comments
 (0)