@@ -5502,61 +5502,85 @@ def __add__(
55025502 # Define parameters that must be the same to add objects
55035503 compatibility_params = ["_vis_units"]
55045504
5505- axes = ["Nblts", "Nfreqs", "Npols"]
5506- # axis_key_params defines which parameters to use as the defining
5507- # parameters along each axis. These are used to identify overlapping data.
5508- axis_key_params = {
5509- "Nblts": ["time_array", "baseline_array"],
5510- "Nfreqs": ["freq_array", "flex_spw_id_array"],
5511- "Npols": ["polarization_array"],
5505+ # setup a dict to carry all the axis-specific info we need throughout
5506+ # the add process:
5507+ # - description is used in history string
5508+ # - key_params defines which parameters to use as the defining
5509+ # parameters along each axis. These are used to identify overlapping data.
5510+ # - key_func specifies a function to form a combined string if there are
5511+ # multiple key arrays (e.g. baseline-time, spw-freq)
5512+ # - reorder gives method & parameters for reording along each axis
5513+ # - order has info about how to sort each axis. Initialize to None
5514+ # for axes that are not added along (so do not need sorting),
5515+ # updated later.
5516+ # ---added later---
5517+ # - check_params gives parameters that should be checked if adding
5518+ # along other axes
5519+ # - key_arrays gives the arrays to use for checking for overlap per axis
5520+ # - overlap_inds has the outcomes of np.intersect1d on the key_arrays
5521+ # between this and other. So it has the both/this/other inds
5522+ # for any overlaps.
5523+ # - other_inds_use has the indices in other that will be added to this
5524+ # - combined_key_arrays has the final key arrays after adding.
5525+ # - t2o has the mapping of where arrays on other get mapped into this
5526+ # along each axis after padding
5527+
5528+ axis_info = {
5529+ "Nblts": {
5530+ "description": "baseline-time",
5531+ "key_params": ["time_array", "baseline_array"],
5532+ "key_func": "blt_str_arr",
5533+ "reorder": {"method": "reorder_blts", "parameter": "order"},
5534+ "order": None,
5535+ },
5536+ "Nfreqs": {
5537+ "description": "frequency",
5538+ "key_params": ["freq_array", "flex_spw_id_array"],
5539+ "key_func": "spw_freq_str_arr",
5540+ "reorder": {"method": "reorder_freqs", "parameter": "channel_order"},
5541+ "order": None,
5542+ },
5543+ "Npols": {
5544+ "description": "polarization",
5545+ "key_params": ["polarization_array"],
5546+ "reorder": {"method": "reorder_pols", "parameter": "order"},
5547+ "order": None,
5548+ },
55125549 }
5513- # specify a function to form a combined string if there are multiple
5514- # key arrays (e.g. baseline-time, spw-freq)
5515- axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5516- multi_axis_params = this._get_multi_axis_params()
5517- # axis_parameters gives parameters whose form contains each axis
5518- axis_parameters = {}
5519- # axis_check_params gives parameters that should be checked if adding
5520- # along other axes
5521- axis_check_params = {}
5522- # axis_key_arrays gives the arrays to use for checking for overlap per axis
5523- axis_key_arrays = {}
5524- # axis_overlap_inds has the outcomes of np.intersect1d on the
5525- # axis_key_arrays per axis. So it has the both/this inds/other inds
5526- # for any overlaps.
5527- axis_overlap_inds = {}
5528- for axis, overlap_params in axis_key_params.items():
5529- axis_parameters[axis] = this._get_param_axis(axis)
5530- axis_check_params[axis] = []
5531- for param in axis_parameters[axis]:
5532- # get parameters for compatibility checking. Exclude parameters
5533- # that define overlap and multidimensional parameters which are
5534- # handled separately later.
5535- if (
5536- param not in multi_axis_params
5537- and param not in axis_key_params[axis]
5538- ):
5539- axis_check_params[axis].append("_" + param)
5550+
5551+ for axis, info in axis_info.items():
5552+ # get parameters for compatibility checking. Exclude multidimensional
5553+ # parameters which are handled separately later.
5554+ params_this_axis = this._get_param_axis(axis, single_named_axis=True)
5555+ info["check_params"] = []
5556+ for param in params_this_axis:
5557+ # Also exclude parameters that define overlap
5558+ if param not in info["key_params"]:
5559+ info["check_params"].append("_" + param)
55405560
55415561 # build this/other arrays for checking for overlap.
5542- if len(overlap_params) > 1:
5543- axis_key_arrays[axis] = {
5544- "this": getattr(this, axis_key_func[axis])(),
5545- "other": getattr(other, axis_key_func[axis])(),
5562+ # key_arrays gives the arrays to use for checking for overlap per axis
5563+ if len(info["key_params"]) > 1:
5564+ info["key_arrays"] = {
5565+ "this": getattr(this, info["key_func"])(),
5566+ "other": getattr(other, info["key_func"])(),
55465567 }
55475568 else:
5548- axis_key_arrays[axis ] = {
5549- "this": getattr(this, overlap_params [0]),
5550- "other": getattr(other, overlap_params [0]),
5569+ info["key_arrays" ] = {
5570+ "this": getattr(this, info["key_params"] [0]),
5571+ "other": getattr(other, info["key_params"] [0]),
55515572 }
55525573
55535574 # Check if we have overlapping data
55545575 both_inds, this_inds, other_inds = np.intersect1d(
5555- axis_key_arrays[axis ]["this"],
5556- axis_key_arrays[axis ]["other"],
5576+ info["key_arrays" ]["this"],
5577+ info["key_arrays" ]["other"],
55575578 return_indices=True,
55585579 )
5559- axis_overlap_inds[axis] = {
5580+ # overlap_inds has the outcomes of np.intersect1d on the
5581+ # key_arrays per axis. So it has the both/this inds/other inds
5582+ # for any overlaps.
5583+ info["overlap_inds"] = {
55605584 "this": this_inds,
55615585 "other": other_inds,
55625586 "both": both_inds,
@@ -5565,11 +5589,12 @@ def __add__(
55655589 history_update_string = ""
55665590
55675591 if np.all(
5568- [len(axis_overlap_inds [axis]["both"]) > 0 for axis in axis_overlap_inds ]
5592+ [len(axis_info [axis]["overlap_inds"][" both"]) > 0 for axis in axis_info ]
55695593 ):
55705594 # We have overlaps, check that overlapping data is not valid
55715595 this_test = []
55725596 other_test = []
5597+ multi_axis_params = this._get_multi_axis_params()
55735598 for param in multi_axis_params:
55745599 form = getattr(this, "_" + param).form
55755600 this_shape = getattr(this, param).shape
@@ -5583,12 +5608,12 @@ def __add__(
55835608 expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
55845609 this_index_list.append(
55855610 np.expand_dims(
5586- axis_overlap_inds [axis]["this"], axis=expand_axes
5611+ axis_info [axis]["overlap_inds" ]["this"], axis=expand_axes
55875612 )
55885613 )
55895614 other_index_list.append(
55905615 np.expand_dims(
5591- axis_overlap_inds [axis]["other"], axis=expand_axes
5616+ axis_info [axis]["overlap_inds" ]["other"], axis=expand_axes
55925617 )
55935618 )
55945619 this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
@@ -5622,41 +5647,35 @@ def __add__(
56225647 )
56235648
56245649 # Now actually find which axes are going to be added along
5625- # other_inds_use have the indices in other that will be added to this
5626- other_inds_use = {}
56275650 additions = []
5628- axis_descriptions = {
5629- "Nblts": "baseline-time",
5630- "Nfreqs": "frequency",
5631- "Npols": "polarization",
5632- }
56335651 # find the indices in "other" but not in "this"
5634- for axis in axes :
5652+ for axis, info in axis_info.items() :
56355653 temp = np.nonzero(
5636- ~np.isin(axis_key_arrays[axis ]["other"], axis_key_arrays[axis ]["this"])
5654+ ~np.isin(info["key_arrays" ]["other"], info["key_arrays" ]["this"])
56375655 )[0]
56385656 if len(temp) > 0:
5639- other_inds_use[axis] = temp
5657+ # other_inds_use has the indices in other that will be added to this
5658+ info["other_inds_use"] = temp
56405659 # add params associated with the other axes to compatibility_params
5641- for axis2 in axes :
5660+ for axis2 in axis_info :
56425661 if axis2 != axis:
5643- compatibility_params.extend(axis_check_params [axis2])
5644- additions.append(axis_descriptions[axis ])
5662+ compatibility_params.extend(axis_info [axis2]["check_params" ])
5663+ additions.append(info["description" ])
56455664 else:
5646- other_inds_use[axis ] = []
5665+ info["other_inds_use" ] = []
56475666
56485667 # Actually check compatibility parameters
56495668 for cp in compatibility_params:
56505669 params_match = None
5651- for axis, check_list in axis_check_params .items():
5652- if cp in check_list :
5670+ for axis, info in axis_info .items():
5671+ if cp in info["check_params"] :
56535672 # only check that overlapping indices match
56545673 this_param = getattr(this, cp)
56555674 this_param_overlap = this_param.get_from_form(
5656- {axis: axis_overlap_inds[axis ]["this"]}
5675+ {axis: info["overlap_inds" ]["this"]}
56575676 )
56585677 other_param_overlap = getattr(other, cp).get_from_form(
5659- {axis: axis_overlap_inds[axis ]["other"]}
5678+ {axis: info["overlap_inds" ]["other"]}
56605679 )
56615680 params_match = np.allclose(
56625681 this_param_overlap,
@@ -5683,85 +5702,80 @@ def __add__(
56835702
56845703 # Next, we want to make sure that the ordering of the _overlapping_ data is
56855704 # the same, so that things can get plugged together in a sensible way.
5686- reorder_method = {
5687- "Nblts": {"method": "reorder_blts", "parameter": "order"},
5688- "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
5689- "Npols": {"method": "reorder_pols", "parameter": "order"},
5690- }
5691- for axis, ind_dict in axis_overlap_inds.items():
5692- if len(ind_dict["this"]) != 0:
5705+ for axis, info in axis_info.items():
5706+ if len(info["overlap_inds"]["this"]) != 0:
56935707 # there is some overlap, so check sorting
5694- this_argsort = np.argsort(ind_dict ["this"])
5695- other_argsort = np.argsort(ind_dict ["other"])
5708+ this_argsort = np.argsort(info["overlap_inds"] ["this"])
5709+ other_argsort = np.argsort(info["overlap_inds"] ["other"])
56965710
56975711 if np.any(this_argsort != other_argsort):
56985712 temp_ind = np.arange(getattr(this, axis))
5699- temp_ind[ind_dict ["this"][this_argsort]] = temp_ind[
5700- ind_dict ["this"][other_argsort]
5713+ temp_ind[info["overlap_inds"] ["this"][this_argsort]] = temp_ind[
5714+ info["overlap_inds"] ["this"][other_argsort]
57015715 ]
5702- kwargs = {reorder_method[axis ]["parameter"]: temp_ind}
5716+ kwargs = {info["reorder" ]["parameter"]: temp_ind}
57035717
5704- getattr(this, reorder_method[axis ]["method"])(**kwargs)
5718+ getattr(this, info["reorder" ]["method"])(**kwargs)
57055719
57065720 # checks are all done, start updating parameters
5707- # combined_key_arrays has the final key arrays after adding.
5708- combined_key_arrays = {}
5709- # order_dict has info about how to sort each axis. Initialize to None
5710- # for axes that are not added along (so do not need sorting)
5711- order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
5712- for axis in axes:
5713- if len(other_inds_use[axis]) > 0:
5714- combined_key_arrays[axis] = np.concatenate(
5721+ for axis, info in axis_info.items():
5722+ if len(info["other_inds_use"]) > 0:
5723+ # combined_key_arrays has the final key arrays after adding.
5724+ info["combined_key_arrays"] = np.concatenate(
57155725 (
5716- axis_key_arrays[axis ]["this"],
5717- axis_key_arrays[axis ]["other"][other_inds_use[axis ]],
5726+ info["key_arrays" ]["this"],
5727+ info["key_arrays" ]["other"][info["other_inds_use" ]],
57185728 )
57195729 )
57205730 if axis == "Npols":
5721- order_dict[axis] = np.argsort(np.abs(combined_key_arrays[axis]))
5731+ # order has info about how to sort each axis.
5732+ info["order"] = np.argsort(np.abs(info["combined_key_arrays"]))
57225733 elif axis == "Nfreqs" and (
57235734 np.any(np.diff(this.freq_array) < 0)
57245735 or np.any(np.diff(other.freq_array) < 0)
57255736 ):
57265737 # deal with the possibility of spws with channels in
57275738 # descending order.
5728- order_dict[axis ] = utils.frequency._add_freq_order(
5739+ info["order" ] = utils.frequency._add_freq_order(
57295740 np.concatenate(
57305741 (
57315742 this.flex_spw_id_array,
5732- other.flex_spw_id_array[other_inds_use[axis ]],
5743+ other.flex_spw_id_array[info["other_inds_use" ]],
57335744 )
57345745 ),
57355746 np.concatenate(
5736- (this.freq_array, other.freq_array[other_inds_use[axis ]])
5747+ (this.freq_array, other.freq_array[info["other_inds_use" ]])
57375748 ),
57385749 )
57395750 else:
5740- order_dict[axis ] = np.argsort(combined_key_arrays[axis ])
5751+ info["order" ] = np.argsort(info["combined_key_arrays" ])
57415752
57425753 # first handle parameters with a single named axis
57435754 this._axis_add_helper(
5744- other, axis, other_inds_use[axis ], order_dict[axis ]
5755+ other, axis, info["other_inds_use" ], info["order" ]
57455756 )
57465757
57475758 # then pad out parameters with multiple axes
5748- this._axis_pad_helper(axis, len(other_inds_use[axis ]))
5759+ this._axis_pad_helper(axis, len(info["other_inds_use" ]))
57495760 else:
57505761 # no add along this axis, so it's the same as what's already on this
5751- combined_key_arrays[axis ] = axis_key_arrays[axis ]["this"]
5762+ info["combined_key_arrays" ] = info["key_arrays" ]["this"]
57525763
57535764 # Now fill in multidimensional parameters
5754- # t2o_dict has the mapping of where arrays on other get mapped into
5765+ # t2o has the mapping of where arrays on other get mapped into
57555766 # this after padding
5756- t2o_dict = {}
5757- for axis, inds_dict in axis_key_arrays.items():
5758- t2o_dict[axis] = np.nonzero(
5759- np.isin(combined_key_arrays[axis], inds_dict["other"])
5767+ for _, info in axis_info.items():
5768+ info["t2o"] = np.nonzero(
5769+ np.isin(info["combined_key_arrays"], info["key_arrays"]["other"])
57605770 )[0]
57615771
5762- this._fill_multi_helper(other, t2o_dict, order_dict)
5772+ this._fill_multi_helper(
5773+ other,
5774+ {axis: info["t2o"] for axis, info in axis_info.items()},
5775+ {axis: info["order"] for axis, info in axis_info.items()},
5776+ )
57635777
5764- if len(other_inds_use ["Nfreqs"]) > 0:
5778+ if len(axis_info ["Nfreqs"]["other_inds_use "]) > 0:
57655779 # We want to preserve per-spw information based on first appearance
57665780 # in the concatenated array.
57675781 unique_index = np.sort(
@@ -5815,7 +5829,7 @@ def __add__(
58155829 )
58165830
58175831 # Reset blt_order if blt axis was added to
5818- if order_dict ["Nblts"] is not None:
5832+ if axis_info ["Nblts"]["order "] is not None:
58195833 this.blt_order = ("time", "baseline")
58205834
58215835 this.set_rectangularity(force=True)
0 commit comments