@@ -195,9 +195,7 @@ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
195195 setattr(this, param, new_array)
196196
197197
198- def _fill_multi_helper(
199- this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
200- ):
198+ def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict):
201199 """
202200 Fill UVParameter objects with multiple dimensions from the right side object.
203201
@@ -210,8 +208,6 @@ def _fill_multi_helper(
210208 t2o_dict : dict
211209 dict giving the indices in the left object to be filled from the right
212210 object for each axis (keys are axes, values are index arrays).
213- sort_axes : list of str
214- The axes that need to be sorted along.
215211 order_dict : dict
216212 dict giving the final sort indices for each axis (keys are axes, values
217213 are index arrays for sorting).
@@ -229,7 +225,7 @@ def _fill_multi_helper(
229225
230226 # Fix ordering
231227 for axis_ind, axis in enumerate(form):
232- if axis in sort_axes :
228+ if order_dict[ axis] is not None :
233229 unique_order_diffs = np.unique(np.diff(order_dict[axis]))
234230 if np.array_equal(unique_order_diffs, np.array([1])):
235231 # everything is already in order
@@ -5728,59 +5724,72 @@ def __add__(
57285724 # Define parameters that must be the same to add objects
57295725 compatibility_params = ["_vis_units"]
57305726
5731- # identify params that are not explicitly included in overlap calc per axis
57325727 axes = ["Nblts", "Nfreqs", "Npols"]
5733- axis_params_check = {}
5734- axis_overlap_params = {
5728+ # axis_key_params defines which parameters to use as the defining
5729+ # parameters along each axis. These are used to identify overlapping data.
5730+ axis_key_params = {
57355731 "Nblts": ["time_array", "baseline_array"],
57365732 "Nfreqs": ["freq_array", "flex_spw_id_array"],
57375733 "Npols": ["polarization_array"],
57385734 }
5739- axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5740- axis_dict = {}
5741- for axis in axes:
5742- axis_dict[axis] = this._get_param_axis(axis)
5743- axis_params_check[axis] = []
5744- for param in axis_dict[axis]:
5735+ # specify a function to form a combined string if there are multiple
5736+ # key arrays (e.g. baseline-time, spw-freq)
5737+ axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5738+ multi_axis_params = this._get_multi_axis_params()
5739+ # axis_parameters gives parameters whose form contains each axis
5740+ axis_parameters = {}
5741+ # axis_check_params gives parameters that should be checked if adding
5742+ # along other axes
5743+ axis_check_params = {}
5744+ # axis_key_arrays gives the arrays to use for checking for overlap per axis
5745+ axis_key_arrays = {}
5746+ # axis_overlap_inds has the outcomes of np.intersect1d on the
5747+ # axis_key_arrays per axis. So it has the both/this inds/other inds
5748+ # for any overlaps.
5749+ axis_overlap_inds = {}
5750+ for axis, overlap_params in axis_key_params.items():
5751+ axis_parameters[axis] = this._get_param_axis(axis)
5752+ axis_check_params[axis] = []
5753+ for param in axis_parameters[axis]:
5754+ # get parameters for compatibility checking. Exclude parameters
5755+ # that define overlap and multidimensional parameters which are
5756+ # handled separately later.
57455757 if (
5746- param not in this._data_params
5747- and param not in axis_overlap_params [axis]
5758+ param not in multi_axis_params
5759+ and param not in axis_key_params [axis]
57485760 ):
5749- axis_params_check [axis].append("_" + param)
5761+ axis_check_params [axis].append("_" + param)
57505762
5751- # build this/other arrays for checking for overlap.
5752- # Use a combined string if there are multiple arrays defining overlap
5753- # (e.g. baseline-time, spw-freq)
5754- axis_vals = {}
5755- for axis, overlap_params in axis_overlap_params.items():
5763+ # build this/other arrays for checking for overlap.
57565764 if len(overlap_params) > 1:
5757- axis_vals [axis] = {
5758- "this": getattr(this, axis_combined_func [axis])(),
5759- "other": getattr(other, axis_combined_func [axis])(),
5765+ axis_key_arrays [axis] = {
5766+ "this": getattr(this, axis_key_func [axis])(),
5767+ "other": getattr(other, axis_key_func [axis])(),
57605768 }
57615769 else:
5762- axis_vals [axis] = {
5770+ axis_key_arrays [axis] = {
57635771 "this": getattr(this, overlap_params[0]),
57645772 "other": getattr(other, overlap_params[0]),
57655773 }
57665774
5767- # Check if we have overlapping data
5768- axis_inds = {}
5769- for axis, val_arr in axis_vals.items():
5775+ # Check if we have overlapping data
57705776 both_inds, this_inds, other_inds = np.intersect1d(
5771- val_arr["this"], val_arr["other"], return_indices=True
5777+ axis_key_arrays[axis]["this"],
5778+ axis_key_arrays[axis]["other"],
5779+ return_indices=True,
57725780 )
5773- axis_inds [axis] = {
5781+ axis_overlap_inds [axis] = {
57745782 "this": this_inds,
57755783 "other": other_inds,
57765784 "both": both_inds,
57775785 }
57785786
57795787 history_update_string = ""
57805788
5781- if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
5789+ if np.all(
5790+ [len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds]
5791+ ):
57825792 # We have overlaps, check that overlapping data is not valid
5783- multi_axis_params = this._get_multi_axis_params()
57845793 this_test = []
57855794 other_test = []
57865795 for param in multi_axis_params:
@@ -5795,10 +5804,14 @@ def __add__(
57955804 for ax_ind, axis in enumerate(form):
57965805 expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
57975806 this_index_list.append(
5798- np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5807+ np.expand_dims(
5808+ axis_overlap_inds[axis]["this"], axis=expand_axes
5809+ )
57995810 )
58005811 other_index_list.append(
5801- np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5812+ np.expand_dims(
5813+ axis_overlap_inds[axis]["other"], axis=expand_axes
5814+ )
58025815 )
58035816 this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
58045817
@@ -5830,7 +5843,9 @@ def __add__(
58305843 "These objects have overlapping data and cannot be combined."
58315844 )
58325845
5833- new_inds = {}
5846+ # Now actually find which axes are going to be added along
5847+ # other_inds_use have the indices in other that will be added to this
5848+ other_inds_use = {}
58345849 additions = []
58355850 axis_descriptions = {
58365851 "Nblts": "baseline-time",
@@ -5840,30 +5855,30 @@ def __add__(
58405855 # find the indices in "other" but not in "this"
58415856 for axis in axes:
58425857 temp = np.nonzero(
5843- ~np.isin(axis_vals [axis]["other"], axis_vals [axis]["this"])
5858+ ~np.isin(axis_key_arrays [axis]["other"], axis_key_arrays [axis]["this"])
58445859 )[0]
58455860 if len(temp) > 0:
5846- new_inds [axis] = temp
5861+ other_inds_use [axis] = temp
58475862 # add params associated with the other axes to compatibility_params
58485863 for axis2 in axes:
58495864 if axis2 != axis:
5850- compatibility_params.extend(axis_params_check [axis2])
5865+ compatibility_params.extend(axis_check_params [axis2])
58515866 additions.append(axis_descriptions[axis])
58525867 else:
5853- new_inds [axis] = []
5868+ other_inds_use [axis] = []
58545869
58555870 # Actually check compatibility parameters
58565871 for cp in compatibility_params:
58575872 params_match = None
5858- for axis, check_list in axis_params_check .items():
5873+ for axis, check_list in axis_check_params .items():
58595874 if cp in check_list:
58605875 # only check that overlapping indices match
58615876 this_param = getattr(this, cp)
58625877 this_param_overlap = this_param.get_from_form(
5863- {axis: axis_inds [axis]["this"]}
5878+ {axis: axis_overlap_inds [axis]["this"]}
58645879 )
58655880 other_param_overlap = getattr(other, cp).get_from_form(
5866- {axis: axis_inds [axis]["other"]}
5881+ {axis: axis_overlap_inds [axis]["other"]}
58675882 )
58685883 params_match = np.allclose(
58695884 this_param_overlap,
@@ -5895,7 +5910,7 @@ def __add__(
58955910 "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
58965911 "Npols": {"method": "reorder_pols", "parameter": "order"},
58975912 }
5898- for axis, ind_dict in axis_inds .items():
5913+ for axis, ind_dict in axis_overlap_inds .items():
58995914 if len(ind_dict["this"]) != 0:
59005915 # there is some overlap, so check sorting
59015916 this_argsort = np.argsort(ind_dict["this"])
@@ -5910,16 +5925,22 @@ def __add__(
59105925
59115926 getattr(this, reorder_method[axis]["method"])(**kwargs)
59125927
5913- # start updating parameters
5914- new_axis_inds = {}
5928+ # checks are all done, start updating parameters
5929+ # combined_key_arrays has the final key arrays after adding.
5930+ combined_key_arrays = {}
5931+ # order_dict has info about how to sort each axis. Initialize to None
5932+ # for axes that are not added along (so do not need sorting)
59155933 order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
59165934 for axis in axes:
5917- if len(new_inds[axis]) > 0:
5918- new_axis_inds[axis] = np.concatenate(
5919- (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5935+ if len(other_inds_use[axis]) > 0:
5936+ combined_key_arrays[axis] = np.concatenate(
5937+ (
5938+ axis_key_arrays[axis]["this"],
5939+ axis_key_arrays[axis]["other"][other_inds_use[axis]],
5940+ )
59205941 )
59215942 if axis == "Npols":
5922- order_dict[axis] = np.argsort(np.abs(new_axis_inds [axis]))
5943+ order_dict[axis] = np.argsort(np.abs(combined_key_arrays [axis]))
59235944 elif axis == "Nfreqs" and (
59245945 np.any(np.diff(this.freq_array) < 0)
59255946 or np.any(np.diff(other.freq_array) < 0)
@@ -5930,38 +5951,39 @@ def __add__(
59305951 np.concatenate(
59315952 (
59325953 this.flex_spw_id_array,
5933- other.flex_spw_id_array[new_inds [axis]],
5954+ other.flex_spw_id_array[other_inds_use [axis]],
59345955 )
59355956 ),
59365957 np.concatenate(
5937- (this.freq_array, other.freq_array[new_inds [axis]])
5958+ (this.freq_array, other.freq_array[other_inds_use [axis]])
59385959 ),
59395960 )
59405961 else:
5941- order_dict[axis] = np.argsort(new_axis_inds [axis])
5962+ order_dict[axis] = np.argsort(combined_key_arrays [axis])
59425963
5943- # first handle parameters with a single axis
5944- _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
5964+ # first handle parameters with a single named axis
5965+ _axis_add_helper(
5966+ this, other, axis, other_inds_use[axis], order_dict[axis]
5967+ )
59455968
59465969 # then pad out parameters with multiple axes
5947- _axis_pad_helper(this, axis, len(new_inds [axis]))
5970+ _axis_pad_helper(this, axis, len(other_inds_use [axis]))
59485971 else:
5949- new_axis_inds[axis] = axis_vals[axis]["this"]
5972+ # no add along this axis, so it's the same as what's already on this
5973+ combined_key_arrays[axis] = axis_key_arrays[axis]["this"]
59505974
59515975 # Now fill in multidimensional parameters
5976+ # t2o_dict has the mapping of where arrays on other get mapped into
5977+ # this after padding
59525978 t2o_dict = {}
5953- for axis, inds_dict in axis_vals .items():
5979+ for axis, inds_dict in axis_key_arrays .items():
59545980 t2o_dict[axis] = np.nonzero(
5955- np.isin(new_axis_inds [axis], inds_dict["other"])
5981+ np.isin(combined_key_arrays [axis], inds_dict["other"])
59565982 )[0]
59575983
5958- sort_axes = []
5959- for axis in axes:
5960- if len(new_inds[axis]) > 0:
5961- sort_axes.append(axis)
5962- _fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5984+ _fill_multi_helper(this, other, t2o_dict, order_dict)
59635985
5964- if len(new_inds ["Nfreqs"]) > 0:
5986+ if len(other_inds_use ["Nfreqs"]) > 0:
59655987 # We want to preserve per-spw information based on first appearance
59665988 # in the concatenated array.
59675989 unique_index = np.sort(
@@ -6015,7 +6037,7 @@ def __add__(
60156037 )
60166038
60176039 # Reset blt_order if blt axis was added to
6018- if "Nblts" in sort_axes :
6040+ if order_dict[ "Nblts"] is not None :
60196041 this.blt_order = ("time", "baseline")
60206042
60216043 this.set_rectangularity(force=True)
@@ -6231,18 +6253,18 @@ def fast_concat(
62316253
62326254 # identify params that are not explicitly included in overlap calc per axis
62336255 axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6234- axis_params_check = {}
6235- axis_dict = {}
6256+ axis_check_params = {}
6257+ axis_parameters = {}
62366258 for axis2, ax_shape in axis_shape.items():
6237- axis_dict [axis2] = this._get_param_axis(ax_shape)
6238- axis_params_check [axis2] = []
6239- for param in axis_dict [axis2]:
6259+ axis_parameters [axis2] = this._get_param_axis(ax_shape)
6260+ axis_check_params [axis2] = []
6261+ for param in axis_parameters [axis2]:
62406262 if param not in this._data_params:
6241- axis_params_check [axis2].append("_" + param)
6263+ axis_check_params [axis2].append("_" + param)
62426264
62436265 for axis2 in axis_shape:
62446266 if axis2 != axis:
6245- compatibility_params.extend(axis_params_check [axis2])
6267+ compatibility_params.extend(axis_check_params [axis2])
62466268
62476269 axis_descriptions = {
62486270 "blt": "baseline-time",
0 commit comments