Skip to content

Commit 171974c

Browse files
authored
gh-395: Theta max (#394)
1 parent eec4980 commit 171974c

4 files changed

Lines changed: 104 additions & 177 deletions

File tree

examples/unmixing.ipynb

Lines changed: 80 additions & 139 deletions
Large diffs are not rendered by default.

heracles/transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# Python < 3.13
88
from dataclasses import replace
99

10+
from .result import get_result_array
11+
1012
gauss_legendre = None
1113
_gauss_legendre_cache = {}
1214

@@ -203,6 +205,9 @@ def cl2corr(cls):
203205
s1, s2 = cl.spin
204206
# Grab metadata
205207
dtype = cl.array.dtype
208+
# Determine lmax from ell field or shape along ell axis
209+
lmax = len(get_result_array(cl, "ell")[0]) - 1
210+
xvals, _ = _cached_gauss_legendre(lmax + 1)
206211
# Initialize wd
207212
wd = np.zeros_like(cl)
208213
if (s1 != 0) and (s2 != 0):
@@ -259,6 +264,7 @@ def cl2corr(cls):
259264
wd = np.array(list(wd), dtype=dtype)
260265
wds[key] = replace(
261266
cls[key],
267+
ell=xvals,
262268
array=wd,
263269
)
264270
return wds
@@ -278,6 +284,9 @@ def corr2cl(wds):
278284
s1, s2 = wd.spin
279285
# Grab metadata
280286
dtype = wd.array.dtype
287+
# Derive lmax from xvals stored in the correlation's ell field
288+
xvals = get_result_array(wd, "ell")[0]
289+
lmax = len(xvals) - 1
281290
# initialize cl
282291
cl = np.zeros_like(wd)
283292
if (s1 != 0) and (s2 != 0):
@@ -335,6 +344,7 @@ def corr2cl(wds):
335344
cl = np.array(list(cl), dtype=dtype)
336345
cls[key] = replace(
337346
wds[key],
347+
ell=np.arange(lmax + 1),
338348
array=cl,
339349
)
340350
return cls

heracles/unmixing.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
# You should have received a copy of the GNU Lesser General Public
1818
# License along with Heracles. If not, see <https://www.gnu.org/licenses/>.
1919
import numpy as np
20-
from collections.abc import Mapping
2120
from .result import binned
2221
from .transforms import cl2corr, corr2cl
2322
from .utils import get_cl
23+
from .transforms import _cached_gauss_legendre
2424

2525
try:
2626
from copy import replace
@@ -29,38 +29,36 @@
2929
from dataclasses import replace
3030

3131

32-
def naturalspice(d, m, fields, rcond=0.01):
32+
def naturalspice(d, m, fields, theta_max=None):
3333
"""
3434
Natural unmixing of the data Cl.
3535
Args:
3636
d: Data Cl
3737
m: mask Cl
3838
fields: list of fields
39-
patch_hole: If True, apply the patch hole correction
39+
theta_max: maximum angle to use for the unmixing, in degrees. If None, use all angles.
4040
Returns:
4141
corr_d: Corrected Cl
4242
"""
4343
first_wd = list(d.values())[0]
4444
first_wm = list(m.values())[0]
4545
lmax = first_wd.shape[first_wd.axis[0]]
4646
lmax_mask = first_wm.shape[first_wm.axis[0]]
47+
xvals, _ = _cached_gauss_legendre(lmax_mask)
48+
theta = np.arccos(xvals) * 180 / np.pi
4749

4850
# pad correlation functions to lmax_mask
4951
d = binned(d, np.arange(0, lmax_mask + 1))
5052

5153
wd = cl2corr(d)
5254
wm = cl2corr(m)
53-
for m_key in list(wm.keys()):
54-
if isinstance(rcond, Mapping):
55-
if m_key not in rcond:
56-
raise KeyError(f"Missing rcond value for wm key: {m_key}")
57-
_rcond = rcond[m_key]
58-
else:
59-
_rcond = rcond
60-
_wm = wm[m_key].array
61-
_wm = _wm * logistic(np.log10(abs(_wm)), x0=np.log10(_rcond * np.max(_wm)))
62-
wm[m_key] = replace(wm[m_key], array=_wm)
63-
55+
if theta_max is not None:
56+
for m_key in list(wm.keys()):
57+
_wm = wm[m_key].array
58+
i_theta_max = np.abs(theta - theta_max).argmin()
59+
wm_at_theta_max = _wm[i_theta_max]
60+
_wm = _wm * logistic(np.log10(abs(_wm)), x0=np.log10(abs(wm_at_theta_max)))
61+
wm[m_key] = replace(wm[m_key], array=_wm)
6462
corr_wds = _naturalspice(wd, wm, fields)
6563

6664
# trnasform back to Cl
@@ -99,5 +97,5 @@ def _naturalspice(wd, wm, fields):
9997
return corr_wds
10098

10199

102-
def logistic(x, x0=-5, k=50):
100+
def logistic(x, x0=-2, k=50):
103101
return 1.0 + np.exp(-k * (x - x0))

tests/test_dices.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -117,23 +117,12 @@ def test_full_mask_correction(cls0, mls0, fields):
117117
assert np.isclose(cl[2:], _cl[2:]).all()
118118

119119
cls_alphas = heracles.corr2cl(alphas)
120-
__cls = heracles.unmixing.naturalspice(cls0, cls_alphas, fields, rcond=1.0)
120+
__cls = heracles.unmixing.naturalspice(cls0, cls_alphas, fields, theta_max=180)
121121
for key in list(cls0.keys()):
122122
cl = cls0[key].array
123123
_cl = __cls[key].array
124124
assert np.isclose(cl[2:], 2 * _cl[2:]).all()
125125

126-
wm_keys = heracles.transforms.cl2corr(cls_alphas).keys()
127-
rcond_by_key = {key: 1.0 for key in wm_keys}
128-
___cls = heracles.unmixing.naturalspice(
129-
cls0,
130-
cls_alphas,
131-
fields,
132-
rcond=rcond_by_key,
133-
)
134-
for key in list(cls0.keys()):
135-
np.testing.assert_allclose(__cls[key].array, ___cls[key].array)
136-
137126
_alphas = dices.get_mask_correlation_ratio(mls0, mls0, unmixed=True)
138127
for key in list(_alphas.keys()):
139128
wmls0 = heracles.transforms._cl2corr(mls0[key]).T[0]
@@ -142,17 +131,6 @@ def test_full_mask_correction(cls0, mls0, fields):
142131
assert np.isclose(alpha, _alpha).all()
143132

144133

145-
def test_full_mask_correction_rcond_missing_key(cls0, mls0, fields):
146-
cls_alphas = heracles.corr2cl(
147-
dices.get_mask_correlation_ratio(mls0, mls0, unmixed=False)
148-
)
149-
wm_keys = list(heracles.transforms.cl2corr(cls_alphas).keys())
150-
rcond_by_key = {key: 1.0 for key in wm_keys[1:]}
151-
152-
with pytest.raises(KeyError, match="Missing rcond value for wm key"):
153-
heracles.unmixing.naturalspice(cls0, cls_alphas, fields, rcond=rcond_by_key)
154-
155-
156134
def test_fast_mask_correction(cls0, fields, jk_maps):
157135
_cls0 = dices.correct_footprint_reduction(cls0, jk_maps, fields, 0, 0)
158136
for key in list(cls0.keys()):

0 commit comments

Comments
 (0)