@@ -5777,36 +5777,49 @@ def __add__(
57775777 }
57785778
57795779 history_update_string = ""
5780- # TODO do this programmatically for multidimensional parameters
5781- if not self.metadata_only and np.all(
5782- [len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]
5783- ):
5780+
5781+ if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
57845782 # We have overlaps, check that overlapping data is not valid
5785- this_inds = np.ravel_multi_index(
5786- (
5787- axis_inds["Nblts"]["this"][:, np.newaxis, np.newaxis],
5788- axis_inds["Nfreqs"]["this"][np.newaxis, :, np.newaxis],
5789- axis_inds["Npols"]["this"][np.newaxis, np.newaxis, :],
5790- ),
5791- this.data_array.shape,
5792- ).flatten()
5793- other_inds = np.ravel_multi_index(
5794- (
5795- axis_inds["Nblts"]["other"][:, np.newaxis, np.newaxis],
5796- axis_inds["Nfreqs"]["other"][np.newaxis, :, np.newaxis],
5797- axis_inds["Npols"]["other"][np.newaxis, np.newaxis, :],
5798- ),
5799- other.data_array.shape,
5800- ).flatten()
5801- this_all_zero = np.all(this.data_array.flatten()[this_inds] == 0)
5802- this_all_flag = np.all(this.flag_array.flatten()[this_inds])
5803- other_all_zero = np.all(other.data_array.flatten()[other_inds] == 0)
5804- other_all_flag = np.all(other.flag_array.flatten()[other_inds])
5805-
5806- if this_all_zero and this_all_flag:
5783+ multi_axis_params = this._get_multi_axis_params()
5784+ this_test = []
5785+ other_test = []
5786+ for param in multi_axis_params:
5787+ form = getattr(this, "_" + param).form
5788+ this_shape = getattr(this, param).shape
5789+ other_shape = getattr(other, param).shape
5790+ this_param_type = getattr(this, "_" + param).expected_type
5791+ bool_type = this_param_type is bool or bool in this_param_type
5792+
5793+ this_index_list = []
5794+ other_index_list = []
5795+ for ax_ind, axis in enumerate(form):
5796+ expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
5797+ this_index_list.append(
5798+ np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5799+ )
5800+ other_index_list.append(
5801+ np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5802+ )
5803+ this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
5804+
5805+ other_inds = np.ravel_multi_index(
5806+ other_index_list, other_shape
5807+ ).flatten()
5808+
5809+ this_arr = getattr(this, param).flatten()[this_inds]
5810+ other_arr = getattr(other, param).flatten()[other_inds]
5811+
5812+ if bool_type:
5813+ this_test.append(np.all(this_arr))
5814+ other_test.append(np.all(other_arr))
5815+ else:
5816+ this_test.append(np.all(this_arr == 0))
5817+ other_test.append(np.all(other_arr == 0))
5818+
5819+ if np.all(this_test):
58075820 # we're fine to overwrite; update history accordingly
58085821 history_update_string = " Overwrote invalid data using pyuvdata."
5809- elif other_all_zero and other_all_flag :
5822+ elif np.all(other_test) :
58105823 raise ValueError(
58115824 "To combine these data, please run the add operation again, "
58125825 "but with the object whose data is to be overwritten as the "
0 commit comments