@@ -67,6 +67,7 @@ def flt_ind_str_arr(
6767 -------
6868 np.ndarray of str
6969 String array that combines the float and integer values, useful for matching.
70+
7071 """
7172 prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
7273 prec_int = 8
@@ -99,6 +100,7 @@ def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray:
99100 -------
100101 f_order : np.ndarray of int
101102 index array giving the sort order.
103+
102104 """
103105 spws = np.unique(spw_id)
104106 f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)])
@@ -126,7 +128,7 @@ def _axis_add_helper(
126128 final_order: IntArray | None = None,
127129):
128130 """
129- Combine UVData objects along an axis.
131+ Combine UVParameter objects with a single axis along an axis.
130132
131133 Parameters
132134 ----------
@@ -140,6 +142,7 @@ def _axis_add_helper(
140142 Indices into the other object along this axis to include.
141143 final_order : np.ndarray of int
142144 Final ordering array giving the sort order after concatenation.
145+
143146 """
144147 update_params = this._get_param_axis(axis_name, single_named_axis=True)
145148 other_form_dict = {axis_name: other_inds}
@@ -158,9 +161,90 @@ def _axis_add_helper(
158161 setattr(this, param, new_array)
159162
160163
164+ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
165+ """
166+ Pad out UVParameter objects with multiple dimensions along an axis.
167+
168+ Parameters
169+ ----------
170+ this : UVData
171+ The left UVData object in the add.
172+ axis_name : str
173+ The axis name (e.g. "Nblts", "Npols").
174+ add_len : int
175+ The extra length to be padded on for this axis.
176+
177+ """
178+ update_params = this._get_param_axis(axis_name)
179+ multi_axis_params = this._get_multi_axis_params()
180+ for param, axis_list in update_params.items():
181+ if param not in multi_axis_params:
182+ continue
183+ this_param_shape = getattr(this, param).shape
184+ this_param_type = getattr(this, "_" + param).expected_type
185+ bool_type = this_param_type is bool or bool in this_param_type
186+ pad_shape = list(this_param_shape)
187+ for ax in axis_list:
188+ pad_shape[ax] = add_len
189+ if bool_type:
190+ pad_array = np.ones(tuple(pad_shape), dtype=bool)
191+ else:
192+ pad_array = np.zeros(tuple(pad_shape))
193+ new_array = np.concatenate([getattr(this, param), pad_array], axis=ax)
194+ if bool_type:
195+ new_array = new_array.astype(np.bool_)
196+ setattr(this, param, new_array)
197+
198+
199+ def _fill_multi_helper(
200+ this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
201+ ):
202+ """
203+ Fill UVParameter objects with multiple dimensions from the right side object.
204+
205+ Parameters
206+ ----------
207+ this : UVData
208+ The left UVData object in the add.
209+ other : UVData
210+ The right UVData object in the add.
211+ t2o_dict : dict
212+ dict giving the indices in the left object to be filled from the right
213+ object for each axis (keys are axes, values are index arrays).
214+ sort_axes : list of str
215+ The axes that need to be sorted along.
216+ order_dict : dict
217+ dict giving the final sort indices for each axis (keys are axes, values
218+ are index arrays for sorting).
219+
220+ """
221+ multi_axis_params = this._get_multi_axis_params()
222+ for param in multi_axis_params:
223+ form = getattr(this, "_" + param).form
224+ index_list = []
225+ for axis in form:
226+ index_list.append(t2o_dict[axis])
227+ new_arr = getattr(this, param)
228+ new_arr[np.ix_(*index_list)] = getattr(other, param)
229+ setattr(this, param, new_arr)
230+
231+ # Fix ordering
232+ for axis_ind, axis in enumerate(form):
233+ if axis in sort_axes:
234+ unique_order_diffs = np.unique(np.diff(order_dict[axis]))
235+ if np.array_equal(unique_order_diffs, np.array([1])):
236+ # everything is already in order
237+ continue
238+ setattr(
239+ this,
240+ param,
241+ np.take(getattr(this, param), order_dict[axis], axis=axis_ind),
242+ )
243+
244+
161245def _axis_fast_concat_helper(this: UVData, other: UVData, axis_name: str):
162246 """
163- Concatenate UVData objects along an axis assuming no overlap.
247+ Concatenate UVParameter objects along an axis assuming no overlap.
164248
165249 Parameters
166250 ----------
@@ -5816,7 +5900,7 @@ def __add__(
58165900 # Pad out self to accommodate new data
58175901 new_axis_inds = {}
58185902 order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
5819- for axis_ind, axis in enumerate( axes) :
5903+ for axis in axes:
58205904 if len(new_inds[axis]) > 0:
58215905 new_axis_inds[axis] = np.concatenate(
58225906 (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
@@ -5843,24 +5927,27 @@ def __add__(
58435927 else:
58445928 order_dict[axis] = np.argsort(new_axis_inds[axis])
58455929
5930+ # first handle parameters with a single axis
58465931 _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
58475932
5848- if not self.metadata_only:
5849- pad_shape = list(this.data_array.shape)
5850- pad_shape[axis_ind] = len(new_inds[axis])
5851- zero_pad = np.zeros(tuple(pad_shape))
5852- this.data_array = np.concatenate(
5853- [this.data_array, zero_pad], axis=axis_ind
5854- )
5855- this.nsample_array = np.concatenate(
5856- [this.nsample_array, zero_pad], axis=axis_ind
5857- )
5858- this.flag_array = np.concatenate(
5859- [this.flag_array, 1 - zero_pad], axis=axis_ind
5860- ).astype(np.bool_)
5933+ # then pad out parameters with multiple axes
5934+ _axis_pad_helper(this, axis, len(new_inds[axis]))
58615935 else:
58625936 new_axis_inds[axis] = axis_vals[axis]["this"]
58635937
5938+ # Now fill in multidimensional arrays
5939+ t2o_dict = {}
5940+ for axis, inds_dict in axis_vals.items():
5941+ t2o_dict[axis] = np.nonzero(
5942+ np.isin(new_axis_inds[axis], inds_dict["other"])
5943+ )[0]
5944+
5945+ sort_axes = []
5946+ for axis in axes:
5947+ if len(new_inds[axis]) > 0:
5948+ sort_axes.append(axis)
5949+ _fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5950+
58645951 if len(new_inds["Nfreqs"]) > 0:
58655952 # We want to preserve per-spw information based on first appearance
58665953 # in the concatenated array.
@@ -5875,38 +5962,6 @@ def __add__(
58755962 [this_flexpol_dict[key] for key in this.spw_array]
58765963 )
58775964
5878- # Now populate the data
5879- t2o_dict = {}
5880- for axis, inds_dict in axis_vals.items():
5881- t2o_dict[axis] = np.nonzero(
5882- np.isin(new_axis_inds[axis], inds_dict["other"])
5883- )[0]
5884-
5885- if not self.metadata_only:
5886- this.data_array[
5887- np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5888- ] = other.data_array
5889- this.nsample_array[
5890- np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5891- ] = other.nsample_array
5892- this.flag_array[
5893- np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5894- ] = other.flag_array
5895-
5896- # Fix ordering
5897- for axis_ind, axis in enumerate(axes):
5898- for name, param in zip(
5899- this._data_params, this.data_like_parameters, strict=True
5900- ):
5901- if len(new_inds[axis]) > 0:
5902- unique_order_diffs = np.unique(np.diff(order_dict[axis]))
5903- if np.array_equal(unique_order_diffs, np.array([1])):
5904- # everything is already in order
5905- continue
5906- setattr(
5907- this, name, np.take(param, order_dict[axis], axis=axis_ind)
5908- )
5909-
59105965 # Update N parameters (e.g. Npols)
59115966 this.Ntimes = len(np.unique(this.time_array))
59125967 this.Nbls = len(np.unique(this.baseline_array))
0 commit comments