@@ -197,9 +197,7 @@ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
197197 setattr(this, param, new_array)
198198
199199
200- def _fill_multi_helper(
201- this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
202- ):
200+ def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict):
203201 """
204202 Fill UVParameter objects with multiple dimensions from the right side object.
205203
@@ -212,8 +210,6 @@ def _fill_multi_helper(
212210 t2o_dict : dict
213211 dict giving the indices in the left object to be filled from the right
214212 object for each axis (keys are axes, values are index arrays).
215- sort_axes : list of str
216- The axes that need to be sorted along.
217213 order_dict : dict
218214 dict giving the final sort indices for each axis (keys are axes, values
219215 are index arrays for sorting).
@@ -231,7 +227,7 @@ def _fill_multi_helper(
231227
232228 # Fix ordering
233229 for axis_ind, axis in enumerate(form):
234- if axis in sort_axes :
230+ if order_dict[ axis] is not None :
235231 unique_order_diffs = np.unique(np.diff(order_dict[axis]))
236232 if np.array_equal(unique_order_diffs, np.array([1])):
237233 # everything is already in order
@@ -5711,59 +5707,72 @@ def __add__(
57115707 # Define parameters that must be the same to add objects
57125708 compatibility_params = ["_vis_units"]
57135709
5714- # identify params that are not explicitly included in overlap calc per axis
57155710 axes = ["Nblts", "Nfreqs", "Npols"]
5716- axis_params_check = {}
5717- axis_overlap_params = {
5711+ # axis_key_params defines which parameters to use as the defining
5712+ # parameters along each axis. These are used to identify overlapping data.
5713+ axis_key_params = {
57185714 "Nblts": ["time_array", "baseline_array"],
57195715 "Nfreqs": ["freq_array", "flex_spw_id_array"],
57205716 "Npols": ["polarization_array"],
57215717 }
5722- axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5723- axis_dict = {}
5724- for axis in axes:
5725- axis_dict[axis] = this._get_param_axis(axis)
5726- axis_params_check[axis] = []
5727- for param in axis_dict[axis]:
5718+ # specify a function to form a combined string if there are multiple
5719+ # key arrays (e.g. baseline-time, spw-freq)
5720+ axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5721+ multi_axis_params = this._get_multi_axis_params()
5722+ # axis_parameters gives parameters whose form contains each axis
5723+ axis_parameters = {}
5724+ # axis_check_params gives parameters that should be checked if adding
5725+ # along other axes
5726+ axis_check_params = {}
5727+ # axis_key_arrays gives the arrays to use for checking for overlap per axis
5728+ axis_key_arrays = {}
5729+ # axis_overlap_inds has the outcomes of np.intersect1d on the
5730+ # axis_key_arrays per axis. So it has the both/this inds/other inds
5731+ # for any overlaps.
5732+ axis_overlap_inds = {}
5733+ for axis, overlap_params in axis_key_params.items():
5734+ axis_parameters[axis] = this._get_param_axis(axis)
5735+ axis_check_params[axis] = []
5736+ for param in axis_parameters[axis]:
5737+ # get parameters for compatibility checking. Exclude parameters
5738+ # that define overlap and multidimensional parameters which are
5739+ # handled separately later.
57285740 if (
5729- param not in this._data_params
5730- and param not in axis_overlap_params [axis]
5741+ param not in multi_axis_params
5742+ and param not in axis_key_params [axis]
57315743 ):
5732- axis_params_check [axis].append("_" + param)
5744+ axis_check_params [axis].append("_" + param)
57335745
5734- # build this/other arrays for checking for overlap.
5735- # Use a combined string if there are multiple arrays defining overlap
5736- # (e.g. baseline-time, spw-freq)
5737- axis_vals = {}
5738- for axis, overlap_params in axis_overlap_params.items():
5746+ # build this/other arrays for checking for overlap.
57395747 if len(overlap_params) > 1:
5740- axis_vals [axis] = {
5741- "this": getattr(this, axis_combined_func [axis])(),
5742- "other": getattr(other, axis_combined_func [axis])(),
5748+ axis_key_arrays [axis] = {
5749+ "this": getattr(this, axis_key_func [axis])(),
5750+ "other": getattr(other, axis_key_func [axis])(),
57435751 }
57445752 else:
5745- axis_vals [axis] = {
5753+ axis_key_arrays [axis] = {
57465754 "this": getattr(this, overlap_params[0]),
57475755 "other": getattr(other, overlap_params[0]),
57485756 }
57495757
5750- # Check if we have overlapping data
5751- axis_inds = {}
5752- for axis, val_arr in axis_vals.items():
5758+ # Check if we have overlapping data
57535759 both_inds, this_inds, other_inds = np.intersect1d(
5754- val_arr["this"], val_arr["other"], return_indices=True
5760+ axis_key_arrays[axis]["this"],
5761+ axis_key_arrays[axis]["other"],
5762+ return_indices=True,
57555763 )
5756- axis_inds [axis] = {
5764+ axis_overlap_inds [axis] = {
57575765 "this": this_inds,
57585766 "other": other_inds,
57595767 "both": both_inds,
57605768 }
57615769
57625770 history_update_string = ""
57635771
5764- if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
5772+ if np.all(
5773+ [len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds]
5774+ ):
57655775 # We have overlaps, check that overlapping data is not valid
5766- multi_axis_params = this._get_multi_axis_params()
57675776 this_test = []
57685777 other_test = []
57695778 for param in multi_axis_params:
@@ -5778,10 +5787,14 @@ def __add__(
57785787 for ax_ind, axis in enumerate(form):
57795788 expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
57805789 this_index_list.append(
5781- np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5790+ np.expand_dims(
5791+ axis_overlap_inds[axis]["this"], axis=expand_axes
5792+ )
57825793 )
57835794 other_index_list.append(
5784- np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5795+ np.expand_dims(
5796+ axis_overlap_inds[axis]["other"], axis=expand_axes
5797+ )
57855798 )
57865799 this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
57875800
@@ -5813,7 +5826,9 @@ def __add__(
58135826 "These objects have overlapping data and cannot be combined."
58145827 )
58155828
5816- new_inds = {}
5829+ # Now actually find which axes are going to be added along
5830+ # other_inds_use have the indices in other that will be added to this
5831+ other_inds_use = {}
58175832 additions = []
58185833 axis_descriptions = {
58195834 "Nblts": "baseline-time",
@@ -5823,30 +5838,30 @@ def __add__(
58235838 # find the indices in "other" but not in "this"
58245839 for axis in axes:
58255840 temp = np.nonzero(
5826- ~np.isin(axis_vals [axis]["other"], axis_vals [axis]["this"])
5841+ ~np.isin(axis_key_arrays [axis]["other"], axis_key_arrays [axis]["this"])
58275842 )[0]
58285843 if len(temp) > 0:
5829- new_inds [axis] = temp
5844+ other_inds_use [axis] = temp
58305845 # add params associated with the other axes to compatibility_params
58315846 for axis2 in axes:
58325847 if axis2 != axis:
5833- compatibility_params.extend(axis_params_check [axis2])
5848+ compatibility_params.extend(axis_check_params [axis2])
58345849 additions.append(axis_descriptions[axis])
58355850 else:
5836- new_inds [axis] = []
5851+ other_inds_use [axis] = []
58375852
58385853 # Actually check compatibility parameters
58395854 for cp in compatibility_params:
58405855 params_match = None
5841- for axis, check_list in axis_params_check .items():
5856+ for axis, check_list in axis_check_params .items():
58425857 if cp in check_list:
58435858 # only check that overlapping indices match
58445859 this_param = getattr(this, cp)
58455860 this_param_overlap = this_param.get_from_form(
5846- {axis: axis_inds [axis]["this"]}
5861+ {axis: axis_overlap_inds [axis]["this"]}
58475862 )
58485863 other_param_overlap = getattr(other, cp).get_from_form(
5849- {axis: axis_inds [axis]["other"]}
5864+ {axis: axis_overlap_inds [axis]["other"]}
58505865 )
58515866 params_match = np.allclose(
58525867 this_param_overlap,
@@ -5878,7 +5893,7 @@ def __add__(
58785893 "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
58795894 "Npols": {"method": "reorder_pols", "parameter": "order"},
58805895 }
5881- for axis, ind_dict in axis_inds .items():
5896+ for axis, ind_dict in axis_overlap_inds .items():
58825897 if len(ind_dict["this"]) != 0:
58835898 # there is some overlap, so check sorting
58845899 this_argsort = np.argsort(ind_dict["this"])
@@ -5893,16 +5908,22 @@ def __add__(
58935908
58945909 getattr(this, reorder_method[axis]["method"])(**kwargs)
58955910
5896- # start updating parameters
5897- new_axis_inds = {}
5911+ # checks are all done, start updating parameters
5912+ # combined_key_arrays has the final key arrays after adding.
5913+ combined_key_arrays = {}
5914+ # order_dict has info about how to sort each axis. Initialize to None
5915+ # for axes that are not added along (so do not need sorting)
58985916 order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
58995917 for axis in axes:
5900- if len(new_inds[axis]) > 0:
5901- new_axis_inds[axis] = np.concatenate(
5902- (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5918+ if len(other_inds_use[axis]) > 0:
5919+ combined_key_arrays[axis] = np.concatenate(
5920+ (
5921+ axis_key_arrays[axis]["this"],
5922+ axis_key_arrays[axis]["other"][other_inds_use[axis]],
5923+ )
59035924 )
59045925 if axis == "Npols":
5905- order_dict[axis] = np.argsort(np.abs(new_axis_inds [axis]))
5926+ order_dict[axis] = np.argsort(np.abs(combined_key_arrays [axis]))
59065927 elif axis == "Nfreqs" and (
59075928 np.any(np.diff(this.freq_array) < 0)
59085929 or np.any(np.diff(other.freq_array) < 0)
@@ -5913,38 +5934,39 @@ def __add__(
59135934 np.concatenate(
59145935 (
59155936 this.flex_spw_id_array,
5916- other.flex_spw_id_array[new_inds [axis]],
5937+ other.flex_spw_id_array[other_inds_use [axis]],
59175938 )
59185939 ),
59195940 np.concatenate(
5920- (this.freq_array, other.freq_array[new_inds [axis]])
5941+ (this.freq_array, other.freq_array[other_inds_use [axis]])
59215942 ),
59225943 )
59235944 else:
5924- order_dict[axis] = np.argsort(new_axis_inds [axis])
5945+ order_dict[axis] = np.argsort(combined_key_arrays [axis])
59255946
5926- # first handle parameters with a single axis
5927- _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
5947+ # first handle parameters with a single named axis
5948+ _axis_add_helper(
5949+ this, other, axis, other_inds_use[axis], order_dict[axis]
5950+ )
59285951
59295952 # then pad out parameters with multiple axes
5930- _axis_pad_helper(this, axis, len(new_inds [axis]))
5953+ _axis_pad_helper(this, axis, len(other_inds_use [axis]))
59315954 else:
5932- new_axis_inds[axis] = axis_vals[axis]["this"]
5955+ # no add along this axis, so it's the same as what's already on this
5956+ combined_key_arrays[axis] = axis_key_arrays[axis]["this"]
59335957
59345958 # Now fill in multidimensional parameters
5959+ # t2o_dict has the mapping of where arrays on other get mapped into
5960+ # this after padding
59355961 t2o_dict = {}
5936- for axis, inds_dict in axis_vals .items():
5962+ for axis, inds_dict in axis_key_arrays .items():
59375963 t2o_dict[axis] = np.nonzero(
5938- np.isin(new_axis_inds [axis], inds_dict["other"])
5964+ np.isin(combined_key_arrays [axis], inds_dict["other"])
59395965 )[0]
59405966
5941- sort_axes = []
5942- for axis in axes:
5943- if len(new_inds[axis]) > 0:
5944- sort_axes.append(axis)
5945- _fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5967+ _fill_multi_helper(this, other, t2o_dict, order_dict)
59465968
5947- if len(new_inds ["Nfreqs"]) > 0:
5969+ if len(other_inds_use ["Nfreqs"]) > 0:
59485970 # We want to preserve per-spw information based on first appearance
59495971 # in the concatenated array.
59505972 unique_index = np.sort(
@@ -5998,7 +6020,7 @@ def __add__(
59986020 )
59996021
60006022 # Reset blt_order if blt axis was added to
6001- if "Nblts" in sort_axes :
6023+ if order_dict[ "Nblts"] is not None :
60026024 this.blt_order = ("time", "baseline")
60036025
60046026 this.set_rectangularity(force=True)
@@ -6214,18 +6236,18 @@ def fast_concat(
62146236
62156237 # identify params that are not explicitly included in overlap calc per axis
62166238 axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6217- axis_params_check = {}
6218- axis_dict = {}
6239+ axis_check_params = {}
6240+ axis_parameters = {}
62196241 for axis2, ax_shape in axis_shape.items():
6220- axis_dict [axis2] = this._get_param_axis(ax_shape)
6221- axis_params_check [axis2] = []
6222- for param in axis_dict [axis2]:
6242+ axis_parameters [axis2] = this._get_param_axis(ax_shape)
6243+ axis_check_params [axis2] = []
6244+ for param in axis_parameters [axis2]:
62236245 if param not in this._data_params:
6224- axis_params_check [axis2].append("_" + param)
6246+ axis_check_params [axis2].append("_" + param)
62256247
62266248 for axis2 in axis_shape:
62276249 if axis2 != axis:
6228- compatibility_params.extend(axis_params_check [axis2])
6250+ compatibility_params.extend(axis_check_params [axis2])
62296251
62306252 axis_descriptions = {
62316253 "blt": "baseline-time",
0 commit comments