@@ -67,6 +67,7 @@ def flt_ind_str_arr(
6767 -------
6868 np.ndarray of str
6969 String array that combines the float and integer values, useful for matching.
70+
7071 """
7172 prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
7273 prec_int = 8
@@ -99,6 +100,7 @@ def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray:
99100 -------
100101 f_order : np.ndarray of int
101102 index array giving the sort order.
103+
102104 """
103105 spws = np.unique(spw_id)
104106 f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)])
@@ -126,7 +128,7 @@ def _axis_add_helper(
126128 final_order: IntArray | None = None,
127129):
128130 """
129- Combine UVData objects along an axis.
131+ Combine UVParameter objects with a single axis along an axis.
130132
131133 Parameters
132134 ----------
@@ -140,6 +142,7 @@ def _axis_add_helper(
140142 Indices into the other object along this axis to include.
141143 final_order : np.ndarray of int
142144 Final ordering array giving the sort order after concatenation.
145+
143146 """
144147 update_params = this._get_param_axis(axis_name, single_named_axis=True)
145148 other_form_dict = {axis_name: other_inds}
@@ -158,9 +161,90 @@ def _axis_add_helper(
158161 setattr(this, param, new_array)
159162
160163
164+ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
165+ """
166+ Pad out UVParameter objects with multiple dimensions along an axis.
167+
168+ Parameters
169+ ----------
170+ this : UVData
171+ The left UVData object in the add.
172+ axis_name : str
173+ The axis name (e.g. "Nblts", "Npols").
174+ add_len : int
175+ The extra length to be padded on for this axis.
176+
177+ """
178+ update_params = this._get_param_axis(axis_name)
179+ multi_axis_params = this._get_multi_axis_params()
180+ for param, axis_list in update_params.items():
181+ if param not in multi_axis_params:
182+ continue
183+ this_param_shape = getattr(this, param).shape
184+ this_param_type = getattr(this, "_" + param).expected_type
185+ bool_type = this_param_type is bool or bool in this_param_type
186+ pad_shape = list(this_param_shape)
187+ for ax in axis_list:
188+ pad_shape[ax] = add_len
189+ if bool_type:
190+ pad_array = np.ones(tuple(pad_shape), dtype=bool)
191+ else:
192+ pad_array = np.zeros(tuple(pad_shape))
193+ new_array = np.concatenate([getattr(this, param), pad_array], axis=ax)
194+ if bool_type:
195+ new_array = new_array.astype(np.bool_)
196+ setattr(this, param, new_array)
197+
198+
199+ def _fill_multi_helper(
200+ this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
201+ ):
202+ """
203+ Fill UVParameter objects with multiple dimensions from the right side object.
204+
205+ Parameters
206+ ----------
207+ this : UVData
208+ The left UVData object in the add.
209+ other : UVData
210+ The right UVData object in the add.
211+ t2o_dict : dict
212+ dict giving the indices in the left object to be filled from the right
213+ object for each axis (keys are axes, values are index arrays).
214+ sort_axes : list of str
215+ The axes that need to be sorted along.
216+ order_dict : dict
217+ dict giving the final sort indices for each axis (keys are axes, values
218+ are index arrays for sorting).
219+
220+ """
221+ multi_axis_params = this._get_multi_axis_params()
222+ for param in multi_axis_params:
223+ form = getattr(this, "_" + param).form
224+ index_list = []
225+ for axis in form:
226+ index_list.append(t2o_dict[axis])
227+ new_arr = getattr(this, param)
228+ new_arr[np.ix_(*index_list)] = getattr(other, param)
229+ setattr(this, param, new_arr)
230+
231+ # Fix ordering
232+ for axis_ind, axis in enumerate(form):
233+ if axis in sort_axes:
234+ unique_order_diffs = np.unique(np.diff(order_dict[axis]))
235+ if np.array_equal(unique_order_diffs, np.array([1])):
236+ # everything is already in order
237+ continue
238+ setattr(
239+ this,
240+ param,
241+ np.take(getattr(this, param), order_dict[axis], axis=axis_ind),
242+ )
243+
244+
161245def _axis_fast_concat_helper(this: UVData, other: UVData, axis_name: str):
162246 """
163- Concatenate UVData objects along an axis assuming no overlap.
247+ Concatenate UVParameter objects along an axis assuming no overlap.
164248
165249 Parameters
166250 ----------
@@ -5797,7 +5881,7 @@ def __add__(
57975881 # Pad out self to accommodate new data
57985882 new_axis_inds = {}
57995883 order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
5800- for axis_ind, axis in enumerate( axes) :
5884+ for axis in axes:
58015885 if len(new_inds[axis]) > 0:
58025886 new_axis_inds[axis] = np.concatenate(
58035887 (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
@@ -5824,24 +5908,27 @@ def __add__(
58245908 else:
58255909 order_dict[axis] = np.argsort(new_axis_inds[axis])
58265910
5911+ # first handle parameters with a single axis
58275912 _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
58285913
5829- if not self.metadata_only:
5830- pad_shape = list(this.data_array.shape)
5831- pad_shape[axis_ind] = len(new_inds[axis])
5832- zero_pad = np.zeros(tuple(pad_shape))
5833- this.data_array = np.concatenate(
5834- [this.data_array, zero_pad], axis=axis_ind
5835- )
5836- this.nsample_array = np.concatenate(
5837- [this.nsample_array, zero_pad], axis=axis_ind
5838- )
5839- this.flag_array = np.concatenate(
5840- [this.flag_array, 1 - zero_pad], axis=axis_ind
5841- ).astype(np.bool_)
5914+ # then pad out parameters with multiple axes
5915+ _axis_pad_helper(this, axis, len(new_inds[axis]))
58425916 else:
58435917 new_axis_inds[axis] = axis_vals[axis]["this"]
58445918
5919+ # Now fill in multidimensional arrays
5920+ t2o_dict = {}
5921+ for axis, inds_dict in axis_vals.items():
5922+ t2o_dict[axis] = np.nonzero(
5923+ np.isin(new_axis_inds[axis], inds_dict["other"])
5924+ )[0]
5925+
5926+ sort_axes = []
5927+ for axis in axes:
5928+ if len(new_inds[axis]) > 0:
5929+ sort_axes.append(axis)
5930+ _fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5931+
58455932 if len(new_inds["Nfreqs"]) > 0:
58465933 # We want to preserve per-spw information based on first appearance
58475934 # in the concatenated array.
@@ -5856,38 +5943,6 @@ def __add__(
58565943 [this_flexpol_dict[key] for key in this.spw_array]
58575944 )
58585945
5859- # Now populate the data
5860- t2o_dict = {}
5861- for axis, inds_dict in axis_vals.items():
5862- t2o_dict[axis] = np.nonzero(
5863- np.isin(new_axis_inds[axis], inds_dict["other"])
5864- )[0]
5865-
5866- if not self.metadata_only:
5867- this.data_array[
5868- np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5869- ] = other.data_array
5870- this.nsample_array[
5871- np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5872- ] = other.nsample_array
5873- this.flag_array[
5874- np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5875- ] = other.flag_array
5876-
5877- # Fix ordering
5878- for axis_ind, axis in enumerate(axes):
5879- for name, param in zip(
5880- this._data_params, this.data_like_parameters, strict=True
5881- ):
5882- if len(new_inds[axis]) > 0:
5883- unique_order_diffs = np.unique(np.diff(order_dict[axis]))
5884- if np.array_equal(unique_order_diffs, np.array([1])):
5885- # everything is already in order
5886- continue
5887- setattr(
5888- this, name, np.take(param, order_dict[axis], axis=axis_ind)
5889- )
5890-
58915946 # Update N parameters (e.g. Npols)
58925947 this.Ntimes = len(np.unique(this.time_array))
58935948 this.Nbls = len(np.unique(this.baseline_array))
0 commit comments