|
| 1 | +# from tensorflow_probability.substrates import jax as tfp |
| 2 | +import jax.numpy as jnp |
| 3 | +from jax import grad, jit, vmap, lax |
| 4 | +import jax.scipy as jsp |
| 5 | +import jax.scipy.optimize as jsp_opt |
| 6 | +import optax |
| 7 | +import jaxopt |
| 8 | +from jaxopt import ScipyBoundedMinimize |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +import jax |
| 11 | +from tqdm import tqdm |
| 12 | + |
| 13 | + |
| 14 | +@jit |
| 15 | +def mpm(E): |
| 16 | + # nsteps |
| 17 | + nsteps = 5200 |
| 18 | + |
| 19 | + # mom tolerance |
| 20 | + tol = 1e-12 |
| 21 | + |
| 22 | + # Domain |
| 23 | + L = 25 |
| 24 | + |
| 25 | + # Material properties |
| 26 | + # E = 100 |
| 27 | + rho = 1 |
| 28 | + |
| 29 | + # Computational grid |
| 30 | + |
| 31 | + nelements = 13 # number of elements |
| 32 | + dx = L / nelements # element length |
| 33 | + |
| 34 | + # Create equally spaced nodes |
| 35 | + x_n = jnp.linspace(0, L, nelements + 1) |
| 36 | + nnodes = len(x_n) |
| 37 | + |
| 38 | + # Set-up a 2D array of elements with node ids |
| 39 | + elements = jnp.zeros((nelements, 2), dtype=int) |
| 40 | + for nid in range(nelements): |
| 41 | + elements = elements.at[nid, 0].set(nid) |
| 42 | + elements = elements.at[nid, 1].set(nid + 1) |
| 43 | + |
| 44 | + # Loading conditions |
| 45 | + v0 = 0.1 # initial velocity |
| 46 | + c = jnp.sqrt(E / rho) # speed of sound |
| 47 | + b1 = jnp.pi / (2 * L) # beta1 |
| 48 | + w1 = b1 * c # omega1 |
| 49 | + |
| 50 | + # Create material points at the center of each element |
| 51 | + nparticles = nelements # number of particles |
| 52 | + # Id of the particle in the central element |
| 53 | + pmid = 6 |
| 54 | + |
| 55 | + # Material point properties |
| 56 | + x_p = jnp.zeros(nparticles) # positions |
| 57 | + vol_p = jnp.ones(nparticles) * dx # volume |
| 58 | + mass_p = vol_p * rho # mass |
| 59 | + stress_p = jnp.zeros(nparticles) # stress |
| 60 | + vel_p = jnp.zeros(nparticles) # velocity |
| 61 | + |
| 62 | + # Create particle at the center |
| 63 | + x_p = 0.5 * (x_n[:-1] + x_n[1:]) |
| 64 | + |
| 65 | + # set initial velocities |
| 66 | + vel_p = v0 * jnp.sin(b1 * x_p) |
| 67 | + |
| 68 | + # Time steps and duration |
| 69 | + dt_crit = dx / c |
| 70 | + dt = 0.02 |
| 71 | + |
| 72 | + # results |
| 73 | + tt = jnp.zeros(nsteps) |
| 74 | + vt = jnp.zeros(nsteps) |
| 75 | + xt = jnp.zeros(nsteps) |
| 76 | + |
| 77 | + def step(i, carry): |
| 78 | + x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = carry |
| 79 | + # reset nodal values |
| 80 | + mass_n = jnp.zeros(nnodes) # mass |
| 81 | + mom_n = jnp.zeros(nnodes) # momentum |
| 82 | + fint_n = jnp.zeros(nnodes) # internal force |
| 83 | + |
| 84 | + # iterate through each element |
| 85 | + for eid in range(nelements): |
| 86 | + # get nodal ids |
| 87 | + nid1, nid2 = elements[eid] |
| 88 | + |
| 89 | + # compute shape functions and derivatives |
| 90 | + N1 = 1 - abs(x_p[eid] - x_n[nid1]) / dx |
| 91 | + N2 = 1 - abs(x_p[eid] - x_n[nid2]) / dx |
| 92 | + dN1 = -1 / dx |
| 93 | + dN2 = 1 / dx |
| 94 | + |
| 95 | + # map particle mass and momentum to nodes |
| 96 | + mass_n = mass_n.at[nid1].set(mass_n[nid1] + N1 * mass_p[eid]) |
| 97 | + mass_n = mass_n.at[nid2].set(mass_n[nid2] + N2 * mass_p[eid]) |
| 98 | + |
| 99 | + mom_n = mom_n.at[nid1].set( |
| 100 | + mom_n[nid1] + N1 * mass_p[eid] * vel_p[eid] |
| 101 | + ) |
| 102 | + mom_n = mom_n.at[nid2].set( |
| 103 | + mom_n[nid2] + N2 * mass_p[eid] * vel_p[eid] |
| 104 | + ) |
| 105 | + |
| 106 | + # compute nodal internal force |
| 107 | + fint_n = fint_n.at[nid1].set( |
| 108 | + fint_n[nid1] - vol_p[eid] * stress_p[eid] * dN1 |
| 109 | + ) |
| 110 | + fint_n = fint_n.at[nid2].set( |
| 111 | + fint_n[nid2] - vol_p[eid] * stress_p[eid] * dN2 |
| 112 | + ) |
| 113 | + |
| 114 | + # apply boundary conditions |
| 115 | + mom_n = mom_n.at[0].set(0) # Nodal velocity v = 0 in m * v at node 0. |
| 116 | + fint_n = fint_n.at[0].set( |
| 117 | + 0 |
| 118 | + ) # Nodal force f = m * a, where a = 0 at node 0. |
| 119 | + |
| 120 | + # update nodal momentum |
| 121 | + mom_n = mom_n + fint_n * dt |
| 122 | + |
| 123 | + # update particle velocity position and stress |
| 124 | + # iterate through each element |
| 125 | + for eid in range(nelements): |
| 126 | + # get nodal ids |
| 127 | + nid1, nid2 = elements[eid] |
| 128 | + |
| 129 | + # compute shape functions and derivatives |
| 130 | + N1 = 1 - abs(x_p[eid] - x_n[nid1]) / dx |
| 131 | + N2 = 1 - abs(x_p[eid] - x_n[nid2]) / dx |
| 132 | + dN1 = -1 / dx |
| 133 | + dN2 = 1 / dx |
| 134 | + |
| 135 | + # compute particle velocity |
| 136 | + # if (mass_n[nid1]) > tol: |
| 137 | + vel_p = vel_p.at[eid].set( |
| 138 | + vel_p[eid] + dt * N1 * fint_n[nid1] / mass_n[nid1] |
| 139 | + ) |
| 140 | + # if (mass_n[nid2]) > tol: |
| 141 | + vel_p = vel_p.at[eid].set( |
| 142 | + vel_p[eid] + dt * N2 * fint_n[nid2] / mass_n[nid2] |
| 143 | + ) |
| 144 | + |
| 145 | + # update particle position based on nodal momentum |
| 146 | + x_p = x_p.at[eid].set( |
| 147 | + x_p[eid] |
| 148 | + + dt |
| 149 | + * ( |
| 150 | + N1 * mom_n[nid1] / mass_n[nid1] |
| 151 | + + N2 * mom_n[nid2] / mass_n[nid2] |
| 152 | + ) |
| 153 | + ) |
| 154 | + |
| 155 | + # nodal velocity |
| 156 | + nv1 = mom_n[nid1] / mass_n[nid1] |
| 157 | + nv2 = mom_n[nid2] / mass_n[nid2] |
| 158 | + |
| 159 | + # rate of strain increment |
| 160 | + grad_v = dN1 * nv1 + dN2 * nv2 |
| 161 | + # particle dstrain |
| 162 | + dstrain = grad_v * dt |
| 163 | + # particle volume |
| 164 | + vol_p = vol_p.at[eid].set((1 + dstrain) * vol_p[eid]) |
| 165 | + # update stress using linear elastic model |
| 166 | + stress_p = stress_p.at[eid].set(stress_p[eid] + E * dstrain) |
| 167 | + |
| 168 | + # results |
| 169 | + vt = vt.at[i].set(vel_p[pmid]) |
| 170 | + xt = xt.at[i].set(x_p[pmid]) |
| 171 | + |
| 172 | + return (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt) |
| 173 | + |
| 174 | + x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = lax.fori_loop( |
| 175 | + 0, nsteps, step, (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt) |
| 176 | + ) |
| 177 | + |
| 178 | + return vt |
| 179 | + |
| 180 | + |
| 181 | +# Assign target |
| 182 | +Etarget = 100 |
| 183 | +target = mpm(Etarget) |
| 184 | + |
| 185 | + |
| 186 | +############################################################# |
| 187 | +# NOTE: Uncomment the line only for TFP optimizer and |
| 188 | +# jaxopt value_and_grad = True |
| 189 | +############################################################# |
| 190 | +# @jax.value_and_grad |
| 191 | +@jit |
| 192 | +def compute_loss(E): |
| 193 | + vt = mpm(E) |
| 194 | + return jnp.linalg.norm(vt - target) |
| 195 | + |
| 196 | + |
| 197 | +# BFGS Optimizer |
| 198 | +# TODO: Implement box constrained optimizer |
| 199 | +def jaxopt_bfgs(params, niter): |
| 200 | + opt = jaxopt.BFGS( |
| 201 | + fun=compute_loss, |
| 202 | + value_and_grad=True, |
| 203 | + tol=1e-5, |
| 204 | + implicit_diff=False, |
| 205 | + maxiter=niter, |
| 206 | + ) |
| 207 | + res = opt.run(init_params=params) |
| 208 | + result, _ = res |
| 209 | + return result |
| 210 | + |
| 211 | + |
| 212 | +# Optimizers |
| 213 | +def optax_adam(params, niter): |
| 214 | + # Initialize parameters of the model + optimizer. |
| 215 | + start_learning_rate = 1e-1 |
| 216 | + optimizer = optax.adam(start_learning_rate) |
| 217 | + opt_state = optimizer.init(params) |
| 218 | + |
| 219 | + # A simple update loop. |
| 220 | + for _ in tqdm(range(niter)): |
| 221 | + grads = grad(compute_loss)(params) |
| 222 | + updates, opt_state = optimizer.update(grads, opt_state) |
| 223 | + params = optax.apply_updates(params, updates) |
| 224 | + return params |
| 225 | + |
| 226 | + |
| 227 | +# # Tensor Flow Probability Optimization library |
| 228 | +# def tfp_lbfgs(params): |
| 229 | +# results = tfp.optimizer.lbfgs_minimize( |
| 230 | +# jax.jit(compute_loss), initial_position=params, tolerance=1e-5) |
| 231 | +# return results.position |
| 232 | + |
| 233 | +# Initial model - Young's modulus |
| 234 | +params = 95.0 |
| 235 | + |
| 236 | +# vt = tfp_lbfgs(params) # LBFGS optimizer |
| 237 | +result = optax_adam(params, 1000) # ADAM optimizer |
| 238 | + |
| 239 | +""" |
| 240 | +f = jax.jit(compute_loss) |
| 241 | +df = jax.jit(jax.grad(compute_loss)) |
| 242 | +E = 95.0 |
| 243 | +print(0, E) |
| 244 | +for i in range(10): |
| 245 | + E = E - f(E)/df(E) |
| 246 | + print(i, E) |
| 247 | +""" |
| 248 | +print("E: {}".format(result)) |
| 249 | +vel = mpm(result) |
| 250 | +# update time steps |
| 251 | +dt = 0.02 |
| 252 | +nsteps = 5200 |
| 253 | +tt = jnp.arange(0, nsteps) * dt |
| 254 | + |
| 255 | +# Plot results |
| 256 | +plt.plot(tt, vel, "r", markersize=1, label="mpm") |
| 257 | +plt.plot(tt, target, "ob", markersize=1, label="mpm-target") |
| 258 | +plt.xlabel("time (s)") |
| 259 | +plt.ylabel("velocity (m/s)") |
| 260 | +plt.legend() |
| 261 | +plt.show() |
0 commit comments