Skip to content

Commit ceefc07

Browse files
authored
gh-397: Dices lmax (#398)
1 parent 171974c commit ceefc07

4 files changed

Lines changed: 71 additions & 62 deletions

File tree

heracles/dices/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
"jackknife_fsky",
2727
"jackknife_bias",
2828
"correct_bias",
29-
"get_mask_correlation_ratio",
30-
"correct_footprint_reduction",
29+
"correct_footprint_naturalspice",
30+
"correct_footprint_fsky",
3131
"jackknife_covariance",
3232
"debias_covariance",
3333
"delete2_correction",
@@ -46,8 +46,8 @@
4646
jackknife_fsky,
4747
jackknife_bias,
4848
correct_bias,
49-
get_mask_correlation_ratio,
50-
correct_footprint_reduction,
49+
correct_footprint_naturalspice,
50+
correct_footprint_fsky,
5151
jackknife_covariance,
5252
debias_covariance,
5353
delete2_correction,

heracles/dices/jackknife.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
from itertools import combinations
2323
from ..utils import add_to_Cls, sub_to_Cls
2424
from ..core import update_metadata
25-
from ..result import Result, get_result_array
25+
from ..result import Result, get_result_array, binned
2626
from ..mapping import transform
2727
from ..twopoint import angular_power_spectra
28-
from ..unmixing import _naturalspice, logistic
29-
from ..transforms import _cl2corr, cl2corr, corr2cl
28+
from ..unmixing import _naturalspice
29+
from ..transforms import cl2corr, corr2cl
3030

3131
try:
3232
from copy import replace
@@ -76,12 +76,11 @@ def jackknife_cls(
7676
if mask_correction == "Full":
7777
vis_alms_jk = _sum_alms_except(vis_alms_regions, regions)
7878
_cls_mm = angular_power_spectra(vis_alms_jk)
79-
alphas = get_mask_correlation_ratio(_cls_mm, mls0, unmixed=unmixed)
80-
_wcls = cl2corr(_cls)
81-
_wcls = _naturalspice(_wcls, alphas, fields)
82-
_cls = corr2cl(_wcls)
79+
_cls = correct_footprint_naturalspice(
80+
_cls, _cls_mm, mls0, fields, unmixed=unmixed
81+
)
8382
elif mask_correction == "Fast":
84-
_cls = correct_footprint_reduction(
83+
_cls = correct_footprint_fsky(
8584
_cls, jk_maps, fields, *regions, unmixed=unmixed
8685
)
8786
else:
@@ -220,7 +219,7 @@ def correct_bias(cls, jkmaps, fields, jk=0, jk2=0):
220219
return cls
221220

222221

223-
def correct_footprint_reduction(cls, jkmaps, fields, jk=0, jk2=0, unmixed=False):
222+
def correct_footprint_fsky(cls, jkmaps, fields, jk=0, jk2=0, unmixed=False):
224223
"""
225224
Corrects the Cls for the footprint reduction due to taking out a region.
226225
inputs:
@@ -249,33 +248,42 @@ def correct_footprint_reduction(cls, jkmaps, fields, jk=0, jk2=0, unmixed=False)
249248
return _cls
250249

251250

252-
def get_mask_correlation_ratio(Mljk, Mls0, unmixed=False):
251+
def _mask_correlation_ratio(mljk, mls0, unmixed=False):
252+
alphas = {}
253+
wmls0 = cl2corr(mls0)
254+
wmljk = cl2corr(mljk)
255+
for key in list(wmljk.keys()):
256+
_wmljk = wmljk[key].array
257+
_wmls0 = wmls0[key].array
258+
alpha = _wmljk
259+
if not unmixed:
260+
alpha = alpha / _wmls0
261+
alphas[key] = replace(mls0[key], array=alpha)
262+
return alphas
263+
264+
265+
def correct_footprint_naturalspice(cls, cls_mm, mls0, fields, unmixed=False):
253266
"""
254-
Computes the ratio of the correlation
255-
functions of the masks Cls.
256-
input:
257-
Mljk (np.array): mask of delete1 Cls
258-
Mls0 (np.array): mask Cls
267+
Corrects the Cls for footprint reduction using the full NaMaster/naturalspice approach.
268+
inputs:
269+
cls (dict): Dictionary of data Cls
270+
cls_mm (dict): Dictionary of jackknife mask Cls
271+
mls0 (dict): Dictionary of full mask Cls
272+
fields (dict): Dictionary of fields
259273
unmixed (bool): unmix the Cls
260274
returns:
261-
alpha (Float64): Mask correction factor
275+
cls (dict): Corrected Cls
262276
"""
263-
alphas = {}
264-
for key in list(Mljk.keys()):
265-
mljk = Mljk[key]
266-
mls0 = Mls0[key]
267-
# Transform to real space
268-
wmljk = _cl2corr(mljk)
269-
wmljk = wmljk.T[0]
270-
wmljk *= logistic(np.log10(abs(wmljk)))
271-
# Compute alpha
272-
alpha = wmljk
273-
if not unmixed:
274-
wmls0 = _cl2corr(mls0)
275-
wmls0 = wmls0.T[0]
276-
alpha /= wmls0
277-
alphas[key] = replace(Mls0[key], array=alpha)
278-
return alphas
277+
alphas = _mask_correlation_ratio(cls_mm, mls0, unmixed=unmixed)
278+
first_cls = list(cls.values())[0]
279+
first_mls = list(mls0.values())[0]
280+
lmax = first_cls.shape[first_cls.axis[0]]
281+
lmax_mask = first_mls.shape[first_mls.axis[0]]
282+
cls = binned(cls, np.arange(0, lmax_mask + 1))
283+
wcls = cl2corr(cls)
284+
wcls = _naturalspice(wcls, alphas, fields)
285+
cls = corr2cl(wcls)
286+
return binned(cls, np.arange(0, lmax + 1))
279287

280288

281289
def jackknife_covariance(dict, nd=1):

heracles/unmixing.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from dataclasses import replace
3030

3131

32+
def logistic(x, x0=-2, k=50):
33+
return 1.0 + np.exp(-k * (x - x0))
34+
35+
3236
def naturalspice(d, m, fields, theta_max=None):
3337
"""
3438
Natural unmixing of the data Cl.
@@ -44,22 +48,13 @@ def naturalspice(d, m, fields, theta_max=None):
4448
first_wm = list(m.values())[0]
4549
lmax = first_wd.shape[first_wd.axis[0]]
4650
lmax_mask = first_wm.shape[first_wm.axis[0]]
47-
xvals, _ = _cached_gauss_legendre(lmax_mask)
48-
theta = np.arccos(xvals) * 180 / np.pi
4951

5052
# pad correlation functions to lmax_mask
5153
d = binned(d, np.arange(0, lmax_mask + 1))
5254

5355
wd = cl2corr(d)
5456
wm = cl2corr(m)
55-
if theta_max is not None:
56-
for m_key in list(wm.keys()):
57-
_wm = wm[m_key].array
58-
i_theta_max = np.abs(theta - theta_max).argmin()
59-
wm_at_theta_max = _wm[i_theta_max]
60-
_wm = _wm * logistic(np.log10(abs(_wm)), x0=np.log10(abs(wm_at_theta_max)))
61-
wm[m_key] = replace(wm[m_key], array=_wm)
62-
corr_wds = _naturalspice(wd, wm, fields)
57+
corr_wds = _naturalspice(wd, wm, fields, theta_max=theta_max)
6358

6459
# trnasform back to Cl
6560
corr_d = corr2cl(corr_wds)
@@ -69,14 +64,14 @@ def naturalspice(d, m, fields, theta_max=None):
6964
return corr_d
7065

7166

72-
def _naturalspice(wd, wm, fields):
67+
def _naturalspice(wd, wm, fields, theta_max=None):
7368
"""
7469
Natural unmixing of the data correlation function.
7570
Args:
7671
wd: data correlation function
7772
wm: mask correlation function
7873
fields: list of fields
79-
patch_hole: If True, apply the patch hole correction
74+
theta_max: maximum angle in degrees for the logistic cutoff. If None, uses default x0=-2.
8075
Returns:
8176
corr_d: Corrected Cl
8277
"""
@@ -85,17 +80,22 @@ def _naturalspice(wd, wm, fields):
8580
if field.mask is not None:
8681
masks[key] = field.mask
8782

83+
if theta_max is not None:
84+
first_wm = list(wm.values())[0]
85+
lmax_mask = first_wm.shape[first_wm.axis[0]]
86+
xvals, _ = _cached_gauss_legendre(lmax_mask)
87+
theta = np.arccos(xvals) * 180 / np.pi
88+
i_theta_max = np.abs(theta - theta_max).argmin()
89+
8890
corr_wds = {}
8991
for key in wd.keys():
9092
a, b, i, j = key
9193
m_key = (masks[a], masks[b], i, j)
92-
_wm = get_cl(m_key, wm)
93-
_wd = wd[key]
94-
# divide by the mask correlation function
95-
corr_wds[key] = replace(wd[key], array=(_wd.array / _wm.array))
94+
_wm = get_cl(m_key, wm).array
95+
_wd = wd[key].array
96+
if theta_max is not None:
97+
x0 = np.log10(abs(_wm[i_theta_max]))
98+
_wm *= logistic(np.log10(abs(_wm)), x0=x0)
99+
corr_wds[key] = replace(wd[key], array=(_wd / _wm))
96100

97101
return corr_wds
98-
99-
100-
def logistic(x, x0=-2, k=50):
101-
return 1.0 + np.exp(-k * (x - x0))

tests/test_dices.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,24 @@ def test_get_delete2_fsky(jk_maps, njk):
107107

108108

109109
def test_full_mask_correction(cls0, mls0, fields):
110-
alphas = dices.get_mask_correlation_ratio(mls0, mls0, unmixed=False)
111-
wcls0 = heracles.transforms.cl2corr(cls0)
112-
_wcls = heracles.unmixing._naturalspice(wcls0, alphas, fields)
113-
_cls = heracles.transforms.corr2cl(_wcls)
110+
from heracles.dices.jackknife import _mask_correlation_ratio
111+
112+
# When mljk == mls0, correct_footprint_naturalspice should recover the original cls
113+
_cls = dices.correct_footprint_naturalspice(cls0, mls0, mls0, fields, unmixed=False)
114114
for key in list(cls0.keys()):
115115
cl = cls0[key].array
116116
_cl = _cls[key].array
117117
assert np.isclose(cl[2:], _cl[2:]).all()
118118

119+
alphas = _mask_correlation_ratio(mls0, mls0, unmixed=False)
119120
cls_alphas = heracles.corr2cl(alphas)
120121
__cls = heracles.unmixing.naturalspice(cls0, cls_alphas, fields, theta_max=180)
121122
for key in list(cls0.keys()):
122123
cl = cls0[key].array
123124
_cl = __cls[key].array
124125
assert np.isclose(cl[2:], 2 * _cl[2:]).all()
125126

126-
_alphas = dices.get_mask_correlation_ratio(mls0, mls0, unmixed=True)
127+
_alphas = _mask_correlation_ratio(mls0, mls0, unmixed=True)
127128
for key in list(_alphas.keys()):
128129
wmls0 = heracles.transforms._cl2corr(mls0[key]).T[0]
129130
alpha = alphas[key].array
@@ -132,7 +133,7 @@ def test_full_mask_correction(cls0, mls0, fields):
132133

133134

134135
def test_fast_mask_correction(cls0, fields, jk_maps):
135-
_cls0 = dices.correct_footprint_reduction(cls0, jk_maps, fields, 0, 0)
136+
_cls0 = dices.correct_footprint_fsky(cls0, jk_maps, fields, 0, 0)
136137
for key in list(cls0.keys()):
137138
cl = cls0[key].array
138139
_cl = _cls0[key].array

0 commit comments

Comments
 (0)