Skip to content

Commit 572a748

Browse files
committed
check that overlap is not valid data programmatically
1 parent 17a5294 commit 572a748

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

src/pyuvdata/uvdata/uvdata.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5758,36 +5758,49 @@ def __add__(
57585758
}
57595759

57605760
history_update_string = ""
5761-
# TODO do this programmatically for multidimensional parameters
5762-
if not self.metadata_only and np.all(
5763-
[len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]
5764-
):
5761+
5762+
if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
57655763
# We have overlaps, check that overlapping data is not valid
5766-
this_inds = np.ravel_multi_index(
5767-
(
5768-
axis_inds["Nblts"]["this"][:, np.newaxis, np.newaxis],
5769-
axis_inds["Nfreqs"]["this"][np.newaxis, :, np.newaxis],
5770-
axis_inds["Npols"]["this"][np.newaxis, np.newaxis, :],
5771-
),
5772-
this.data_array.shape,
5773-
).flatten()
5774-
other_inds = np.ravel_multi_index(
5775-
(
5776-
axis_inds["Nblts"]["other"][:, np.newaxis, np.newaxis],
5777-
axis_inds["Nfreqs"]["other"][np.newaxis, :, np.newaxis],
5778-
axis_inds["Npols"]["other"][np.newaxis, np.newaxis, :],
5779-
),
5780-
other.data_array.shape,
5781-
).flatten()
5782-
this_all_zero = np.all(this.data_array.flatten()[this_inds] == 0)
5783-
this_all_flag = np.all(this.flag_array.flatten()[this_inds])
5784-
other_all_zero = np.all(other.data_array.flatten()[other_inds] == 0)
5785-
other_all_flag = np.all(other.flag_array.flatten()[other_inds])
5786-
5787-
if this_all_zero and this_all_flag:
5764+
multi_axis_params = this._get_multi_axis_params()
5765+
this_test = []
5766+
other_test = []
5767+
for param in multi_axis_params:
5768+
form = getattr(this, "_" + param).form
5769+
this_shape = getattr(this, param).shape
5770+
other_shape = getattr(other, param).shape
5771+
this_param_type = getattr(this, "_" + param).expected_type
5772+
bool_type = this_param_type is bool or bool in this_param_type
5773+
5774+
this_index_list = []
5775+
other_index_list = []
5776+
for ax_ind, axis in enumerate(form):
5777+
expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
5778+
this_index_list.append(
5779+
np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5780+
)
5781+
other_index_list.append(
5782+
np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5783+
)
5784+
this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
5785+
5786+
other_inds = np.ravel_multi_index(
5787+
other_index_list, other_shape
5788+
).flatten()
5789+
5790+
this_arr = getattr(this, param).flatten()[this_inds]
5791+
other_arr = getattr(other, param).flatten()[other_inds]
5792+
5793+
if bool_type:
5794+
this_test.append(np.all(this_arr))
5795+
other_test.append(np.all(other_arr))
5796+
else:
5797+
this_test.append(np.all(this_arr == 0))
5798+
other_test.append(np.all(other_arr == 0))
5799+
5800+
if np.all(this_test):
57885801
# we're fine to overwrite; update history accordingly
57895802
history_update_string = " Overwrote invalid data using pyuvdata."
5790-
elif other_all_zero and other_all_flag:
5803+
elif np.all(other_test):
57915804
raise ValueError(
57925805
"To combine these data, please run the add operation again, "
57935806
"but with the object whose data is to be overwritten as the "

tests/uvdata/test_uvdata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3983,6 +3983,7 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2):
39833983
if np.any(np.logical_and(screen1, screen2)):
39843984
flag_screen = screen2[screen1]
39853985
uv1.data_array[:, flag_screen] = 0.0
3986+
uv1.nsample_array[:, flag_screen] = 0.0
39863987
uv1.flag_array[:, flag_screen] = True
39873988

39883989
uv_recomb = getattr(uv1, add_method[0])(uv2, **add_method[1])

0 commit comments

Comments
 (0)