1010We optionally use ripple waveforms and a JIT-compiled likelihood.
1111"""
1212import os
13- from itertools import product
1413
1514# Set OMP_NUM_THREADS to stop lalsimulation taking over my computer
1615os .environ ["OMP_NUM_THREADS" ] = "1"
1716
1817import bilby
19- import bilby .gw .jaxstuff
2018import numpy as np
2119import jax
2220import jax .numpy as jnp
23- from jax import random
24- from numpyro .infer import AIES , ESS # noqa
25- from numpyro .infer .ensemble_util import get_nondiagonal_indices
21+ from bilby .compat .jax import JittedLikelihood
22+ from ripple .waveforms import IMRPhenomPv2
2623
2724jax .config .update ("jax_enable_x64" , True )
2825
29- bilby .core .utils .setup_logger () # log_level="WARNING")
3026
31-
32- def setup_prior ():
33- # Set up a PriorDict, which inherits from dict.
34- # By default we will sample all terms in the signal models. However, this will
35- # take a long time for the calculation, so for this example we will set almost
36- # all of the priors to be equall to their injected values. This implies the
37- # prior is a delta function at the true, injected value. In reality, the
38- # sampler implementation is smart enough to not sample any parameter that has
39- # a delta-function prior.
40- # The above list does *not* include mass_1, mass_2, theta_jn and luminosity
41- # distance, which means those are the parameters that will be included in the
42- # sampler. If we do nothing, then the default priors get used.
43- priors = bilby .gw .prior .BBHPriorDict ()
44- del priors ["mass_1" ], priors ["mass_2" ]
45- priors ["geocent_time" ] = bilby .core .prior .Uniform (1126249642 , 1126269642 )
46- priors ["luminosity_distance" ].minimum = 1
47- priors ["luminosity_distance" ].maximum = 500
48- priors ["chirp_mass" ].minimum = 2.35
49- priors ["chirp_mass" ].maximum = 2.45
50- # priors["luminosity_distance"] = bilby.core.prior.PowerLaw(2.0, 10.0, 500.0)
51- # priors["sky_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
52- # priors["sky_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
53- # priors["sky_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
54- # priors["delta_phase"] = priors.pop("phase")
55- # del priors["tilt_1"], priors["tilt_2"], priors["phi_12"], priors["phi_jl"]
56- # priors["spin_1_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
57- # priors["spin_1_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
58- # priors["spin_1_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
59- # priors["spin_2_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
60- # priors["spin_2_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
61- # priors["spin_2_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
62- # # del priors["a_1"], priors["a_2"]
63- # # priors["chi_1"] = bilby.core.prior.Uniform(-0.05, 0.05)
64- # # priors["chi_2"] = bilby.core.prior.Uniform(-0.05, 0.05)
65- # del priors["theta_jn"], priors["psi"], priors["delta_phase"]
66- # priors["orientation_w"] = bilby.core.prior.Normal(mu=0, sigma=1)
67- # priors["orientation_x"] = bilby.core.prior.Normal(mu=0, sigma=1)
68- # priors["orientation_y"] = bilby.core.prior.Normal(mu=0, sigma=1)
69- # priors["orientation_z"] = bilby.core.prior.Normal(mu=0, sigma=1)
70- return priors
71-
72-
73- def original_to_sampling_priors (priors , truth ):
74- del priors ["ra" ], priors ["dec" ]
75- priors ["zenith" ] = bilby .core .prior .Cosine ()
76- priors ["azimuth" ] = bilby .core .prior .Uniform (minimum = 0 , maximum = 2 * np .pi )
77- priors ["L1_time" ] = bilby .core .prior .Uniform (truth ["geocent_time" ] - 0.1 , truth ["geocent_time" ] + 0.1 )
27+ def bilby_to_ripple_spins (
28+ theta_jn ,
29+ phi_jl ,
30+ tilt_1 ,
31+ tilt_2 ,
32+ phi_12 ,
33+ a_1 ,
34+ a_2 ,
35+ ):
36+ """
37+ A simplified spherical to cartesian spin conversion function.
38+ This is not equivalent to the method used in `bilby.gw.conversion`
39+ which comes from `lalsimulation` and is not `JAX` compatible.
40+ """
41+ iota = theta_jn
42+ spin_1x = a_1 * jnp .sin (tilt_1 ) * jnp .cos (phi_jl )
43+ spin_1y = a_1 * jnp .sin (tilt_1 ) * jnp .sin (phi_jl )
44+ spin_1z = a_1 * jnp .cos (tilt_1 )
45+ spin_2x = a_2 * jnp .sin (tilt_2 ) * jnp .cos (phi_jl + phi_12 )
46+ spin_2y = a_2 * jnp .sin (tilt_2 ) * jnp .sin (phi_jl + phi_12 )
47+ spin_2z = a_2 * jnp .cos (tilt_2 )
48+ return iota , spin_1x , spin_1y , spin_1z , spin_2x , spin_2y , spin_2z
49+
50+
51+ def ripple_bbh (
52+ frequency , mass_1 , mass_2 , luminosity_distance , theta_jn , phase ,
53+ a_1 , a_2 , tilt_1 , tilt_2 , phi_12 , phi_jl , ** kwargs ,
54+ ):
55+ """
56+ Source function wrapper to ripple's IMRPhenomPv2 waveform generator.
57+ This function cannot be jitted directly as the Bilby waveform generator
58+ relies on inspecting the function signature.
59+
60+ Parameters
61+ ----------
62+ frequency: jnp.ndarray
63+ Frequencies at which to compute the waveform.
64+ mass_1: float | jnp.ndarray
65+ Mass of the primary component in solar masses.
66+ mass_2: float | jnp.ndarray
67+ Mass of the secondary component in solar masses.
68+ luminosity_distance: float | jnp.ndarray
69+ Luminosity distance to the source in Mpc.
70+ theta_jn: float | jnp.ndarray
71+ Angle between total angular momentum and line of sight in radians.
72+ phase: float | jnp.ndarray
73+ Phase at coalescence in radians.
74+ a_1: float | jnp.ndarray
75+ Dimensionless spin magnitude of the primary component.
76+ a_2: float | jnp.ndarray
77+ Dimensionless spin magnitude of the secondary component.
78+ tilt_1: float | jnp.ndarray
79+ Tilt angle of the primary component spin in radians.
80+ tilt_2: float | jnp.ndarray
81+ Tilt angle of the secondary component spin in radians.
82+ phi_12: float | jnp.ndarray
83+ Azimuthal angle between the two spin vectors in radians.
84+ phi_jl: float | jnp.ndarray
85+ Azimuthal angle of the total angular momentum vector in radians.
86+ **kwargs
87+ Additional keyword arguments. Must include 'minimum_frequency'.
88+
89+ Returns
90+ -------
91+ dict
92+ Dictionary containing the plus and cross polarizations of the waveform.
93+ """
94+ iota , * cartesian_spins = bilby_to_ripple_spins (
95+ # iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_ripple_spins(
96+ theta_jn , phi_jl , tilt_1 , tilt_2 , phi_12 , a_1 , a_2
97+ )
98+ frequencies = jnp .maximum (frequency , kwargs ["minimum_frequency" ])
99+ theta = jnp .array ([
100+ mass_1 , mass_2 , * cartesian_spins ,
101+ luminosity_distance , jnp .array (0.0 ), phase , iota
102+ ])
103+ wf_func = jax .jit (IMRPhenomPv2 .gen_IMRPhenomPv2 )
104+ hp , hc = wf_func (frequencies , theta , jnp .array (20.0 ))
105+ return dict (plus = hp , cross = hc )
78106
79107
80- def main (use_jax , model , idx ):
108+ def main ():
81109 # Set the duration and sampling frequency of the data segment that we're
82110 # going to inject the signal into
83111 duration = 64.0
84112 sampling_frequency = 2048.0
85113 minimum_frequency = 20.0
86- if use_jax :
87- duration = jax .numpy .array (duration )
88- sampling_frequency = jax .numpy .array (sampling_frequency )
89- minimum_frequency = jax .numpy .array (minimum_frequency )
114+ duration = jnp .array (duration )
115+ sampling_frequency = jnp .array (sampling_frequency )
116+ minimum_frequency = jnp .array (minimum_frequency )
90117
91118 # Specify the output directory and the name of the simulation.
92- outdir = "pp-test-2 "
93- label = f"{ model } _ { 'jax' if use_jax else 'numpy' } _ { idx } "
119+ outdir = "outdir "
120+ label = f"jax_fast_tutorial "
94121
95122 # Set up a random seed for result reproducibility. This is optional!
96- bilby .core .utils .random .seed (88170235 + idx * 1000 )
123+ bilby .core .utils .random .seed (88170235 )
97124
98- priors = setup_prior ()
125+ priors = bilby . gw . prior . BBHPriorDict ()
99126 injection_parameters = priors .sample ()
100- if model == "relbin" :
101- injection_parameters ["fiducial" ] = 1
102- original_to_sampling_priors (priors , injection_parameters )
127+ injection_parameters ["geocent_time" ] = 1000000000.0
128+ injection_parameters ["luminosity_distance" ] = 400.0
129+ del priors ["ra" ], priors ["dec" ]
130+ priors ["zenith" ] = bilby .core .prior .Cosine ()
131+ priors ["azimuth" ] = bilby .core .prior .Uniform (minimum = 0 , maximum = 2 * np .pi )
132+ priors ["L1_time" ] = bilby .core .prior .Uniform (
133+ injection_parameters ["geocent_time" ] - 0.1 ,
134+ injection_parameters ["geocent_time" ] + 0.1 ,
135+ )
103136
104137 # Fixed arguments passed into the source model
105138 waveform_arguments = dict (
@@ -108,28 +141,13 @@ def main(use_jax, model, idx):
108141 minimum_frequency = minimum_frequency ,
109142 )
110143
111- if use_jax :
112- match model :
113- case "relbin" :
114- fdsm = bilby .gw .jaxstuff .ripple_bbh_relbin
115- case _:
116- fdsm = bilby .gw .jaxstuff .ripple_bbh
117- else :
118- match model :
119- case "relbin" :
120- fdsm = bilby .gw .source .lal_binary_black_hole_relative_binning
121- case _:
122- fdsm = bilby .gw .source .lal_binary_black_hole
123- # fdsm = bilby.gw.source.sinegaussian
124-
125144 # Create the waveform_generator using a LAL BinaryBlackHole source function
126145 waveform_generator = bilby .gw .WaveformGenerator (
127146 duration = duration ,
128147 sampling_frequency = sampling_frequency ,
129- frequency_domain_source_model = fdsm ,
130- # parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
148+ frequency_domain_source_model = ripple_bbh ,
131149 waveform_arguments = waveform_arguments ,
132- use_cache = not use_jax ,
150+ use_cache = False ,
133151 )
134152
135153 # Set up interferometers. In this case we'll use two interferometers
@@ -145,74 +163,49 @@ def main(use_jax, model, idx):
145163 waveform_generator = waveform_generator , parameters = injection_parameters ,
146164 raise_error = False ,
147165 )
148- if use_jax :
149- ifos .set_array_backend (jax .numpy )
150-
151- if model == "mb" :
152- if use_jax :
153- pass
154- else :
155- waveform_generator .frequency_domain_source_model = (
156- bilby .gw .source .binary_black_hole_frequency_sequence
157- )
158- del waveform_generator .waveform_arguments ["minimum_frequency" ]
166+ ifos .set_array_backend (jnp )
159167
160168 # Initialise the likelihood by passing in the interferometer data (ifos) and
161169 # the waveform generator
162- match model :
163- case "relbin" :
164- likelihood_class = (
165- bilby .gw .likelihood .RelativeBinningGravitationalWaveTransient
166- )
167- case "mb" :
168- likelihood_class = bilby .gw .likelihood .MBGravitationalWaveTransient
169- case _:
170- likelihood_class = bilby .gw .likelihood .GravitationalWaveTransient
171- likelihood = likelihood_class (
170+ likelihood = bilby .gw .likelihood .GravitationalWaveTransient (
172171 interferometers = ifos ,
173172 waveform_generator = waveform_generator ,
174173 priors = priors ,
175174 phase_marginalization = True ,
176175 distance_marginalization = True ,
177176 reference_frame = ifos ,
178177 time_reference = "L1" ,
179- # epsilon=0.1,
180- # update_fiducial_parameters=True,
181178 )
179+ # Do an initial likelihood evaluation to trigger any internal setup
180+ likelihood .log_likelihood_ratio (priors .sample ())
181+ # Wrap the likelihood with the JittedLikelihood to JIT compile the likelihood
182+ # evaluation
183+ likelihood = JittedLikelihood (likelihood )
184+ # Evaluate the likelihood once to trigger the JIT compilation, this will take
185+ # a few seconds as compiling the waveform takes some time
186+ likelihood .log_likelihood_ratio (priors .sample ())
182187
183188 # use the log_compiles context so we can make sure there aren't recompilations
184189 # inside the sampling loop
185- if True :
186- # with jax.log_compiles():
190+ with jax .log_compiles ():
187191 result = bilby .run_sampler (
188192 likelihood = likelihood ,
189193 priors = priors ,
190- sampler = "jaxted" if use_jax else " dynesty" ,
191- nlive = 1000 ,
194+ sampler = "dynesty" ,
195+ nlive = 100 ,
192196 sample = "acceptance-walk" ,
193- method = "nest" ,
194- nsteps = 100 ,
195- naccept = 30 ,
197+ naccept = 5 ,
196198 injection_parameters = injection_parameters ,
197199 outdir = outdir ,
198200 label = label ,
199- npool = None if use_jax else 16 ,
200- # save="hdf5",
201- save = False ,
201+ npool = None ,
202+ save = "hdf5" ,
202203 rseed = np .random .randint (0 , 100000 ),
203204 )
204205
205206 # Make a corner plot.
206- # result.plot_corner()
207- import IPython ; IPython .embed ()
208- return result .sampling_time
207+ result .plot_corner ()
209208
210209
211210if __name__ == "__main__" :
212- times = dict ()
213- # for arg in product([True, False][:], ["relbin", "mb", "regular"][2:3]):
214- # times[arg] = main(*arg)
215- with jax .log_compiles ():
216- for idx in np .arange (100 ):
217- times [idx ] = main (True , "mb" , idx )
218- print (times )
211+ main ()
0 commit comments