@@ -195,9 +195,7 @@ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
195195 setattr(this, param, new_array)
196196
197197
198- def _fill_multi_helper(
199- this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
200- ):
198+ def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict):
201199 """
202200 Fill UVParameter objects with multiple dimensions from the right side object.
203201
@@ -210,8 +208,6 @@ def _fill_multi_helper(
210208 t2o_dict : dict
211209 dict giving the indices in the left object to be filled from the right
212210 object for each axis (keys are axes, values are index arrays).
213- sort_axes : list of str
214- The axes that need to be sorted along.
215211 order_dict : dict
216212 dict giving the final sort indices for each axis (keys are axes, values
217213 are index arrays for sorting).
@@ -229,7 +225,7 @@ def _fill_multi_helper(
229225
230226 # Fix ordering
231227 for axis_ind, axis in enumerate(form):
232- if axis in sort_axes :
228+ if order_dict[ axis] is not None :
233229 unique_order_diffs = np.unique(np.diff(order_dict[axis]))
234230 if np.array_equal(unique_order_diffs, np.array([1])):
235231 # everything is already in order
@@ -5709,59 +5705,72 @@ def __add__(
57095705 # Define parameters that must be the same to add objects
57105706 compatibility_params = ["_vis_units"]
57115707
5712- # identify params that are not explicitly included in overlap calc per axis
57135708 axes = ["Nblts", "Nfreqs", "Npols"]
5714- axis_params_check = {}
5715- axis_overlap_params = {
5709+ # axis_key_params defines which parameters to use as the defining
5710+ # parameters along each axis. These are used to identify overlapping data.
5711+ axis_key_params = {
57165712 "Nblts": ["time_array", "baseline_array"],
57175713 "Nfreqs": ["freq_array", "flex_spw_id_array"],
57185714 "Npols": ["polarization_array"],
57195715 }
5720- axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5721- axis_dict = {}
5722- for axis in axes:
5723- axis_dict[axis] = this._get_param_axis(axis)
5724- axis_params_check[axis] = []
5725- for param in axis_dict[axis]:
5716+ # specify a function to form a combined string if there are multiple
5717+ # key arrays (e.g. baseline-time, spw-freq)
5718+ axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5719+ multi_axis_params = this._get_multi_axis_params()
5720+ # axis_parameters gives parameters whose form contains each axis
5721+ axis_parameters = {}
5722+ # axis_check_params gives parameters that should be checked if adding
5723+ # along other axes
5724+ axis_check_params = {}
5725+ # axis_key_arrays gives the arrays to use for checking for overlap per axis
5726+ axis_key_arrays = {}
5727+ # axis_overlap_inds has the outcomes of np.intersect1d on the
5728+ # axis_key_arrays per axis. So it has the both/this inds/other inds
5729+ # for any overlaps.
5730+ axis_overlap_inds = {}
5731+ for axis, overlap_params in axis_key_params.items():
5732+ axis_parameters[axis] = this._get_param_axis(axis)
5733+ axis_check_params[axis] = []
5734+ for param in axis_parameters[axis]:
5735+ # get parameters for compatibility checking. Exclude parameters
5736+ # that define overlap and multidimensional parameters which are
5737+ # handled separately later.
57265738 if (
5727- param not in this._data_params
5728- and param not in axis_overlap_params [axis]
5739+ param not in multi_axis_params
5740+ and param not in axis_key_params [axis]
57295741 ):
5730- axis_params_check [axis].append("_" + param)
5742+ axis_check_params [axis].append("_" + param)
57315743
5732- # build this/other arrays for checking for overlap.
5733- # Use a combined string if there are multiple arrays defining overlap
5734- # (e.g. baseline-time, spw-freq)
5735- axis_vals = {}
5736- for axis, overlap_params in axis_overlap_params.items():
5744+ # build this/other arrays for checking for overlap.
57375745 if len(overlap_params) > 1:
5738- axis_vals [axis] = {
5739- "this": getattr(this, axis_combined_func [axis])(),
5740- "other": getattr(other, axis_combined_func [axis])(),
5746+ axis_key_arrays [axis] = {
5747+ "this": getattr(this, axis_key_func [axis])(),
5748+ "other": getattr(other, axis_key_func [axis])(),
57415749 }
57425750 else:
5743- axis_vals [axis] = {
5751+ axis_key_arrays [axis] = {
57445752 "this": getattr(this, overlap_params[0]),
57455753 "other": getattr(other, overlap_params[0]),
57465754 }
57475755
5748- # Check if we have overlapping data
5749- axis_inds = {}
5750- for axis, val_arr in axis_vals.items():
5756+ # Check if we have overlapping data
57515757 both_inds, this_inds, other_inds = np.intersect1d(
5752- val_arr["this"], val_arr["other"], return_indices=True
5758+ axis_key_arrays[axis]["this"],
5759+ axis_key_arrays[axis]["other"],
5760+ return_indices=True,
57535761 )
5754- axis_inds [axis] = {
5762+ axis_overlap_inds [axis] = {
57555763 "this": this_inds,
57565764 "other": other_inds,
57575765 "both": both_inds,
57585766 }
57595767
57605768 history_update_string = ""
57615769
5762- if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
5770+ if np.all(
5771+ [len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds]
5772+ ):
57635773 # We have overlaps, check that overlapping data is not valid
5764- multi_axis_params = this._get_multi_axis_params()
57655774 this_test = []
57665775 other_test = []
57675776 for param in multi_axis_params:
@@ -5776,10 +5785,14 @@ def __add__(
57765785 for ax_ind, axis in enumerate(form):
57775786 expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
57785787 this_index_list.append(
5779- np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5788+ np.expand_dims(
5789+ axis_overlap_inds[axis]["this"], axis=expand_axes
5790+ )
57805791 )
57815792 other_index_list.append(
5782- np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5793+ np.expand_dims(
5794+ axis_overlap_inds[axis]["other"], axis=expand_axes
5795+ )
57835796 )
57845797 this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
57855798
@@ -5811,7 +5824,9 @@ def __add__(
58115824 "These objects have overlapping data and cannot be combined."
58125825 )
58135826
5814- new_inds = {}
5827+ # Now actually find which axes are going to be added along
5828+ # other_inds_use have the indices in other that will be added to this
5829+ other_inds_use = {}
58155830 additions = []
58165831 axis_descriptions = {
58175832 "Nblts": "baseline-time",
@@ -5821,30 +5836,30 @@ def __add__(
58215836 # find the indices in "other" but not in "this"
58225837 for axis in axes:
58235838 temp = np.nonzero(
5824- ~np.isin(axis_vals [axis]["other"], axis_vals [axis]["this"])
5839+ ~np.isin(axis_key_arrays [axis]["other"], axis_key_arrays [axis]["this"])
58255840 )[0]
58265841 if len(temp) > 0:
5827- new_inds [axis] = temp
5842+ other_inds_use [axis] = temp
58285843 # add params associated with the other axes to compatibility_params
58295844 for axis2 in axes:
58305845 if axis2 != axis:
5831- compatibility_params.extend(axis_params_check [axis2])
5846+ compatibility_params.extend(axis_check_params [axis2])
58325847 additions.append(axis_descriptions[axis])
58335848 else:
5834- new_inds [axis] = []
5849+ other_inds_use [axis] = []
58355850
58365851 # Actually check compatibility parameters
58375852 for cp in compatibility_params:
58385853 params_match = None
5839- for axis, check_list in axis_params_check .items():
5854+ for axis, check_list in axis_check_params .items():
58405855 if cp in check_list:
58415856 # only check that overlapping indices match
58425857 this_param = getattr(this, cp)
58435858 this_param_overlap = this_param.get_from_form(
5844- {axis: axis_inds [axis]["this"]}
5859+ {axis: axis_overlap_inds [axis]["this"]}
58455860 )
58465861 other_param_overlap = getattr(other, cp).get_from_form(
5847- {axis: axis_inds [axis]["other"]}
5862+ {axis: axis_overlap_inds [axis]["other"]}
58485863 )
58495864 params_match = np.allclose(
58505865 this_param_overlap,
@@ -5876,7 +5891,7 @@ def __add__(
58765891 "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
58775892 "Npols": {"method": "reorder_pols", "parameter": "order"},
58785893 }
5879- for axis, ind_dict in axis_inds .items():
5894+ for axis, ind_dict in axis_overlap_inds .items():
58805895 if len(ind_dict["this"]) != 0:
58815896 # there is some overlap, so check sorting
58825897 this_argsort = np.argsort(ind_dict["this"])
@@ -5891,16 +5906,22 @@ def __add__(
58915906
58925907 getattr(this, reorder_method[axis]["method"])(**kwargs)
58935908
5894- # start updating parameters
5895- new_axis_inds = {}
5909+ # checks are all done, start updating parameters
5910+ # combined_key_arrays has the final key arrays after adding.
5911+ combined_key_arrays = {}
5912+ # order_dict has info about how to sort each axis. Initialize to None
5913+ # for axes that are not added along (so do not need sorting)
58965914 order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
58975915 for axis in axes:
5898- if len(new_inds[axis]) > 0:
5899- new_axis_inds[axis] = np.concatenate(
5900- (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5916+ if len(other_inds_use[axis]) > 0:
5917+ combined_key_arrays[axis] = np.concatenate(
5918+ (
5919+ axis_key_arrays[axis]["this"],
5920+ axis_key_arrays[axis]["other"][other_inds_use[axis]],
5921+ )
59015922 )
59025923 if axis == "Npols":
5903- order_dict[axis] = np.argsort(np.abs(new_axis_inds [axis]))
5924+ order_dict[axis] = np.argsort(np.abs(combined_key_arrays [axis]))
59045925 elif axis == "Nfreqs" and (
59055926 np.any(np.diff(this.freq_array) < 0)
59065927 or np.any(np.diff(other.freq_array) < 0)
@@ -5911,38 +5932,39 @@ def __add__(
59115932 np.concatenate(
59125933 (
59135934 this.flex_spw_id_array,
5914- other.flex_spw_id_array[new_inds [axis]],
5935+ other.flex_spw_id_array[other_inds_use [axis]],
59155936 )
59165937 ),
59175938 np.concatenate(
5918- (this.freq_array, other.freq_array[new_inds [axis]])
5939+ (this.freq_array, other.freq_array[other_inds_use [axis]])
59195940 ),
59205941 )
59215942 else:
5922- order_dict[axis] = np.argsort(new_axis_inds [axis])
5943+ order_dict[axis] = np.argsort(combined_key_arrays [axis])
59235944
5924- # first handle parameters with a single axis
5925- _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
5945+ # first handle parameters with a single named axis
5946+ _axis_add_helper(
5947+ this, other, axis, other_inds_use[axis], order_dict[axis]
5948+ )
59265949
59275950 # then pad out parameters with multiple axes
5928- _axis_pad_helper(this, axis, len(new_inds [axis]))
5951+ _axis_pad_helper(this, axis, len(other_inds_use [axis]))
59295952 else:
5930- new_axis_inds[axis] = axis_vals[axis]["this"]
5953+ # no add along this axis, so it's the same as what's already on this
5954+ combined_key_arrays[axis] = axis_key_arrays[axis]["this"]
59315955
59325956 # Now fill in multidimensional parameters
5957+ # t2o_dict has the mapping of where arrays on other get mapped into
5958+ # this after padding
59335959 t2o_dict = {}
5934- for axis, inds_dict in axis_vals .items():
5960+ for axis, inds_dict in axis_key_arrays .items():
59355961 t2o_dict[axis] = np.nonzero(
5936- np.isin(new_axis_inds [axis], inds_dict["other"])
5962+ np.isin(combined_key_arrays [axis], inds_dict["other"])
59375963 )[0]
59385964
5939- sort_axes = []
5940- for axis in axes:
5941- if len(new_inds[axis]) > 0:
5942- sort_axes.append(axis)
5943- _fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5965+ _fill_multi_helper(this, other, t2o_dict, order_dict)
59445966
5945- if len(new_inds ["Nfreqs"]) > 0:
5967+ if len(other_inds_use ["Nfreqs"]) > 0:
59465968 # We want to preserve per-spw information based on first appearance
59475969 # in the concatenated array.
59485970 unique_index = np.sort(
@@ -5996,7 +6018,7 @@ def __add__(
59966018 )
59976019
59986020 # Reset blt_order if blt axis was added to
5999- if "Nblts" in sort_axes :
6021+ if order_dict[ "Nblts"] is not None :
60006022 this.blt_order = ("time", "baseline")
60016023
60026024 this.set_rectangularity(force=True)
@@ -6212,18 +6234,18 @@ def fast_concat(
62126234
62136235 # identify params that are not explicitly included in overlap calc per axis
62146236 axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6215- axis_params_check = {}
6216- axis_dict = {}
6237+ axis_check_params = {}
6238+ axis_parameters = {}
62176239 for axis2, ax_shape in axis_shape.items():
6218- axis_dict [axis2] = this._get_param_axis(ax_shape)
6219- axis_params_check [axis2] = []
6220- for param in axis_dict [axis2]:
6240+ axis_parameters [axis2] = this._get_param_axis(ax_shape)
6241+ axis_check_params [axis2] = []
6242+ for param in axis_parameters [axis2]:
62216243 if param not in this._data_params:
6222- axis_params_check [axis2].append("_" + param)
6244+ axis_check_params [axis2].append("_" + param)
62236245
62246246 for axis2 in axis_shape:
62256247 if axis2 != axis:
6226- compatibility_params.extend(axis_params_check [axis2])
6248+ compatibility_params.extend(axis_check_params [axis2])
62276249
62286250 axis_descriptions = {
62296251 "blt": "baseline-time",
0 commit comments