Skip to content

Commit 452c9fc

Browse files
authored
Fix typos in SGMCMC (#124)
* Fix typos * Minor
1 parent b3ba1ac commit 452c9fc

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

posteriors/sgmcmc/baoa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def build(
5656
Defaults to random iid samples from N(0, 1).
5757
5858
Returns:
59-
SGHMC transform instance.
59+
BAOA transform instance.
6060
"""
6161
init_fn = partial(init, momenta=momenta)
6262
update_fn = partial(
@@ -95,7 +95,7 @@ def init(params: TensorTree, momenta: TensorTree | float | None = None) -> BAOAS
9595
Defaults to random iid samples from N(0, 1).
9696
9797
Returns:
98-
Initial SGHMCState containing momenta.
98+
Initial BAOAState containing momenta.
9999
"""
100100
if momenta is None:
101101
momenta = tree_map(
@@ -128,7 +128,7 @@ def update(
128128
See [build](baoa.md#posteriors.sgmcmc.baoa.build) for more details.
129129
130130
Args:
131-
state: SGHMCState containing params and momenta.
131+
state: BAOAState containing params and momenta.
132132
batch: Data batch to be send to log_posterior.
133133
log_posterior: Function that takes parameters and input batch and
134134
returns the log posterior value (which can be unnormalised)

posteriors/sgmcmc/sgnht.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class SGNHTState(TensorClass["frozen"]):
7373
Attributes:
7474
params: Parameters.
7575
momenta: Momenta for each parameter.
76+
xi: Scalar thermostat.
7677
log_posterior: Log posterior evaluation.
7778
aux: Auxiliary information from the log_posterior call.
7879
"""
@@ -98,7 +99,7 @@ def init(
9899
xi: Initial value for scalar thermostat ξ.
99100
100101
Returns:
101-
Initial SGNHTState containing momenta.
102+
Initial SGNHTState containing params, momenta and xi (thermostat).
102103
"""
103104
if momenta is None:
104105
momenta = tree_map(

0 commit comments

Comments
 (0)