Skip to content

Commit 56724c4

Browse files
committed
handle spw, freq in the same way as time, bl
1 parent 2202e33 commit 56724c4

2 files changed

Lines changed: 81 additions & 130 deletions

File tree

src/pyuvdata/uvdata/uvdata.py

Lines changed: 78 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@
4040
)
4141

4242

43+
def flt_ind_str_arr(*, fltarr, intarr, flt_tols, flt_first=True):
44+
"""Create a string array built from float and integer arrays for matching."""
45+
prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
46+
prec_int = 8
47+
flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr]
48+
int_str_list = [str(intv).zfill(prec_int) for intv in intarr]
49+
if flt_first:
50+
zipped = zip(flt_str_list, int_str_list, strict=True)
51+
else:
52+
zipped = zip(int_str_list, flt_str_list, strict=True)
53+
return np.array(["_".join(zpval) for zpval in zipped])
54+
55+
4356
def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None):
4457
update_params = this._get_param_axis(axis_name, single_named_axis=True)
4558
other_form_dict = {axis_name: other_inds}
@@ -5378,15 +5391,20 @@ def fix_phase(self, *, use_ant_pos=True):
53785391

53795392
def blt_str_arr(self):
53805393
"""Create a string array with baseline and time info for matching purposes."""
5381-
prec_t = -2 * np.floor(np.log10(self._time_array.tols[-1])).astype(int)
5382-
prec_b = 8
5383-
return np.array(
5384-
[
5385-
"_".join(
5386-
["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)]
5387-
)
5388-
for blt in zip(self.time_array, self.baseline_array, strict=True)
5389-
]
5394+
return flt_ind_str_arr(
5395+
fltarr=self.time_array,
5396+
intarr=self.baseline_array,
5397+
flt_tols=self._time_array.tols,
5398+
flt_first=True,
5399+
)
5400+
5401+
def spw_freq_str_arr(self):
5402+
"""Create a string array with spw and freq info for matching purposes."""
5403+
return flt_ind_str_arr(
5404+
fltarr=self.freq_array,
5405+
intarr=self.flex_spw_id_array,
5406+
flt_tols=self._freq_array.tols,
5407+
flt_first=False,
53905408
)
53915409

53925410
def flexpol_dict(self):
@@ -5485,9 +5503,10 @@ def __add__(
54855503
axis_params_check = {}
54865504
axis_overlap_params = {
54875505
"Nblts": ["time_array", "baseline_array"],
5488-
"Nfreqs": ["freq_array"],
5506+
"Nfreqs": ["freq_array", "flex_spw_id_array"],
54895507
"Npols": ["polarization_array"],
54905508
}
5509+
axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
54915510
axis_dict = {}
54925511
for axis in axes:
54935512
axis_dict[axis] = this._get_param_axis(axis)
@@ -5503,56 +5522,28 @@ def __add__(
55035522
# the blt axis because we need a combo of time and baseline.
55045523
axis_vals = {}
55055524
for axis, overlap_params in axis_overlap_params.items():
5506-
if axis == "Nblts":
5507-
# Create combined arrays for convenience
5508-
this_blts = this.blt_str_arr()
5509-
other_blts = other.blt_str_arr()
5510-
axis_vals[axis] = {"this": this_blts, "other": other_blts}
5525+
if len(overlap_params) > 1:
5526+
axis_vals[axis] = {
5527+
"this": getattr(this, axis_combined_func[axis])(),
5528+
"other": getattr(other, axis_combined_func[axis])(),
5529+
}
55115530
else:
55125531
axis_vals[axis] = {
55135532
"this": getattr(this, overlap_params[0]),
55145533
"other": getattr(other, overlap_params[0]),
55155534
}
5535+
55165536
# Check we don't have overlapping data
55175537
axis_inds = {}
55185538
for axis, val_arr in axis_vals.items():
5519-
if axis == "Nfreqs":
5520-
# This is more complicated because we are allowed to have channels
5521-
# with the same frequency *if* they belong to different spectral
5522-
# windows (one real-life example: you might want to preserve
5523-
# guard bands in the correlator, which can have overlaping RF
5524-
# frequency channels)
5525-
this_freq_ind = np.array([], dtype=np.int64)
5526-
other_freq_ind = np.array([], dtype=np.int64)
5527-
both_freq_ind = np.array([], dtype=float)
5528-
both_spw = np.intersect1d(this.spw_array, other.spw_array)
5529-
for idx in both_spw:
5530-
this_mask = np.where(this.flex_spw_id_array == idx)[0]
5531-
other_mask = np.where(other.flex_spw_id_array == idx)[0]
5532-
both_spw_freq, this_spw_ind, other_spw_ind = np.intersect1d(
5533-
this.freq_array[this_mask],
5534-
other.freq_array[other_mask],
5535-
return_indices=True,
5536-
)
5537-
this_freq_ind = np.append(this_freq_ind, this_mask[this_spw_ind])
5538-
other_freq_ind = np.append(
5539-
other_freq_ind, other_mask[other_spw_ind]
5540-
)
5541-
both_freq_ind = np.append(both_freq_ind, both_spw_freq)
5542-
axis_inds["Nfreqs"] = {
5543-
"this": this_freq_ind,
5544-
"other": other_freq_ind,
5545-
"both": both_freq_ind,
5546-
}
5547-
else:
5548-
both_inds, this_inds, other_inds = np.intersect1d(
5549-
val_arr["this"], val_arr["other"], return_indices=True
5550-
)
5551-
axis_inds[axis] = {
5552-
"this": this_inds,
5553-
"other": other_inds,
5554-
"both": both_inds,
5555-
}
5539+
both_inds, this_inds, other_inds = np.intersect1d(
5540+
val_arr["this"], val_arr["other"], return_indices=True
5541+
)
5542+
axis_inds[axis] = {
5543+
"this": this_inds,
5544+
"other": other_inds,
5545+
"both": both_inds,
5546+
}
55565547

55575548
history_update_string = ""
55585549
if not self.metadata_only and np.all(
@@ -5603,12 +5594,11 @@ def __add__(
56035594
}
56045595
# find the indices in "other" but not in "this"
56055596
for axis in axes:
5606-
if axis != "Nfreqs":
5607-
temp = np.nonzero(
5608-
~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"])
5609-
)[0]
5610-
else:
5611-
# more complicated because of spws
5597+
if axis == "Nfreqs" and (
5598+
this.flex_spw_polarization_array is not None
5599+
or other.flex_spw_polarization_array is not None
5600+
):
5601+
# special checking for flex_spw
56125602
if (this.flex_spw_polarization_array is None) != (
56135603
other.flex_spw_polarization_array is None
56145604
):
@@ -5633,27 +5623,18 @@ def __add__(
56335623
except KeyError:
56345624
this_flexpol_dict[key] = other_flexpol_dict[key]
56355625

5636-
other_mask = np.ones_like(other.flex_spw_id_array, dtype=bool)
5637-
for idx in np.intersect1d(this.spw_array, other.spw_array):
5638-
other_mask[other.flex_spw_id_array == idx] = np.isin(
5639-
other.freq_array[other.flex_spw_id_array == idx],
5640-
this.freq_array[this.flex_spw_id_array == idx],
5641-
invert=True,
5642-
)
5643-
temp = np.where(other_mask)[0]
5626+
temp = np.nonzero(
5627+
~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"])
5628+
)[0]
56445629
if len(temp) > 0:
56455630
new_inds[axis] = temp
56465631
# add params associated with the other axes to compatibility_params
56475632
for axis2 in axes:
56485633
if axis2 != axis:
56495634
compatibility_params.extend(axis_params_check[axis2])
5650-
if axis == "Nblts":
5651-
new_blts = other_blts[temp]
56525635
additions.append(axis_descriptions[axis])
56535636
else:
56545637
new_inds[axis] = []
5655-
if axis == "Nblts":
5656-
new_blts = ([], [])
56575638

56585639
# Actually check compatibility parameters
56595640
for cp in compatibility_params:
@@ -5701,6 +5682,7 @@ def __add__(
57015682
order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
57025683
for axis, ind_dict in axis_inds.items():
57035684
if len(ind_dict["this"]) != 0:
5685+
# there is some overlap, so sorting matters
57045686
this_argsort = np.argsort(ind_dict["this"])
57055687
other_argsort = np.argsort(ind_dict["other"])
57065688

@@ -5714,16 +5696,18 @@ def __add__(
57145696
getattr(this, reorder_method[axis]["method"])(**kwargs)
57155697

57165698
# Pad out self to accommodate new data
5699+
new_axis_inds = {}
57175700
for axis_ind, axis in enumerate(axes):
57185701
if len(new_inds[axis]) > 0:
5719-
if axis == "Nblts":
5720-
this_blts = np.concatenate((this_blts, new_blts))
5721-
order_dict["Nblts"] = np.argsort(this_blts)
5722-
order_use = order_dict["Nblts"]
5702+
new_axis_inds[axis] = np.concatenate(
5703+
(axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5704+
)
5705+
if axis == "Npols":
5706+
order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis]))
57235707
else:
5724-
order_use = None
5708+
order_dict[axis] = np.argsort(new_axis_inds[axis])
57255709

5726-
_axis_add_helper(this, other, axis, new_inds[axis], order_use)
5710+
_axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
57275711

57285712
if not self.metadata_only:
57295713
pad_shape = list(this.data_array.shape)
@@ -5738,6 +5722,8 @@ def __add__(
57385722
this.flag_array = np.concatenate(
57395723
[this.flag_array, 1 - zero_pad], axis=axis_ind
57405724
).astype(np.bool_)
5725+
else:
5726+
new_axis_inds[axis] = axis_vals[axis]["this"]
57415727

57425728
if len(new_inds["Nfreqs"]) > 0:
57435729
# We want to preserve per-spw information based on first appearance
@@ -5752,51 +5738,24 @@ def __add__(
57525738
this.flex_spw_polarization_array = np.array(
57535739
[this_flexpol_dict[key] for key in this.spw_array]
57545740
)
5755-
# Need to sort out the order of the individual windows first.
5756-
order_dict["Nfreqs"] = np.concatenate(
5757-
[
5758-
np.where(this.flex_spw_id_array == idx)[0]
5759-
for idx in sorted(this.spw_array)
5760-
]
5761-
)
5762-
5763-
# With spectral windows sorted, check and see if channels within
5764-
# windows need sorting. If they are ordered in ascending or descending
5765-
# fashion, leave them be. If not, sort in ascending order
5766-
for idx in this.spw_array:
5767-
select_mask = this.flex_spw_id_array[order_dict["Nfreqs"]] == idx
5768-
check_freqs = this.freq_array[order_dict["Nfreqs"][select_mask]]
5769-
if (not np.all(check_freqs[1:] > check_freqs[:-1])) and (
5770-
not np.all(check_freqs[1:] < check_freqs[:-1])
5771-
):
5772-
subsort_order = order_dict["Nfreqs"][select_mask]
5773-
order_dict["Nfreqs"][select_mask] = subsort_order[
5774-
np.argsort(check_freqs)
5775-
]
5776-
5777-
if len(new_inds["Npols"]) > 0:
5778-
order_dict["Npols"] = np.argsort(np.abs(this.polarization_array))
57795741

57805742
# Now populate the data
5781-
pol_t2o = np.nonzero(
5782-
np.isin(this.polarization_array, other.polarization_array)
5783-
)[0]
5784-
this_freqs = this.freq_array
5785-
other_freqs = other.freq_array
5786-
5787-
freq_t2o = np.zeros(this_freqs.shape, dtype=bool)
5788-
for spw_id in set(this.spw_array).intersection(other.spw_array):
5789-
mask = this.flex_spw_id_array == spw_id
5790-
freq_t2o[mask] |= np.isin(
5791-
this_freqs[mask], other_freqs[other.flex_spw_id_array == spw_id]
5792-
)
5793-
freq_t2o = np.nonzero(freq_t2o)[0]
5794-
blt_t2o = np.nonzero(np.isin(this_blts, other_blts))[0]
5743+
t2o_dict = {}
5744+
for axis, inds_dict in axis_vals.items():
5745+
t2o_dict[axis] = np.nonzero(
5746+
np.isin(new_axis_inds[axis], inds_dict["other"])
5747+
)[0]
57955748

57965749
if not self.metadata_only:
5797-
this.data_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.data_array
5798-
this.nsample_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.nsample_array
5799-
this.flag_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.flag_array
5750+
this.data_array[
5751+
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5752+
] = other.data_array
5753+
this.nsample_array[
5754+
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5755+
] = other.nsample_array
5756+
this.flag_array[
5757+
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5758+
] = other.flag_array
58005759

58015760
# Fix ordering
58025761
for axis_ind, axis in enumerate(axes):
@@ -5812,16 +5771,6 @@ def __add__(
58125771
this, name, np.take(param, order_dict[axis], axis=axis_ind)
58135772
)
58145773

5815-
# reorder freq, pol axes but not blt axis because that was already done.
5816-
for axis in axes[1:]:
5817-
params_to_update = axis_params_check[axis] + [
5818-
"_" + param for param in axis_overlap_params[axis]
5819-
]
5820-
if len(new_inds[axis]) > 0:
5821-
for param in params_to_update:
5822-
this_param = getattr(this, param)
5823-
this_param.value = this_param.value[order_dict[axis]]
5824-
58255774
# Update N parameters (e.g. Npols)
58265775
this.Ntimes = len(np.unique(this.time_array))
58275776
this.Nbls = len(np.unique(this.baseline_array))
@@ -5862,8 +5811,8 @@ def __add__(
58625811
)
58635812

58645813
# Reset blt_order if blt axis was added to and it is set
5865-
if len(blt_t2o) > 0:
5866-
this.blt_order = None
5814+
if len(t2o_dict["Nblts"]) > 0:
5815+
this.blt_order = ("time", "baseline")
58675816

58685817
this.set_rectangularity(force=True)
58695818

tests/uvdata/test_mir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,11 @@ def test_flex_pol_add(sma_mir_filt):
592592
sma_yy_copy._make_flex_pol()
593593

594594
# Add the two back together here, and make sure we can the same value out,
595-
# modulo the history.
595+
# modulo the history and sorting.
596596
sma_check = sma_yy_copy + sma_xx_copy
597597

598+
sma_mir_filt.reorder_freqs(channel_order="freq")
599+
598600
assert sma_check.history != sma_mir_filt.history
599601
sma_check.history = sma_mir_filt.history = None
600602

0 commit comments

Comments
 (0)