4040)
4141
4242
43+ def flt_ind_str_arr(*, fltarr, intarr, flt_tols, flt_first=True):
44+ """Create a string array built from float and integer arrays for matching."""
45+ prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
46+ prec_int = 8
47+ flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr]
48+ int_str_list = [str(intv).zfill(prec_int) for intv in intarr]
49+ if flt_first:
50+ zipped = zip(flt_str_list, int_str_list, strict=True)
51+ else:
52+ zipped = zip(int_str_list, flt_str_list, strict=True)
53+ return np.array(["_".join(zpval) for zpval in zipped])
54+
55+
4356def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None):
4457 update_params = this._get_param_axis(axis_name, single_named_axis=True)
4558 other_form_dict = {axis_name: other_inds}
@@ -5378,15 +5391,20 @@ def fix_phase(self, *, use_ant_pos=True):
53785391
53795392 def blt_str_arr(self):
53805393 """Create a string array with baseline and time info for matching purposes."""
5381- prec_t = -2 * np.floor(np.log10(self._time_array.tols[-1])).astype(int)
5382- prec_b = 8
5383- return np.array(
5384- [
5385- "_".join(
5386- ["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)]
5387- )
5388- for blt in zip(self.time_array, self.baseline_array, strict=True)
5389- ]
5394+ return flt_ind_str_arr(
5395+ fltarr=self.time_array,
5396+ intarr=self.baseline_array,
5397+ flt_tols=self._time_array.tols,
5398+ flt_first=True,
5399+ )
5400+
5401+ def spw_freq_str_arr(self):
5402+ """Create a string array with spw and freq info for matching purposes."""
5403+ return flt_ind_str_arr(
5404+ fltarr=self.freq_array,
5405+ intarr=self.flex_spw_id_array,
5406+ flt_tols=self._freq_array.tols,
5407+ flt_first=False,
53905408 )
53915409
53925410 def flexpol_dict(self):
@@ -5485,9 +5503,10 @@ def __add__(
54855503 axis_params_check = {}
54865504 axis_overlap_params = {
54875505 "Nblts": ["time_array", "baseline_array"],
5488- "Nfreqs": ["freq_array"],
5506+ "Nfreqs": ["freq_array", "flex_spw_id_array" ],
54895507 "Npols": ["polarization_array"],
54905508 }
5509+ axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
54915510 axis_dict = {}
54925511 for axis in axes:
54935512 axis_dict[axis] = this._get_param_axis(axis)
@@ -5503,56 +5522,28 @@ def __add__(
55035522 # the blt axis because we need a combo of time and baseline.
55045523 axis_vals = {}
55055524 for axis, overlap_params in axis_overlap_params.items():
5506- if axis == "Nblts" :
5507- # Create combined arrays for convenience
5508- this_blts = this.blt_str_arr()
5509- other_blts = other.blt_str_arr()
5510- axis_vals[axis] = {"this": this_blts, "other": other_blts }
5525+ if len(overlap_params) > 1 :
5526+ axis_vals[axis] = {
5527+ " this": getattr(this, axis_combined_func[axis])(),
5528+ " other": getattr(other, axis_combined_func[axis])(),
5529+ }
55115530 else:
55125531 axis_vals[axis] = {
55135532 "this": getattr(this, overlap_params[0]),
55145533 "other": getattr(other, overlap_params[0]),
55155534 }
5535+
55165536 # Check we don't have overlapping data
55175537 axis_inds = {}
55185538 for axis, val_arr in axis_vals.items():
5519- if axis == "Nfreqs":
5520- # This is more complicated because we are allowed to have channels
5521- # with the same frequency *if* they belong to different spectral
5522- # windows (one real-life example: you might want to preserve
5523- # guard bands in the correlator, which can have overlaping RF
5524- # frequency channels)
5525- this_freq_ind = np.array([], dtype=np.int64)
5526- other_freq_ind = np.array([], dtype=np.int64)
5527- both_freq_ind = np.array([], dtype=float)
5528- both_spw = np.intersect1d(this.spw_array, other.spw_array)
5529- for idx in both_spw:
5530- this_mask = np.where(this.flex_spw_id_array == idx)[0]
5531- other_mask = np.where(other.flex_spw_id_array == idx)[0]
5532- both_spw_freq, this_spw_ind, other_spw_ind = np.intersect1d(
5533- this.freq_array[this_mask],
5534- other.freq_array[other_mask],
5535- return_indices=True,
5536- )
5537- this_freq_ind = np.append(this_freq_ind, this_mask[this_spw_ind])
5538- other_freq_ind = np.append(
5539- other_freq_ind, other_mask[other_spw_ind]
5540- )
5541- both_freq_ind = np.append(both_freq_ind, both_spw_freq)
5542- axis_inds["Nfreqs"] = {
5543- "this": this_freq_ind,
5544- "other": other_freq_ind,
5545- "both": both_freq_ind,
5546- }
5547- else:
5548- both_inds, this_inds, other_inds = np.intersect1d(
5549- val_arr["this"], val_arr["other"], return_indices=True
5550- )
5551- axis_inds[axis] = {
5552- "this": this_inds,
5553- "other": other_inds,
5554- "both": both_inds,
5555- }
5539+ both_inds, this_inds, other_inds = np.intersect1d(
5540+ val_arr["this"], val_arr["other"], return_indices=True
5541+ )
5542+ axis_inds[axis] = {
5543+ "this": this_inds,
5544+ "other": other_inds,
5545+ "both": both_inds,
5546+ }
55565547
55575548 history_update_string = ""
55585549 if not self.metadata_only and np.all(
@@ -5603,12 +5594,11 @@ def __add__(
56035594 }
56045595 # find the indices in "other" but not in "this"
56055596 for axis in axes:
5606- if axis != "Nfreqs":
5607- temp = np.nonzero(
5608- ~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"])
5609- )[0]
5610- else:
5611- # more complicated because of spws
5597+ if axis == "Nfreqs" and (
5598+ this.flex_spw_polarization_array is not None
5599+ or other.flex_spw_polarization_array is not None
5600+ ):
5601+ # special checking for flex_spw
56125602 if (this.flex_spw_polarization_array is None) != (
56135603 other.flex_spw_polarization_array is None
56145604 ):
@@ -5633,27 +5623,18 @@ def __add__(
56335623 except KeyError:
56345624 this_flexpol_dict[key] = other_flexpol_dict[key]
56355625
5636- other_mask = np.ones_like(other.flex_spw_id_array, dtype=bool)
5637- for idx in np.intersect1d(this.spw_array, other.spw_array):
5638- other_mask[other.flex_spw_id_array == idx] = np.isin(
5639- other.freq_array[other.flex_spw_id_array == idx],
5640- this.freq_array[this.flex_spw_id_array == idx],
5641- invert=True,
5642- )
5643- temp = np.where(other_mask)[0]
5626+ temp = np.nonzero(
5627+ ~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"])
5628+ )[0]
56445629 if len(temp) > 0:
56455630 new_inds[axis] = temp
56465631 # add params associated with the other axes to compatibility_params
56475632 for axis2 in axes:
56485633 if axis2 != axis:
56495634 compatibility_params.extend(axis_params_check[axis2])
5650- if axis == "Nblts":
5651- new_blts = other_blts[temp]
56525635 additions.append(axis_descriptions[axis])
56535636 else:
56545637 new_inds[axis] = []
5655- if axis == "Nblts":
5656- new_blts = ([], [])
56575638
56585639 # Actually check compatibility parameters
56595640 for cp in compatibility_params:
@@ -5701,6 +5682,7 @@ def __add__(
57015682 order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
57025683 for axis, ind_dict in axis_inds.items():
57035684 if len(ind_dict["this"]) != 0:
5685+ # there is some overlap, so sorting matters
57045686 this_argsort = np.argsort(ind_dict["this"])
57055687 other_argsort = np.argsort(ind_dict["other"])
57065688
@@ -5714,16 +5696,18 @@ def __add__(
57145696 getattr(this, reorder_method[axis]["method"])(**kwargs)
57155697
57165698 # Pad out self to accommodate new data
5699+ new_axis_inds = {}
57175700 for axis_ind, axis in enumerate(axes):
57185701 if len(new_inds[axis]) > 0:
5719- if axis == "Nblts":
5720- this_blts = np.concatenate((this_blts, new_blts))
5721- order_dict["Nblts"] = np.argsort(this_blts)
5722- order_use = order_dict["Nblts"]
5702+ new_axis_inds[axis] = np.concatenate(
5703+ (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5704+ )
5705+ if axis == "Npols":
5706+ order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis]))
57235707 else:
5724- order_use = None
5708+ order_dict[axis] = np.argsort(new_axis_inds[axis])
57255709
5726- _axis_add_helper(this, other, axis, new_inds[axis], order_use )
5710+ _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis] )
57275711
57285712 if not self.metadata_only:
57295713 pad_shape = list(this.data_array.shape)
@@ -5738,6 +5722,8 @@ def __add__(
57385722 this.flag_array = np.concatenate(
57395723 [this.flag_array, 1 - zero_pad], axis=axis_ind
57405724 ).astype(np.bool_)
5725+ else:
5726+ new_axis_inds[axis] = axis_vals[axis]["this"]
57415727
57425728 if len(new_inds["Nfreqs"]) > 0:
57435729 # We want to preserve per-spw information based on first appearance
@@ -5752,51 +5738,24 @@ def __add__(
57525738 this.flex_spw_polarization_array = np.array(
57535739 [this_flexpol_dict[key] for key in this.spw_array]
57545740 )
5755- # Need to sort out the order of the individual windows first.
5756- order_dict["Nfreqs"] = np.concatenate(
5757- [
5758- np.where(this.flex_spw_id_array == idx)[0]
5759- for idx in sorted(this.spw_array)
5760- ]
5761- )
5762-
5763- # With spectral windows sorted, check and see if channels within
5764- # windows need sorting. If they are ordered in ascending or descending
5765- # fashion, leave them be. If not, sort in ascending order
5766- for idx in this.spw_array:
5767- select_mask = this.flex_spw_id_array[order_dict["Nfreqs"]] == idx
5768- check_freqs = this.freq_array[order_dict["Nfreqs"][select_mask]]
5769- if (not np.all(check_freqs[1:] > check_freqs[:-1])) and (
5770- not np.all(check_freqs[1:] < check_freqs[:-1])
5771- ):
5772- subsort_order = order_dict["Nfreqs"][select_mask]
5773- order_dict["Nfreqs"][select_mask] = subsort_order[
5774- np.argsort(check_freqs)
5775- ]
5776-
5777- if len(new_inds["Npols"]) > 0:
5778- order_dict["Npols"] = np.argsort(np.abs(this.polarization_array))
57795741
57805742 # Now populate the data
5781- pol_t2o = np.nonzero(
5782- np.isin(this.polarization_array, other.polarization_array)
5783- )[0]
5784- this_freqs = this.freq_array
5785- other_freqs = other.freq_array
5786-
5787- freq_t2o = np.zeros(this_freqs.shape, dtype=bool)
5788- for spw_id in set(this.spw_array).intersection(other.spw_array):
5789- mask = this.flex_spw_id_array == spw_id
5790- freq_t2o[mask] |= np.isin(
5791- this_freqs[mask], other_freqs[other.flex_spw_id_array == spw_id]
5792- )
5793- freq_t2o = np.nonzero(freq_t2o)[0]
5794- blt_t2o = np.nonzero(np.isin(this_blts, other_blts))[0]
5743+ t2o_dict = {}
5744+ for axis, inds_dict in axis_vals.items():
5745+ t2o_dict[axis] = np.nonzero(
5746+ np.isin(new_axis_inds[axis], inds_dict["other"])
5747+ )[0]
57955748
57965749 if not self.metadata_only:
5797- this.data_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.data_array
5798- this.nsample_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.nsample_array
5799- this.flag_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.flag_array
5750+ this.data_array[
5751+ np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5752+ ] = other.data_array
5753+ this.nsample_array[
5754+ np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5755+ ] = other.nsample_array
5756+ this.flag_array[
5757+ np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5758+ ] = other.flag_array
58005759
58015760 # Fix ordering
58025761 for axis_ind, axis in enumerate(axes):
@@ -5812,16 +5771,6 @@ def __add__(
58125771 this, name, np.take(param, order_dict[axis], axis=axis_ind)
58135772 )
58145773
5815- # reorder freq, pol axes but not blt axis because that was already done.
5816- for axis in axes[1:]:
5817- params_to_update = axis_params_check[axis] + [
5818- "_" + param for param in axis_overlap_params[axis]
5819- ]
5820- if len(new_inds[axis]) > 0:
5821- for param in params_to_update:
5822- this_param = getattr(this, param)
5823- this_param.value = this_param.value[order_dict[axis]]
5824-
58255774 # Update N parameters (e.g. Npols)
58265775 this.Ntimes = len(np.unique(this.time_array))
58275776 this.Nbls = len(np.unique(this.baseline_array))
@@ -5862,8 +5811,8 @@ def __add__(
58625811 )
58635812
58645813 # Reset blt_order if blt axis was added to and it is set
5865- if len(blt_t2o ) > 0:
5866- this.blt_order = None
5814+ if len(t2o_dict["Nblts"] ) > 0:
5815+ this.blt_order = ("time", "baseline")
58675816
58685817 this.set_rectangularity(force=True)
58695818
0 commit comments