Skip to content

Commit ab19f67

Browse files
committed
handle multidimensional arrays programmatically
1 parent 3cc2172 commit ab19f67

2 files changed

Lines changed: 120 additions & 48 deletions

File tree

src/pyuvdata/uvbase.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
771771
The keys are UVParameter names that have an axis with axis_name
772772
(axis_name appears in their form). The values are a list of the axis
773773
indices where axis_name appears in their form.
774+
774775
"""
775776
ret_dict = {}
776777
for param in self:
@@ -795,6 +796,22 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
795796
ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0]
796797
return ret_dict
797798

799+
def _get_multi_axis_params(self) -> list[str]:
800+
"""Get a list of all multidimensional parameters."""
801+
ret_list = []
802+
for param in self:
803+
# For each attribute, if the value is None, then bail, otherwise
804+
# attempt to figure out along which axis ind_arr will apply.
805+
806+
attr = getattr(self, param)
807+
if (
808+
attr.value is not None
809+
and isinstance(attr.form, tuple)
810+
and sum([isinstance(entry, str) for entry in attr.form]) > 1
811+
):
812+
ret_list.append(attr.name)
813+
return ret_list
814+
798815
def _select_along_param_axis(self, param_dict: dict):
799816
"""
800817
Downselect values along a parameterized axis.

src/pyuvdata/uvdata/uvdata.py

