@@ -5758,36 +5758,49 @@ def __add__(
57585758 }
57595759
57605760 history_update_string = ""
5761- # TODO do this programmatically for multidimensional parameters
5762- if not self.metadata_only and np.all(
5763- [len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]
5764- ):
5761+
5762+ if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
57655763 # We have overlaps, check that overlapping data is not valid
5766- this_inds = np.ravel_multi_index(
5767- (
5768- axis_inds["Nblts"]["this"][:, np.newaxis, np.newaxis],
5769- axis_inds["Nfreqs"]["this"][np.newaxis, :, np.newaxis],
5770- axis_inds["Npols"]["this"][np.newaxis, np.newaxis, :],
5771- ),
5772- this.data_array.shape,
5773- ).flatten()
5774- other_inds = np.ravel_multi_index(
5775- (
5776- axis_inds["Nblts"]["other"][:, np.newaxis, np.newaxis],
5777- axis_inds["Nfreqs"]["other"][np.newaxis, :, np.newaxis],
5778- axis_inds["Npols"]["other"][np.newaxis, np.newaxis, :],
5779- ),
5780- other.data_array.shape,
5781- ).flatten()
5782- this_all_zero = np.all(this.data_array.flatten()[this_inds] == 0)
5783- this_all_flag = np.all(this.flag_array.flatten()[this_inds])
5784- other_all_zero = np.all(other.data_array.flatten()[other_inds] == 0)
5785- other_all_flag = np.all(other.flag_array.flatten()[other_inds])
5786-
5787- if this_all_zero and this_all_flag:
5764+ multi_axis_params = this._get_multi_axis_params()
5765+ this_test = []
5766+ other_test = []
5767+ for param in multi_axis_params:
5768+ form = getattr(this, "_" + param).form
5769+ this_shape = getattr(this, param).shape
5770+ other_shape = getattr(other, param).shape
5771+ this_param_type = getattr(this, "_" + param).expected_type
5772+ bool_type = this_param_type is bool or bool in this_param_type
5773+
5774+ this_index_list = []
5775+ other_index_list = []
5776+ for ax_ind, axis in enumerate(form):
5777+ expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
5778+ this_index_list.append(
5779+ np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5780+ )
5781+ other_index_list.append(
5782+ np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5783+ )
5784+ this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
5785+
5786+ other_inds = np.ravel_multi_index(
5787+ other_index_list, other_shape
5788+ ).flatten()
5789+
5790+ this_arr = getattr(this, param).flatten()[this_inds]
5791+ other_arr = getattr(other, param).flatten()[other_inds]
5792+
5793+ if bool_type:
5794+ this_test.append(np.all(this_arr))
5795+ other_test.append(np.all(other_arr))
5796+ else:
5797+ this_test.append(np.all(this_arr == 0))
5798+ other_test.append(np.all(other_arr == 0))
5799+
5800+ if np.all(this_test):
57885801 # we're fine to overwrite; update history accordingly
57895802 history_update_string = " Overwrote invalid data using pyuvdata."
5790- elif other_all_zero and other_all_flag :
5803+ elif np.all(other_test) :
57915804 raise ValueError(
57925805 "To combine these data, please run the add operation again, "
57935806 "but with the object whose data is to be overwritten as the "
0 commit comments