Skip to content

Commit 667e549

Browse files
committed
EXAMPLE: update jax fast tutorial
1 parent ae1a3d9 commit 667e549

1 file changed

Lines changed: 118 additions & 125 deletions

File tree

examples/gw_examples/injection_examples/jax_fast_tutorial.py

Lines changed: 118 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -10,96 +10,129 @@
1010
We optionally use ripple waveforms and a JIT-compiled likelihood.
1111
"""
1212
import os
13-
from itertools import product
1413

1514
# Set OMP_NUM_THREADS to stop lalsimulation taking over my computer
1615
os.environ["OMP_NUM_THREADS"] = "1"
1716

1817
import bilby
19-
import bilby.gw.jaxstuff
2018
import numpy as np
2119
import jax
2220
import jax.numpy as jnp
23-
from jax import random
24-
from numpyro.infer import AIES, ESS # noqa
25-
from numpyro.infer.ensemble_util import get_nondiagonal_indices
21+
from bilby.compat.jax import JittedLikelihood
22+
from ripple.waveforms import IMRPhenomPv2
2623

2724
jax.config.update("jax_enable_x64", True)
2825

29-
bilby.core.utils.setup_logger() # log_level="WARNING")
3026

31-
32-
def setup_prior():
33-
# Set up a PriorDict, which inherits from dict.
34-
# By default we will sample all terms in the signal models. However, this will
35-
# take a long time for the calculation, so for this example we will set almost
36-
# all of the priors to be equall to their injected values. This implies the
37-
# prior is a delta function at the true, injected value. In reality, the
38-
# sampler implementation is smart enough to not sample any parameter that has
39-
# a delta-function prior.
40-
# The above list does *not* include mass_1, mass_2, theta_jn and luminosity
41-
# distance, which means those are the parameters that will be included in the
42-
# sampler. If we do nothing, then the default priors get used.
43-
priors = bilby.gw.prior.BBHPriorDict()
44-
del priors["mass_1"], priors["mass_2"]
45-
priors["geocent_time"] = bilby.core.prior.Uniform(1126249642, 1126269642)
46-
priors["luminosity_distance"].minimum = 1
47-
priors["luminosity_distance"].maximum = 500
48-
priors["chirp_mass"].minimum = 2.35
49-
priors["chirp_mass"].maximum = 2.45
50-
# priors["luminosity_distance"] = bilby.core.prior.PowerLaw(2.0, 10.0, 500.0)
51-
# priors["sky_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
52-
# priors["sky_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
53-
# priors["sky_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
54-
# priors["delta_phase"] = priors.pop("phase")
55-
# del priors["tilt_1"], priors["tilt_2"], priors["phi_12"], priors["phi_jl"]
56-
# priors["spin_1_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
57-
# priors["spin_1_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
58-
# priors["spin_1_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
59-
# priors["spin_2_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
60-
# priors["spin_2_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
61-
# priors["spin_2_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
62-
# # del priors["a_1"], priors["a_2"]
63-
# # priors["chi_1"] = bilby.core.prior.Uniform(-0.05, 0.05)
64-
# # priors["chi_2"] = bilby.core.prior.Uniform(-0.05, 0.05)
65-
# del priors["theta_jn"], priors["psi"], priors["delta_phase"]
66-
# priors["orientation_w"] = bilby.core.prior.Normal(mu=0, sigma=1)
67-
# priors["orientation_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
68-
# priors["orientation_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
69-
# priors["orientation_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
70-
return priors
71-
72-
73-
def original_to_sampling_priors(priors, truth):
74-
del priors["ra"], priors["dec"]
75-
priors["zenith"] = bilby.core.prior.Cosine()
76-
priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi)
77-
priors["L1_time"] = bilby.core.prior.Uniform(truth["geocent_time"] - 0.1, truth["geocent_time"] + 0.1)
27+
def bilby_to_ripple_spins(
28+
theta_jn,
29+
phi_jl,
30+
tilt_1,
31+
tilt_2,
32+
phi_12,
33+
a_1,
34+
a_2,
35+
):
36+
"""
37+
A simplified spherical to cartesian spin conversion function.
38+
This is not equivalent to the method used in `bilby.gw.conversion`
39+
which comes from `lalsimulation` and is not `JAX` compatible.
40+
"""
41+
iota = theta_jn
42+
spin_1x = a_1 * jnp.sin(tilt_1) * jnp.cos(phi_jl)
43+
spin_1y = a_1 * jnp.sin(tilt_1) * jnp.sin(phi_jl)
44+
spin_1z = a_1 * jnp.cos(tilt_1)
45+
spin_2x = a_2 * jnp.sin(tilt_2) * jnp.cos(phi_jl + phi_12)
46+
spin_2y = a_2 * jnp.sin(tilt_2) * jnp.sin(phi_jl + phi_12)
47+
spin_2z = a_2 * jnp.cos(tilt_2)
48+
return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z
49+
50+
51+
def ripple_bbh(
52+
frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase,
53+
a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs,
54+
):
55+
"""
56+
Source function wrapper to ripple's IMRPhenomPv2 waveform generator.
57+
This function cannot be jitted directly as the Bilby waveform generator
58+
relies on inspecting the function signature.
59+
60+
Parameters
61+
----------
62+
frequency: jnp.ndarray
63+
Frequencies at which to compute the waveform.
64+
mass_1: float | jnp.ndarray
65+
Mass of the primary component in solar masses.
66+
mass_2: float | jnp.ndarray
67+
Mass of the secondary component in solar masses.
68+
luminosity_distance: float | jnp.ndarray
69+
Luminosity distance to the source in Mpc.
70+
theta_jn: float | jnp.ndarray
71+
Angle between total angular momentum and line of sight in radians.
72+
phase: float | jnp.ndarray
73+
Phase at coalescence in radians.
74+
a_1: float | jnp.ndarray
75+
Dimensionless spin magnitude of the primary component.
76+
a_2: float | jnp.ndarray
77+
Dimensionless spin magnitude of the secondary component.
78+
tilt_1: float | jnp.ndarray
79+
Tilt angle of the primary component spin in radians.
80+
tilt_2: float | jnp.ndarray
81+
Tilt angle of the secondary component spin in radians.
82+
phi_12: float | jnp.ndarray
83+
Azimuthal angle between the two spin vectors in radians.
84+
phi_jl: float | jnp.ndarray
85+
Azimuthal angle of the total angular momentum vector in radians.
86+
**kwargs
87+
Additional keyword arguments. Must include 'minimum_frequency'.
88+
89+
Returns
90+
-------
91+
dict
92+
Dictionary containing the plus and cross polarizations of the waveform.
93+
"""
94+
iota, *cartesian_spins = bilby_to_ripple_spins(
95+
# iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_ripple_spins(
96+
theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2
97+
)
98+
frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"])
99+
theta = jnp.array([
100+
mass_1, mass_2, *cartesian_spins,
101+
luminosity_distance, jnp.array(0.0), phase, iota
102+
])
103+
wf_func = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2)
104+
hp, hc = wf_func(frequencies, theta, jnp.array(20.0))
105+
return dict(plus=hp, cross=hc)
78106

79107

80-
def main(use_jax, model, idx):
108+
def main():
81109
# Set the duration and sampling frequency of the data segment that we're
82110
# going to inject the signal into
83111
duration = 64.0
84112
sampling_frequency = 2048.0
85113
minimum_frequency = 20.0
86-
if use_jax:
87-
duration = jax.numpy.array(duration)
88-
sampling_frequency = jax.numpy.array(sampling_frequency)
89-
minimum_frequency = jax.numpy.array(minimum_frequency)
114+
duration = jnp.array(duration)
115+
sampling_frequency = jnp.array(sampling_frequency)
116+
minimum_frequency = jnp.array(minimum_frequency)
90117

91118
# Specify the output directory and the name of the simulation.
92-
outdir = "pp-test-2"
93-
label = f"{model}_{'jax' if use_jax else 'numpy'}_{idx}"
119+
outdir = "outdir"
120+
label = f"jax_fast_tutorial"
94121

95122
# Set up a random seed for result reproducibility. This is optional!
96-
bilby.core.utils.random.seed(88170235 + idx * 1000)
123+
bilby.core.utils.random.seed(88170235)
97124

98-
priors = setup_prior()
125+
priors = bilby.gw.prior.BBHPriorDict()
99126
injection_parameters = priors.sample()
100-
if model == "relbin":
101-
injection_parameters["fiducial"] = 1
102-
original_to_sampling_priors(priors, injection_parameters)
127+
injection_parameters["geocent_time"] = 1000000000.0
128+
injection_parameters["luminosity_distance"] = 400.0
129+
del priors["ra"], priors["dec"]
130+
priors["zenith"] = bilby.core.prior.Cosine()
131+
priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi)
132+
priors["L1_time"] = bilby.core.prior.Uniform(
133+
injection_parameters["geocent_time"] - 0.1,
134+
injection_parameters["geocent_time"] + 0.1,
135+
)
103136

104137
# Fixed arguments passed into the source model
105138
waveform_arguments = dict(
@@ -108,28 +141,13 @@ def main(use_jax, model, idx):
108141
minimum_frequency=minimum_frequency,
109142
)
110143

111-
if use_jax:
112-
match model:
113-
case "relbin":
114-
fdsm = bilby.gw.jaxstuff.ripple_bbh_relbin
115-
case _:
116-
fdsm = bilby.gw.jaxstuff.ripple_bbh
117-
else:
118-
match model:
119-
case "relbin":
120-
fdsm = bilby.gw.source.lal_binary_black_hole_relative_binning
121-
case _:
122-
fdsm = bilby.gw.source.lal_binary_black_hole
123-
# fdsm = bilby.gw.source.sinegaussian
124-
125144
# Create the waveform_generator using a LAL BinaryBlackHole source function
126145
waveform_generator = bilby.gw.WaveformGenerator(
127146
duration=duration,
128147
sampling_frequency=sampling_frequency,
129-
frequency_domain_source_model=fdsm,
130-
# parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
148+
frequency_domain_source_model=ripple_bbh,
131149
waveform_arguments=waveform_arguments,
132-
use_cache=not use_jax,
150+
use_cache=False,
133151
)
134152

135153
# Set up interferometers. In this case we'll use two interferometers
@@ -145,74 +163,49 @@ def main(use_jax, model, idx):
145163
waveform_generator=waveform_generator, parameters=injection_parameters,
146164
raise_error=False,
147165
)
148-
if use_jax:
149-
ifos.set_array_backend(jax.numpy)
150-
151-
if model == "mb":
152-
if use_jax:
153-
pass
154-
else:
155-
waveform_generator.frequency_domain_source_model = (
156-
bilby.gw.source.binary_black_hole_frequency_sequence
157-
)
158-
del waveform_generator.waveform_arguments["minimum_frequency"]
166+
ifos.set_array_backend(jnp)
159167

160168
# Initialise the likelihood by passing in the interferometer data (ifos) and
161169
# the waveform generator
162-
match model:
163-
case "relbin":
164-
likelihood_class = (
165-
bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient
166-
)
167-
case "mb":
168-
likelihood_class = bilby.gw.likelihood.MBGravitationalWaveTransient
169-
case _:
170-
likelihood_class = bilby.gw.likelihood.GravitationalWaveTransient
171-
likelihood = likelihood_class(
170+
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
172171
interferometers=ifos,
173172
waveform_generator=waveform_generator,
174173
priors=priors,
175174
phase_marginalization=True,
176175
distance_marginalization=True,
177176
reference_frame=ifos,
178177
time_reference="L1",
179-
# epsilon=0.1,
180-
# update_fiducial_parameters=True,
181178
)
179+
# Do an initial likelihood evaluation to trigger any internal setup
180+
likelihood.log_likelihood_ratio(priors.sample())
181+
# Wrap the likelihood with the JittedLikelihood to JIT compile the likelihood
182+
# evaluation
183+
likelihood = JittedLikelihood(likelihood)
184+
# Evaluate the likelihood once to trigger the JIT compilation, this will take
185+
# a few seconds as compiling the waveform takes some time
186+
likelihood.log_likelihood_ratio(priors.sample())
182187

183188
# use the log_compiles context so we can make sure there aren't recompilations
184189
# inside the sampling loop
185-
if True:
186-
# with jax.log_compiles():
190+
with jax.log_compiles():
187191
result = bilby.run_sampler(
188192
likelihood=likelihood,
189193
priors=priors,
190-
sampler="jaxted" if use_jax else "dynesty",
191-
nlive=1000,
194+
sampler="dynesty",
195+
nlive=100,
192196
sample="acceptance-walk",
193-
method="nest",
194-
nsteps=100,
195-
naccept=30,
197+
naccept=5,
196198
injection_parameters=injection_parameters,
197199
outdir=outdir,
198200
label=label,
199-
npool=None if use_jax else 16,
200-
# save="hdf5",
201-
save=False,
201+
npool=None,
202+
save="hdf5",
202203
rseed=np.random.randint(0, 100000),
203204
)
204205

205206
# Make a corner plot.
206-
# result.plot_corner()
207-
import IPython; IPython.embed()
208-
return result.sampling_time
207+
result.plot_corner()
209208

210209

211210
if __name__ == "__main__":
212-
times = dict()
213-
# for arg in product([True, False][:], ["relbin", "mb", "regular"][2:3]):
214-
# times[arg] = main(*arg)
215-
with jax.log_compiles():
216-
for idx in np.arange(100):
217-
times[idx] = main(True, "mb", idx)
218-
print(times)
211+
main()

0 commit comments

Comments
 (0)