Skip to content

Commit a1bcf5e

Browse files
authored
doc: replace XXX with Note in the code comments (#2191)
1 parent 151ed23 commit a1bcf5e

23 files changed

Lines changed: 38 additions & 38 deletions

numpyro/contrib/control_flow/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def body_fn(wrapped_carry, x, prefix=None):
248248
# we haven't promote shapes of values yet during `lax.scan`, so we do it here
249249
site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"])
250250

251-
# XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
251+
# Note: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
252252
# we don't record 1-size dimensions in this field
253253
time_dim = -min(
254254
len(site["fn"].batch_shape), jnp.ndim(site["value"]) - site["fn"].event_dim

numpyro/contrib/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _update_params(params, new_params, prior, prefix=""):
161161
else:
162162
d = prior
163163
param_batch_shape = param_shape[: len(param_shape) - d.event_dim]
164-
# XXX: here we set all dimensions of prior to event dimensions.
164+
# Note: here we set all dimensions of prior to event dimensions.
165165
new_params[name] = numpyro.sample(
166166
flatten_name, d.expand(param_batch_shape).to_event()
167167
)

numpyro/contrib/tfp/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _get_codomain(bijector):
3232
loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
3333
if not_jax_tracer(concentration) and np.all(np.less(concentration, 0)):
3434
return constraints.interval(loc, loc + scale / jnp.abs(concentration))
35-
# XXX: here we suppose concentration > 0
35+
# Note: here we suppose concentration > 0
3636
# which is not true in general, but should cover enough usage cases
3737
else:
3838
return constraints.greater_than(loc)
@@ -278,7 +278,7 @@ def support(self):
278278

279279
@property
280280
def is_discrete(self):
281-
# XXX: this should cover most cases
281+
# Note: this should cover most cases
282282
return self.support is None
283283

284284
def tree_flatten(self):

numpyro/contrib/tfp/mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def init(
193193
if is_prng_key(rng_key):
194194
init_state = self._init_fn(init_params, rng_key)
195195
else:
196-
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
196+
# note: it's safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
197197
# nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
198198
# wa_steps because those variables do not depend on traced args: init_params, rng_key.
199199
init_state = vmap(self._init_fn)(init_params, rng_key)

numpyro/distributions/conjugate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
def _log_beta_1(alpha, value):
28-
# XXX: support sparse `value`
28+
# Note: support sparse `value`
2929
return gammaln(1 + value) + gammaln(alpha) - gammaln(value + alpha)
3030

3131

numpyro/distributions/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def codomain(self) -> Constraint:
626626
raise NotImplementedError
627627

628628
def __call__(self, x: NumLike) -> NumLike:
629-
# XXX consider to clamp from below for stability if necessary
629+
# Note: consider to clamp from below for stability if necessary
630630
return jnp.exp(x)
631631

632632
def _inverse(self, y: NumLike) -> NumLike:
@@ -1318,7 +1318,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
13181318
raise NotImplementedError
13191319

13201320
def tree_flatten(self):
1321-
# XXX: what if unpack_fn is a parametrized callable pytree?
1321+
# Note: what if unpack_fn is a parametrized callable pytree?
13221322
return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn})
13231323

13241324
def eq(self, other: object, static: bool = False) -> ArrayLike:

numpyro/examples/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def _load_jsb_chorales() -> dict:
414414
)
415415
data = pickle.load(f)
416416

417-
# XXX: we might expose those in `load_dataset` keywords
417+
# Note: we might expose those in `load_dataset` keywords
418418
min_note = 21
419419
note_range = 88
420420
processed_dataset = {}

numpyro/infer/autoguide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def __call__(self, *args, **kwargs):
777777
def _unpack_and_constrain(self, latent_sample, params):
778778
def unpack_single_latent(latent):
779779
unpacked_samples = self._unpack_latent(latent)
780-
# XXX: we need to add param here to be able to replay model
780+
# Note: we need to add param here to be able to replay model
781781
unpacked_samples.update(
782782
{
783783
k: v

numpyro/infer/elbo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,7 @@ def single_particle_elbo(rng_key: jax.Array) -> jax.Array:
11461146
if self.max_plate_nesting == float("inf"):
11471147
seeded_model = seed(model, model_seed)
11481148
seeded_guide = seed(guide, guide_seed)
1149-
# XXX: We can extract abstract latents here such that they
1149+
# Note: We can extract abstract latents here such that they
11501150
# can be reused in get_nonreparam_deps below.
11511151
self.max_plate_nesting = guess_max_plate_nesting(
11521152
seeded_model, seeded_guide, args, kwargs, param_map

numpyro/infer/ensemble.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def get_diagnostics_str(self, state):
307307
return "acc. prob={:.2f}".format(state.inner_state.mean_accept_prob)
308308

309309
def init_inner_state(self, rng_key):
310-
# XXX hack -- we don't know num_chains until we init the inner state
310+
# Note: hack -- we don't know num_chains until we init the inner state
311311
self._moves = [
312312
move(self._num_chains) if move.__name__ == "make_de_move" else move
313313
for move in self._moves
@@ -370,7 +370,7 @@ def de_move(rng_key, active, inactive):
370370
pairs_key, gamma_key = random.split(rng_key)
371371
n_active_chains, n_params = inactive.shape
372372

373-
# XXX: if we pass in n_params to parent scope we don't need to
373+
# Note: if we pass in n_params to parent scope we don't need to
374374
# recompute this each time
375375
g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0
376376

@@ -535,7 +535,7 @@ def __init__(
535535
def init_inner_state(self, rng_key):
536536
self.batch_log_density = lambda x: self._batch_log_density(x)[:, jnp.newaxis]
537537

538-
# XXX hack -- we don't know num_chains until we init the inner state
538+
# Note: hack -- we don't know num_chains until we init the inner state
539539
self._moves = [
540540
move(self._num_chains)
541541
if move.__name__ == "make_differential_move"

0 commit comments

Comments
 (0)