Lines changed: 103 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def flt_ind_str_arr(
6767
-------
6868
np.ndarray of str
6969
String array that combines the float and integer values, useful for matching.
70+
7071
"""
7172
prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
7273
prec_int = 8
@@ -99,6 +100,7 @@ def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray:
99100
-------
100101
f_order : np.ndarray of int
101102
index array giving the sort order.
103+
102104
"""
103105
spws = np.unique(spw_id)
104106
f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)])
@@ -126,7 +128,7 @@ def _axis_add_helper(
126128
final_order: IntArray | None = None,
127129
):
128130
"""
129-
Combine UVData objects along an axis.
131+
Combine UVParameter objects with a single axis along an axis.
130132

131133
Parameters
132134
----------
@@ -140,6 +142,7 @@ def _axis_add_helper(
140142
Indices into the other object along this axis to include.
141143
final_order : np.ndarray of int
142144
Final ordering array giving the sort order after concatenation.
145+
143146
"""
144147
update_params = this._get_param_axis(axis_name, single_named_axis=True)
145148
other_form_dict = {axis_name: other_inds}
@@ -158,9 +161,90 @@ def _axis_add_helper(
158161
setattr(this, param, new_array)
159162

160163

164+
def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
165+
"""
166+
Pad out UVParameter objects with multiple dimensions along an axis.
167+
168+
Parameters
169+
----------
170+
this : UVData
171+
The left UVData object in the add.
172+
axis_name : str
173+
The axis name (e.g. "Nblts", "Npols").
174+
add_len : int
175+
The extra length to be padded on for this axis.
176+
177+
"""
178+
update_params = this._get_param_axis(axis_name)
179+
multi_axis_params = this._get_multi_axis_params()
180+
for param, axis_list in update_params.items():
181+
if param not in multi_axis_params:
182+
continue
183+
this_param_shape = getattr(this, param).shape
184+
this_param_type = getattr(this, "_" + param).expected_type
185+
bool_type = this_param_type is bool or bool in this_param_type
186+
pad_shape = list(this_param_shape)
187+
for ax in axis_list:
188+
pad_shape[ax] = add_len
189+
if bool_type:
190+
pad_array = np.ones(tuple(pad_shape), dtype=bool)
191+
else:
192+
pad_array = np.zeros(tuple(pad_shape))
193+
new_array = np.concatenate([getattr(this, param), pad_array], axis=ax)
194+
if bool_type:
195+
new_array = new_array.astype(np.bool_)
196+
setattr(this, param, new_array)
197+
198+
199+
def _fill_multi_helper(
200+
this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
201+
):
202+
"""
203+
Fill UVParameter objects with multiple dimensions from the right side object.
204+
205+
Parameters
206+
----------
207+
this : UVData
208+
The left UVData object in the add.
209+
other : UVData
210+
The right UVData object in the add.
211+
t2o_dict : dict
212+
dict giving the indices in the left object to be filled from the right
213+
object for each axis (keys are axes, values are index arrays).
214+
sort_axes : list of str
215+
The axes that need to be sorted along.
216+
order_dict : dict
217+
dict giving the final sort indices for each axis (keys are axes, values
218+
are index arrays for sorting).
219+
220+
"""
221+
multi_axis_params = this._get_multi_axis_params()
222+
for param in multi_axis_params:
223+
form = getattr(this, "_" + param).form
224+
index_list = []
225+
for axis in form:
226+
index_list.append(t2o_dict[axis])
227+
new_arr = getattr(this, param)
228+
new_arr[np.ix_(*index_list)] = getattr(other, param)
229+
setattr(this, param, new_arr)
230+
231+
# Fix ordering
232+
for axis_ind, axis in enumerate(form):
233+
if axis in sort_axes:
234+
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
235+
if np.array_equal(unique_order_diffs, np.array([1])):
236+
# everything is already in order
237+
continue
238+
setattr(
239+
this,
240+
param,
241+
np.take(getattr(this, param), order_dict[axis], axis=axis_ind),
242+
)
243+
244+
161245
def _axis_fast_concat_helper(this: UVData, other: UVData, axis_name: str):
162246
"""
163-
Concatenate UVData objects along an axis assuming no overlap.
247+
Concatenate UVParameter objects along an axis assuming no overlap.
164248

165249
Parameters
166250
----------
@@ -5797,7 +5881,7 @@ def __add__(
57975881
# Pad out self to accommodate new data
57985882
new_axis_inds = {}
57995883
order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
5800-
for axis_ind, axis in enumerate(axes):
5884+
for axis in axes:
58015885
if len(new_inds[axis]) > 0:
58025886
new_axis_inds[axis] = np.concatenate(
58035887
(axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
@@ -5824,24 +5908,27 @@ def __add__(
58245908
else:
58255909
order_dict[axis] = np.argsort(new_axis_inds[axis])
58265910

5911+
# first handle parameters with a single axis
58275912
_axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
58285913

5829-
if not self.metadata_only:
5830-
pad_shape = list(this.data_array.shape)
5831-
pad_shape[axis_ind] = len(new_inds[axis])
5832-
zero_pad = np.zeros(tuple(pad_shape))
5833-
this.data_array = np.concatenate(
5834-
[this.data_array, zero_pad], axis=axis_ind
5835-
)
5836-
this.nsample_array = np.concatenate(
5837-
[this.nsample_array, zero_pad], axis=axis_ind
5838-
)
5839-
this.flag_array = np.concatenate(
5840-
[this.flag_array, 1 - zero_pad], axis=axis_ind
5841-
).astype(np.bool_)
5914+
# then pad out parameters with multiple axes
5915+
_axis_pad_helper(this, axis, len(new_inds[axis]))
58425916
else:
58435917
new_axis_inds[axis] = axis_vals[axis]["this"]
58445918

5919+
# Now fill in multidimensional arrays
5920+
t2o_dict = {}
5921+
for axis, inds_dict in axis_vals.items():
5922+
t2o_dict[axis] = np.nonzero(
5923+
np.isin(new_axis_inds[axis], inds_dict["other"])
5924+
)[0]
5925+
5926+
sort_axes = []
5927+
for axis in axes:
5928+
if len(new_inds[axis]) > 0:
5929+
sort_axes.append(axis)
5930+
_fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5931+
58455932
if len(new_inds["Nfreqs"]) > 0:
58465933
# We want to preserve per-spw information based on first appearance
58475934
# in the concatenated array.
@@ -5856,38 +5943,6 @@ def __add__(
58565943
[this_flexpol_dict[key] for key in this.spw_array]
58575944
)
58585945

5859-
# Now populate the data
5860-
t2o_dict = {}
5861-
for axis, inds_dict in axis_vals.items():
5862-
t2o_dict[axis] = np.nonzero(
5863-
np.isin(new_axis_inds[axis], inds_dict["other"])
5864-
)[0]
5865-
5866-
if not self.metadata_only:
5867-
this.data_array[
5868-
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5869-
] = other.data_array
5870-
this.nsample_array[
5871-
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5872-
] = other.nsample_array
5873-
this.flag_array[
5874-
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5875-
] = other.flag_array
5876-
5877-
# Fix ordering
5878-
for axis_ind, axis in enumerate(axes):
5879-
for name, param in zip(
5880-
this._data_params, this.data_like_parameters, strict=True
5881-
):
5882-
if len(new_inds[axis]) > 0:
5883-
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
5884-
if np.array_equal(unique_order_diffs, np.array([1])):
5885-
# everything is already in order
5886-
continue
5887-
setattr(
5888-
this, name, np.take(param, order_dict[axis], axis=axis_ind)
5889-
)
5890-
58915946
# Update N parameters (e.g. Npols)
58925947
this.Ntimes = len(np.unique(this.time_array))
58935948
this.Nbls = len(np.unique(this.baseline_array))

0 commit comments

Comments
 (0)