|
5 | 5 | from diffmpm.material import SimpleMaterial |
6 | 6 | from diffmpm.mesh import Mesh1D |
7 | 7 | from diffmpm.particle import Particles |
| 8 | +from diffmpm.constraint import Constraint |
8 | 9 | from diffmpm.solver import MPMExplicit |
9 | 10 | from jax import value_and_grad, grad, jit |
10 | 11 | from tqdm import tqdm |
11 | 12 |
|
12 | 13 | E_true = 100 |
13 | 14 | material = SimpleMaterial({"E": E_true, "density": 1}) |
14 | | -elements = Linear1D(1, 1, jnp.array([0])) |
15 | | -particles = Particles( |
16 | | - jnp.array([0.5]).reshape(1, 1, 1), material, jnp.array([0]) |
17 | | -) |
| 15 | +cons = [(jnp.array([0]), Constraint(0, 0.0))] |
| 16 | +elements = Linear1D(1, 1, 1, cons) |
| 17 | +particles = Particles(jnp.array([0.5]).reshape(1, 1, 1), material, jnp.array([0])) |
18 | 18 | b1 = jnp.pi * 0.5 |
19 | 19 | particles.velocity += 0.1 |
20 | | -particles.set_mass_volume(1.0) |
| 20 | +# particles.set_mass_volume(1.0) |
21 | 21 | dt = 0.01 |
22 | 22 | nsteps = 1000 |
23 | 23 | mesh = Mesh1D({"particles": [particles], "elements": elements}) |
@@ -60,20 +60,17 @@ def optax_adam(params, niter, mpm, target_vel): |
60 | 60 | return param_list, loss_list |
61 | 61 |
|
62 | 62 |
|
63 | | -params = 107.5 |
| 63 | +params = 105.0 |
64 | 64 | material = SimpleMaterial({"E": params, "density": 1}) |
65 | | -elements = Linear1D(1, 1, jnp.array([0])) |
66 | | -particles = Particles( |
67 | | - jnp.array([0.5]).reshape(1, 1, 1), material, jnp.array([0]) |
68 | | -) |
| 65 | +cons = [(jnp.array([0]), Constraint(0, 0.0))] |
| 66 | +elements = Linear1D(1, 1, 1, cons) |
| 67 | +particles = Particles(jnp.array([0.5]).reshape(1, 1, 1), material, jnp.array([0])) |
69 | 68 | particles.velocity += 0.1 |
70 | 69 | particles.set_mass_volume(1.0) |
71 | 70 | mesh = Mesh1D({"particles": [particles], "elements": elements}) |
72 | 71 |
|
73 | 72 | mpm = MPMExplicit(mesh, dt, scheme="usl") |
74 | | -param_list, loss_list = optax_adam( |
75 | | - params, 400, mpm, target_vel |
76 | | -) # ADAM optimizer |
| 73 | +param_list, loss_list = optax_adam(params, 400, mpm, target_vel) # ADAM optimizer |
77 | 74 | # print("E: {}".format(result)) |
78 | 75 |
|
79 | 76 | fig, ax = plt.subplots(1, 2, figsize=(16, 6)) |
|
0 commit comments