Skip to content

Commit 53f057f

Browse files
Merge pull request #1999 from reubenharry:dirichlet
PiperOrigin-RevId: 752493986
2 parents c1cdc92 + f4a83b1 commit 53f057f

4 files changed

Lines changed: 215 additions & 2 deletions

File tree

spinoffs/inference_gym/inference_gym/internal/test_util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,17 +223,22 @@ def target_log_prob_fn(*x):
223223
seed = test_util.test_seed(sampler_type='stateless')
224224
current_state = tf.nest.map_structure(
225225
lambda b, e: b( # pylint: disable=g-long-lambda
226-
tf.zeros([num_chains] + list(e), dtype=dtype)),
226+
tf.zeros([num_chains] + list(b.inverse_event_shape(e)), dtype=dtype)),
227+
model.default_event_space_bijector,
228+
model.event_shape)
229+
step_size = tf.nest.map_structure(
230+
lambda b, e: tf.fill(b.inverse_event_shape(e), step_size),
227231
model.default_event_space_bijector,
228232
model.event_shape)
229233

230234
# tfp.mcmc only works well with lists.
231235
current_state = tf.nest.flatten(current_state)
236+
step_size = tf.nest.flatten(step_size)
232237

233238
hmc = tfp.mcmc.HamiltonianMonteCarlo(
234239
target_log_prob_fn=target_log_prob_fn,
235240
num_leapfrog_steps=num_leapfrog_steps,
236-
step_size=[tf.fill(s.shape, step_size) for s in current_state])
241+
step_size=step_size)
237242
hmc = tfp.mcmc.TransformedTransitionKernel(
238243
hmc, tf.nest.flatten(model.default_event_space_bijector))
239244
hmc = tfp.mcmc.DualAveragingStepSizeAdaptation(

spinoffs/inference_gym/inference_gym/targets/BUILD

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ py_library(
3838
":banana",
3939
":bayesian_model",
4040
":brownian_motion",
41+
":dirichlet",
4142
":eight_schools",
4243
":ill_conditioned_gaussian",
4344
":item_response_theory",
@@ -132,6 +133,29 @@ py_test(
132133
],
133134
)
134135

136+
py_library(
137+
name = "dirichlet",
138+
srcs = ["dirichlet.py"],
139+
deps = [
140+
":model",
141+
# tensorflow_probability dep,
142+
],
143+
)
144+
145+
# py_strict
146+
py_test(
147+
name = "dirichlet_test",
148+
srcs = ["dirichlet_test.py"],
149+
deps = [
150+
":dirichlet",
151+
# absl/testing:parameterized dep,
152+
# numpy dep,
153+
# tensorflow dep,
154+
# tensorflow_probability/python/internal:test_util dep,
155+
"//inference_gym/internal:test_util",
156+
],
157+
)
158+
135159
py_library(
136160
name = "eight_schools",
137161
srcs = ["eight_schools.py"],
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2025 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Dirichlet model."""
15+
16+
import numpy as np
17+
import tensorflow.compat.v2 as tf
18+
import tensorflow_probability as tfp
19+
from inference_gym.targets import model
20+
import tensorflow_probability.substrates.numpy as tfp_np
21+
22+
23+
tfb = tfp.bijectors
24+
tfd = tfp.distributions
25+
26+
__all__ = [
27+
'Dirichlet',
28+
]
29+
30+
31+
class Dirichlet(model.Model):
32+
"""Creates a Dirichlet.
33+
34+
This function produces a Dirichlet distribution. Low concentration parameters
35+
(much below 1) produce a distribution that is typically difficult to sample
36+
from.
37+
"""
38+
39+
def __init__(
40+
self,
41+
concentration_vector=np.ones(100) * 0.1,
42+
dtype=tf.float32,
43+
name='dirichlet',
44+
pretty_name='Dirichlet',
45+
):
46+
"""Construct the Dirichlet.
47+
48+
Args:
49+
concentration_vector: The concentration parameters of the Dirichlet
50+
distribution.
51+
dtype: Dtype to use for floating point quantities.
52+
name: Python `str` name prefixed to Ops created by this class.
53+
pretty_name: A Python `str`. The pretty name of this model.
54+
"""
55+
56+
dirichlet = tfp.distributions.Dirichlet(
57+
concentration=tf.cast(concentration_vector, dtype),
58+
)
59+
dirichlet_np = tfp_np.distributions.Dirichlet(
60+
concentration=np.asarray(concentration_vector),
61+
)
62+
63+
sample_transformations = {
64+
'identity': model.Model.SampleTransformation(
65+
fn=lambda params: params,
66+
pretty_name='Identity',
67+
ground_truth_mean=dirichlet_np.mean(),
68+
ground_truth_standard_deviation=dirichlet_np.stddev(),
69+
dtype=dtype,
70+
)
71+
}
72+
73+
self._dirichlet = dirichlet
74+
75+
super(Dirichlet, self).__init__(
76+
default_event_space_bijector=tfb.IteratedSigmoidCentered(
77+
validate_args=False, name='iterated_sigmoid'
78+
),
79+
event_shape=dirichlet.event_shape,
80+
dtype=dirichlet.dtype,
81+
name=name,
82+
pretty_name=pretty_name,
83+
sample_transformations=sample_transformations,
84+
)
85+
86+
def _unnormalized_log_prob(self, value):
87+
return self._dirichlet.log_prob(value)
88+
89+
def sample(self, sample_shape=(), seed=None, name='sample'):
90+
"""Generate samples of the specified shape from the target distribution.
91+
92+
The returned samples are exact (and independent) samples from the target
93+
distribution of this model.
94+
95+
Note that a call to `sample()` without arguments will generate a single
96+
sample.
97+
98+
Args:
99+
sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
100+
seed: Python integer or `tfp.util.SeedStream` instance, for seeding PRNG.
101+
name: Name to give to the prefix the generated ops.
102+
103+
Returns:
104+
samples: a `Tensor` with prepended dimensions `sample_shape`.
105+
"""
106+
return self._dirichlet.sample(sample_shape, seed=seed, name=name)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
from absl.testing import parameterized
16+
import numpy as np
17+
import tensorflow.compat.v2 as tf
18+
19+
from tensorflow_probability.python.internal import test_util as tfp_test_util
20+
from inference_gym.internal import test_util
21+
from inference_gym.targets import dirichlet
22+
23+
24+
@test_util.multi_backend_test(globals(),
25+
'targets.dirichlet_test')
26+
class DirichletTest(test_util.InferenceGymTestCase):
27+
28+
@parameterized.parameters(np.float32, np.float64)
29+
def testBasic(self, dtype):
30+
"""Checks that you get finite values given unconstrained samples.
31+
32+
We check `unnormalized_log_prob` as well as the values of the sample
33+
transformations.
34+
35+
Args:
36+
dtype: Dtype to use for floating point computations.
37+
"""
38+
model = dirichlet.Dirichlet(
39+
concentration_vector=np.ones(10, dtype=dtype),
40+
dtype=dtype
41+
)
42+
self.validate_log_prob_and_transforms(
43+
model,
44+
sample_transformation_shapes=dict(identity=[10],),
45+
check_ground_truth_mean=True,
46+
check_ground_truth_standard_deviation=True,
47+
dtype=dtype,
48+
)
49+
50+
def testMC(self):
51+
"""Checks true samples from the model against the ground truth."""
52+
model = dirichlet.Dirichlet(
53+
concentration_vector=np.ones(10, dtype=np.float32),
54+
dtype=tf.float32,
55+
)
56+
self.validate_ground_truth_using_monte_carlo(
57+
model,
58+
num_samples=int(1e6),
59+
)
60+
61+
@test_util.numpy_disable_gradient_test
62+
def testHMC(self):
63+
"""Checks true samples from the model against the ground truth."""
64+
model = dirichlet.Dirichlet(
65+
concentration_vector=np.ones(10, dtype=np.float32),
66+
dtype=np.float32,
67+
)
68+
self.validate_ground_truth_using_hmc(
69+
model,
70+
num_chains=4,
71+
num_steps=4000,
72+
num_leapfrog_steps=10,
73+
step_size=0.025,
74+
)
75+
76+
77+
if __name__ == '__main__':
78+
tfp_test_util.main()

0 commit comments

Comments
 (0)