1+ from tensorflow_probability .substrates import jax as tfp
2+ import jax .numpy as jnp
3+ from jax import grad , jit , vmap , lax
4+ import jax .scipy as jsp
5+ import jax .scipy .optimize as jsp_opt
6+ import optax
7+ import jaxopt
8+ from jaxopt import ScipyBoundedMinimize
9+ import matplotlib .pyplot as plt
10+ import jax
11+
12+ @jit
13+ def mpm (E ):
14+ # nsteps
15+ nsteps = 5200
16+
17+ # mom tolerance
18+ tol = 1e-12
19+
20+ # Domain
21+ L = 25
22+
23+ # Material properties
24+ # E = 100
25+ rho = 1
26+
27+ # Computational grid
28+
29+ nelements = 13 # number of elements
30+ dx = L / nelements # element length
31+
32+ # Create equally spaced nodes
33+ x_n = jnp .linspace (0 , L , nelements + 1 )
34+ nnodes = len (x_n )
35+
36+ # Set-up a 2D array of elements with node ids
37+ elements = jnp .zeros ((nelements , 2 ), dtype = int )
38+ for nid in range (nelements ):
39+ elements = elements .at [nid ,0 ].set (nid )
40+ elements = elements .at [nid ,1 ].set (nid + 1 )
41+
42+ # Loading conditions
43+ v0 = 0.1 # initial velocity
44+ c = jnp .sqrt (E / rho ) # speed of sound
45+ b1 = jnp .pi / (2 * L ) # beta1
46+ w1 = b1 * c # omega1
47+
48+ # Create material points at the center of each element
49+ nparticles = nelements # number of particles
50+ # Id of the particle in the central element
51+ pmid = 6
52+
53+ # Material point properties
54+ x_p = jnp .zeros (nparticles ) # positions
55+ vol_p = jnp .ones (nparticles ) * dx # volume
56+ mass_p = vol_p * rho # mass
57+ stress_p = jnp .zeros (nparticles ) # stress
58+ vel_p = jnp .zeros (nparticles ) # velocity
59+
60+ # Create particle at the center
61+ x_p = 0.5 * (x_n [:- 1 ] + x_n [1 :])
62+
63+ # set initial velocities
64+ vel_p = v0 * jnp .sin (b1 * x_p )
65+
66+ # Time steps and duration
67+ dt_crit = dx / c
68+ dt = 0.02
69+
70+ # results
71+ tt = jnp .zeros (nsteps )
72+ vt = jnp .zeros (nsteps )
73+ xt = jnp .zeros (nsteps )
74+
75+ def step (i , carry ):
76+ x_p , mass_p , vel_p , vol_p , stress_p , vt , xt = carry
77+ # reset nodal values
78+ mass_n = jnp .zeros (nnodes ) # mass
79+ mom_n = jnp .zeros (nnodes ) # momentum
80+ fint_n = jnp .zeros (nnodes ) # internal force
81+
82+ # iterate through each element
83+ for eid in range (nelements ):
84+ # get nodal ids
85+ nid1 , nid2 = elements [eid ]
86+
87+ # compute shape functions and derivatives
88+ N1 = 1 - abs (x_p [eid ] - x_n [nid1 ])/ dx
89+ N2 = 1 - abs (x_p [eid ] - x_n [nid2 ])/ dx
90+ dN1 = - 1 / dx
91+ dN2 = 1 / dx
92+
93+ # map particle mass and momentum to nodes
94+ mass_n = mass_n .at [nid1 ].set (mass_n [nid1 ] + N1 * mass_p [eid ])
95+ mass_n = mass_n .at [nid2 ].set (mass_n [nid2 ] + N2 * mass_p [eid ])
96+
97+ mom_n = mom_n .at [nid1 ].set (mom_n [nid1 ] + N1 * mass_p [eid ] * vel_p [eid ])
98+ mom_n = mom_n .at [nid2 ].set (mom_n [nid2 ] + N2 * mass_p [eid ] * vel_p [eid ])
99+
100+ # compute nodal internal force
101+ fint_n = fint_n .at [nid1 ].set (fint_n [nid1 ] - vol_p [eid ] * stress_p [eid ] * dN1 )
102+ fint_n = fint_n .at [nid2 ].set (fint_n [nid2 ] - vol_p [eid ] * stress_p [eid ] * dN2 )
103+
104+ # apply boundary conditions
105+ mom_n = mom_n .at [0 ].set (0 ) # Nodal velocity v = 0 in m * v at node 0.
106+ fint_n = fint_n .at [0 ].set (0 ) # Nodal force f = m * a, where a = 0 at node 0.
107+
108+ # update nodal momentum
109+ mom_n = mom_n + fint_n * dt
110+
111+ # update particle velocity position and stress
112+ # iterate through each element
113+ for eid in range (nelements ):
114+ # get nodal ids
115+ nid1 , nid2 = elements [eid ]
116+
117+ # compute shape functions and derivatives
118+ N1 = 1 - abs (x_p [eid ] - x_n [nid1 ])/ dx
119+ N2 = 1 - abs (x_p [eid ] - x_n [nid2 ])/ dx
120+ dN1 = - 1 / dx
121+ dN2 = 1 / dx
122+
123+ # compute particle velocity
124+ # if (mass_n[nid1]) > tol:
125+ vel_p = vel_p .at [eid ].set (vel_p [eid ] + dt * N1 * fint_n [nid1 ] / mass_n [nid1 ])
126+ # if (mass_n[nid2]) > tol:
127+ vel_p = vel_p .at [eid ].set (vel_p [eid ] + dt * N2 * fint_n [nid2 ] / mass_n [nid2 ])
128+
129+ # update particle position based on nodal momentum
130+ x_p = x_p .at [eid ].set (x_p [eid ] + dt * (N1 * mom_n [nid1 ]/ mass_n [nid1 ] + N2 * mom_n [nid2 ]/ mass_n [nid2 ]))
131+
132+ # nodal velocity
133+ nv1 = mom_n [nid1 ]/ mass_n [nid1 ]
134+ nv2 = mom_n [nid2 ]/ mass_n [nid2 ]
135+
136+ # rate of strain increment
137+ grad_v = dN1 * nv1 + dN2 * nv2
138+ # particle dstrain
139+ dstrain = grad_v * dt
140+ # particle volume
141+ vol_p = vol_p .at [eid ].set ((1 + dstrain ) * vol_p [eid ])
142+ # update stress using linear elastic model
143+ stress_p = stress_p .at [eid ].set (stress_p [eid ] + E * dstrain )
144+
145+ # results
146+ vt = vt .at [i ].set (vel_p [pmid ])
147+ xt = xt .at [i ].set (x_p [pmid ])
148+
149+ return (x_p , mass_p , vel_p , vol_p , stress_p , vt , xt )
150+
151+ x_p , mass_p , vel_p , vol_p , stress_p , vt , xt = lax .fori_loop (0 , nsteps , step , (x_p , mass_p , vel_p , vol_p , stress_p , vt , xt ))
152+
153+
154+ return vt
155+
156+
157+ # Assign target
158+ Etarget = 100
159+ target = mpm (Etarget )
160+
161+
162+ #############################################################
163+ # NOTE: Uncomment the line only for TFP optimizer and
164+ # jaxopt value_and_grad = True
165+ #############################################################
166+ # @jax.value_and_grad
167+ @jit
168+ def compute_loss (E ):
169+ vt = mpm (E )
170+ return jnp .linalg .norm (vt - target )
171+
172+ # BFGS Optimizer
173+ # TODO: Implement box constrained optimizer
174+ def jaxopt_bfgs (params , niter ):
175+ opt = jaxopt .BFGS (fun = compute_loss , value_and_grad = True , tol = 1e-5 , implicit_diff = False , maxiter = niter )
176+ res = opt .run (init_params = params )
177+ result , _ = res
178+ return result
179+
180+ # Optimizers
181+ def optax_adam (params , niter ):
182+ # Initialize parameters of the model + optimizer.
183+ start_learning_rate = 1e-1
184+ optimizer = optax .adam (start_learning_rate )
185+ opt_state = optimizer .init (params )
186+
187+ # A simple update loop.
188+ for _ in range (niter ):
189+ grads = grad (compute_loss )(params )
190+ updates , opt_state = optimizer .update (grads , opt_state )
191+ params = optax .apply_updates (params , updates )
192+ return params
193+
194+ # Tensor Flow Probability Optimization library
195+ def tfp_lbfgs (params ):
196+ results = tfp .optimizer .lbfgs_minimize (
197+ jax .jit (compute_loss ), initial_position = params , tolerance = 1e-5 )
198+ return results .position
199+
200+ # Initial model - Young's modulus
201+ params = 95.0
202+
203+ # vt = tfp_lbfgs(params) # LBFGS optimizer
204+ result = optax_adam (params , 1000 ) # ADAM optimizer
205+
206+ """
207+ f = jax.jit(compute_loss)
208+ df = jax.jit(jax.grad(compute_loss))
209+ E = 95.0
210+ print(0, E)
211+ for i in range(10):
212+ E = E - f(E)/df(E)
213+ print(i, E)
214+ """
215+ print ("E: {}" .format (result ))
216+ vel = mpm (result )
217+ # update time steps
218+ dt = 0.02
219+ nsteps = 5200
220+ tt = jnp .arange (0 , nsteps ) * dt
221+
222+ # Plot results
223+ plt .plot (tt , vel , 'r' , markersize = 1 , label = 'mpm' )
224+ plt .plot (tt , target , 'ob' , markersize = 1 , label = 'mpm-target' )
225+ plt .xlabel ('time (s)' )
226+ plt .ylabel ('velocity (m/s)' )
227+ plt .legend ()
228+ plt .show ()
0 commit comments