Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit cf7cb7f

Browse files
committed
lint the code
1 parent 8ba4760 commit cf7cb7f

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
class OptaxHMCState(NamedTuple):
1616
"""Optax state for the HMC integrator."""
17+
1718
count: Array
1819
rng_key: PRNGKeyArray
1920
momentum: PyTree
@@ -38,6 +39,7 @@ def hmc_integrator(
3839
step_schedule: StepSchedule
3940
A function that takes training step as input and returns the step size.
4041
"""
42+
4143
def init_fn(params):
4244
return OptaxHMCState(
4345
count=jnp.zeros([], jnp.int32),
@@ -82,7 +84,7 @@ def mh_correction():
8284
momentum, _ = jax.flatten_util.ravel_pytree(momentum)
8385
kinetic = 0.5 * jnp.dot(momentum, momentum)
8486
hamiltonian = kinetic + state.log_prob
85-
accept_prob = jnp.minimum(1., jnp.exp(hamiltonian - state.hamiltonian))
87+
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian))
8688

8789
def _accept():
8890
empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params)

0 commit comments

Comments
 (0)