|
1 | 1 | from jax.tree_util import register_pytree_node_class |
2 | 2 | import abc |
3 | 3 | import jax.numpy as jnp |
| 4 | +from jax import vmap, lax, jit |
4 | 5 |
|
5 | 6 |
|
6 | 7 | class Material(abc.ABC): |
@@ -133,6 +134,186 @@ def compute_stress(self, dstrain): |
133 | 134 | return dstrain * self.properties["E"] |
134 | 135 |
|
135 | 136 |
|
| 137 | +@register_pytree_node_class |
| 138 | +class Bingham(Material): |
| 139 | + _props = ( |
| 140 | + "density", |
| 141 | + "youngs_modulus", |
| 142 | + "poisson_ratio", |
| 143 | + "tau0", |
| 144 | + "mu", |
| 145 | + "critical_shear_rate", |
| 146 | + "ndim", |
| 147 | + ) |
| 148 | + |
| 149 | + def __init__(self, material_properties): |
| 150 | + """ |
| 151 | + Create a Bingham material model. |
| 152 | +
|
| 153 | + Arguments |
| 154 | + --------- |
| 155 | + material_properties: dict |
| 156 | + Dictionary with material properties. For Bingham |
| 157 | + material, 'density','youngs_modulus','poisson_ratio','tau0','mu', |
| 158 | + 'ndim and 'critical_shear_rate' are required keys. |
| 159 | +
|
| 160 | + Methods |
| 161 | + ------- |
| 162 | + initialise_state_variables |
| 163 | + Initialises the state variables for the Bingham material |
| 164 | + compute_stress |
| 165 | + computes the stress for the Bingham material particles |
| 166 | + |
| 167 | + """ |
| 168 | + self.validate_props(material_properties) |
| 169 | + self.ndim = material_properties["ndim"] |
| 170 | + youngs_modulus = material_properties["youngs_modulus"] |
| 171 | + poisson_ratio = material_properties["poisson_ratio"] |
| 172 | + self.state_variables=["pressure"] |
| 173 | + # Calculate the bulk modulus |
| 174 | + bulk_modulus = youngs_modulus / (3.0 * (1.0 - 2.0 * poisson_ratio)) |
| 175 | + compressibility_multiplier_ = 1.0 |
| 176 | + # Special Material Properties |
| 177 | + if material_properties.get("incompressible", False): |
| 178 | + compressibility_multiplier_ = 0.0 |
| 179 | + self.properties = { |
| 180 | + **material_properties, |
| 181 | + "bulk_modulus": bulk_modulus, |
| 182 | + "compressibility_multiplier": compressibility_multiplier_, |
| 183 | + } |
| 184 | + |
| 185 | + def __repr__(self): |
| 186 | + return f"Bingham(props={self.properties})" |
| 187 | + |
| 188 | + # Initialise history variables |
| 189 | + def initialise_state_variables(particles): |
| 190 | + state_vars = {} |
| 191 | + state_vars["pressure"] = jnp.zeros((particles.loc.shape[0])) |
| 192 | + return state_vars |
| 193 | + |
| 194 | + # Compute the pressure |
| 195 | + def __thermodynamic_pressure(self, volumetric_strain): |
| 196 | + return -self.properties["bulk_modulus"] * volumetric_strain |
| 197 | + |
| 198 | + # Compute the stress |
| 199 | + def compute_stress(self, dstrain, particles, state_vars:dict): |
| 200 | + """ |
| 201 | + Computes the stress for the Bingham material |
| 202 | +
|
| 203 | + Arguments |
| 204 | + --------- |
| 205 | + dstrain: array_like |
| 206 | + The strain rate tensor for the particles |
| 207 | + particles: diffmpm.particles.Particles |
| 208 | + state_vars: dict {str: jnp.ndarray} |
| 209 | + dictionary containig the string as the name of the |
| 210 | + property and the jnp.ndarray shape (nparticles, 1) as the |
| 211 | + values of the property at each particle. |
| 212 | + |
| 213 | + Returns |
| 214 | + ------- |
| 215 | + updated_stress: jnp.ndarray |
| 216 | + The updated stress for the particles expected shape (nparticles, 6,1) |
| 217 | + """ |
| 218 | + shear_rate_threshold = 1e-15 |
| 219 | + # dirac delta in Voigt notation |
| 220 | + dirac_delta = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape((6, 1)) |
| 221 | + dirac_delta = lax.cond( |
| 222 | + self.ndim == 1, |
| 223 | + lambda x: x.at[0, 0].set(1.0), |
| 224 | + lambda x: x.at[0:2, 0].set(1.0), |
| 225 | + dirac_delta, |
| 226 | + ) |
| 227 | + # Set threshold for minimum critical shear rate |
| 228 | + self.properties["critical_shear_rate"] = lax.select( |
| 229 | + self.properties["critical_shear_rate"] < shear_rate_threshold, |
| 230 | + shear_rate_threshold, |
| 231 | + self.properties["critical_shear_rate"], |
| 232 | + ) |
| 233 | + |
| 234 | + @jit |
| 235 | + def __compute_stress_per_particle( |
| 236 | + particle_strain_rate, |
| 237 | + self, |
| 238 | + state_vars_pressure, |
| 239 | + dvolumetric_strain_per_particle, |
| 240 | + dirac_delta, |
| 241 | + ): |
| 242 | + strain_r = particle_strain_rate |
| 243 | + # Convert strain rate to rate of deformation tensor |
| 244 | + strain_r = strain_r.at[-3:].multiply(0.5) |
| 245 | + |
| 246 | + # Rate of shear = sqrt(2 * Strain_r_ij * Strain_r_ij) |
| 247 | + # Strain_r is in Voigt notation so the above formula reduces in the voigt notation to |
| 248 | + # sqrt(Strain_r_0^2 + Strain_r_1^2 + Strain_r_2^2 + 2*Strain_r_3^2 |
| 249 | + # + 2*Strain_r_4^2 + 2*Strain_r_5^2) |
| 250 | + # When shear rate> critical_shear_rate^2 then the material is yielding |
| 251 | + |
| 252 | + shear_rate = jnp.sqrt( |
| 253 | + 2.0 * (strain_r.T @ (strain_r) + strain_r[-3:].T @ strain_r[-3:]) |
| 254 | + ).squeeze() |
| 255 | + |
| 256 | + # Apparent_viscosity maps shear rate to shear stress |
| 257 | + # Check if shear rate is 0 |
| 258 | + |
| 259 | + apparent_viscosity_true = 2.0 * ( |
| 260 | + (self.properties["tau0"] / shear_rate) + self.properties["mu"] |
| 261 | + ) |
| 262 | + condition = (shear_rate * shear_rate) > ( |
| 263 | + self.properties["critical_shear_rate"] |
| 264 | + * self.properties["critical_shear_rate"] |
| 265 | + ) |
| 266 | + apparent_viscosity = lax.select(condition, apparent_viscosity_true, 0.0) |
| 267 | + |
| 268 | + # Compute volumetric tau |
| 269 | + |
| 270 | + tau = apparent_viscosity * strain_r |
| 271 | + # von Mises criterion |
| 272 | + # yield condition trace of the invariant > tau0^2 |
| 273 | + # and trace can be found using the first 3 numbers of tau |
| 274 | + # as tau is in voigt notation |
| 275 | + |
| 276 | + trace_invariant = 0.5 * jnp.dot(tau[:3, 0], tau[:3, 0]) |
| 277 | + tau = lax.cond( |
| 278 | + trace_invariant < (self.properties["tau0"] * self.properties["tau0"]), |
| 279 | + lambda x: x.at[:].set(0), |
| 280 | + lambda x: x, |
| 281 | + tau, |
| 282 | + ) |
| 283 | + # update pressure |
| 284 | + state_vars_pressure += self.properties[ |
| 285 | + "compressibility_multiplier" |
| 286 | + ] * self.__thermodynamic_pressure(dvolumetric_strain_per_particle) |
| 287 | + |
| 288 | + # Update volumetric and deviatoric stress |
| 289 | + # thermodynamic pressure is from material point |
| 290 | + # stress = -thermodynamic_pressure I + tau, where I is identity matrix or |
| 291 | + # direc_delta in Voigt notation |
| 292 | + |
| 293 | + updated_stress_per_particle = ( |
| 294 | + -(state_vars_pressure) |
| 295 | + * dirac_delta |
| 296 | + * self.properties["compressibility_multiplier"] |
| 297 | + + tau |
| 298 | + ) |
| 299 | + return updated_stress_per_particle, state_vars_pressure |
| 300 | + |
| 301 | + # using vmap to vectorise the function compute stress per particle |
| 302 | + # for all the particles using the first dimension of the strain rate |
| 303 | + # and the dvolumetric_strain and state_vars pressure |
| 304 | + updated_stress, state_vars["pressure"] = vmap( |
| 305 | + __compute_stress_per_particle, in_axes=(0, None, 0, 0, None) |
| 306 | + )( |
| 307 | + particles.strain_rate, |
| 308 | + self, |
| 309 | + state_vars["pressure"], |
| 310 | + particles.dvolumetric_strain, |
| 311 | + dirac_delta, |
| 312 | + ) |
| 313 | + |
| 314 | + return updated_stress |
| 315 | + |
| 316 | + |
136 | 317 | if __name__ == "__main__": |
137 | 318 | from diffmpm.utils import _show_example |
138 | 319 |
|
|
0 commit comments