Skip to content

Commit 5c5fa70

Browse files
committed
check that overlap is not valid data programmatically
1 parent f6bd529 commit 5c5fa70

2 files changed

Lines changed: 41 additions & 27 deletions

File tree

src/pyuvdata/uvdata/uvdata.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5777,36 +5777,49 @@ def __add__(
57775777
}
57785778

57795779
history_update_string = ""
5780-
# TODO do this programmatically for multidimensional parameters
5781-
if not self.metadata_only and np.all(
5782-
[len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]
5783-
):
5780+
5781+
if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
57845782
# We have overlaps, check that overlapping data is not valid
5785-
this_inds = np.ravel_multi_index(
5786-
(
5787-
axis_inds["Nblts"]["this"][:, np.newaxis, np.newaxis],
5788-
axis_inds["Nfreqs"]["this"][np.newaxis, :, np.newaxis],
5789-
axis_inds["Npols"]["this"][np.newaxis, np.newaxis, :],
5790-
),
5791-
this.data_array.shape,
5792-
).flatten()
5793-
other_inds = np.ravel_multi_index(
5794-
(
5795-
axis_inds["Nblts"]["other"][:, np.newaxis, np.newaxis],
5796-
axis_inds["Nfreqs"]["other"][np.newaxis, :, np.newaxis],
5797-
axis_inds["Npols"]["other"][np.newaxis, np.newaxis, :],
5798-
),
5799-
other.data_array.shape,
5800-
).flatten()
5801-
this_all_zero = np.all(this.data_array.flatten()[this_inds] == 0)
5802-
this_all_flag = np.all(this.flag_array.flatten()[this_inds])
5803-
other_all_zero = np.all(other.data_array.flatten()[other_inds] == 0)
5804-
other_all_flag = np.all(other.flag_array.flatten()[other_inds])
5805-
5806-
if this_all_zero and this_all_flag:
5783+
multi_axis_params = this._get_multi_axis_params()
5784+
this_test = []
5785+
other_test = []
5786+
for param in multi_axis_params:
5787+
form = getattr(this, "_" + param).form
5788+
this_shape = getattr(this, param).shape
5789+
other_shape = getattr(other, param).shape
5790+
this_param_type = getattr(this, "_" + param).expected_type
5791+
bool_type = this_param_type is bool or bool in this_param_type
5792+
5793+
this_index_list = []
5794+
other_index_list = []
5795+
for ax_ind, axis in enumerate(form):
5796+
expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
5797+
this_index_list.append(
5798+
np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5799+
)
5800+
other_index_list.append(
5801+
np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5802+
)
5803+
this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
5804+
5805+
other_inds = np.ravel_multi_index(
5806+
other_index_list, other_shape
5807+
).flatten()
5808+
5809+
this_arr = getattr(this, param).flatten()[this_inds]
5810+
other_arr = getattr(other, param).flatten()[other_inds]
5811+
5812+
if bool_type:
5813+
this_test.append(np.all(this_arr))
5814+
other_test.append(np.all(other_arr))
5815+
else:
5816+
this_test.append(np.all(this_arr == 0))
5817+
other_test.append(np.all(other_arr == 0))
5818+
5819+
if np.all(this_test):
58075820
# we're fine to overwrite; update history accordingly
58085821
history_update_string = " Overwrote invalid data using pyuvdata."
5809-
elif other_all_zero and other_all_flag:
5822+
elif np.all(other_test):
58105823
raise ValueError(
58115824
"To combine these data, please run the add operation again, "
58125825
"but with the object whose data is to be overwritten as the "

tests/uvdata/test_uvdata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3983,6 +3983,7 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2):
39833983
if np.any(np.logical_and(screen1, screen2)):
39843984
flag_screen = screen2[screen1]
39853985
uv1.data_array[:, flag_screen] = 0.0
3986+
uv1.nsample_array[:, flag_screen] = 0.0
39863987
uv1.flag_array[:, flag_screen] = True
39873988

39883989
uv_recomb = getattr(uv1, add_method[0])(uv2, **add_method[1])

0 commit comments

Comments
 (0)