Skip to content

Commit 329be58

Browse files
committed
handle multidimensional arrays programmatically
1 parent ac4c230 commit 329be58

2 files changed

Lines changed: 120 additions & 48 deletions

File tree

src/pyuvdata/uvbase.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
771771
The keys are UVParameter names that have an axis with axis_name
772772
(axis_name appears in their form). The values are a list of the axis
773773
indices where axis_name appears in their form.
774+
774775
"""
775776
ret_dict = {}
776777
for param in self:
@@ -795,6 +796,22 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
795796
ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0]
796797
return ret_dict
797798

799+
def _get_multi_axis_params(self) -> list[str]:
800+
"""Get a list of all multidimensional parameters."""
801+
ret_list = []
802+
for param in self:
803+
# For each attribute, if the value is None, then bail, otherwise
804+
# attempt to figure out along which axis ind_arr will apply.
805+
806+
attr = getattr(self, param)
807+
if (
808+
attr.value is not None
809+
and isinstance(attr.form, tuple)
810+
and sum([isinstance(entry, str) for entry in attr.form]) > 1
811+
):
812+
ret_list.append(attr.name)
813+
return ret_list
814+
798815
def _select_along_param_axis(self, param_dict: dict):
799816
"""
800817
Downselect values along a parameterized axis.

src/pyuvdata/uvdata/uvdata.py

Lines changed: 103 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def flt_ind_str_arr(
6767
-------
6868
np.ndarray of str
6969
String array that combines the float and integer values, useful for matching.
70+
7071
"""
7172
prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
7273
prec_int = 8
@@ -99,6 +100,7 @@ def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray:
99100
-------
100101
f_order : np.ndarray of int
101102
index array giving the sort order.
103+
102104
"""
103105
spws = np.unique(spw_id)
104106
f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)])
@@ -126,7 +128,7 @@ def _axis_add_helper(
126128
final_order: IntArray | None = None,
127129
):
128130
"""
129-
Combine UVData objects along an axis.
131+
Combine UVParameter objects with a single axis along an axis.
130132

131133
Parameters
132134
----------
@@ -140,6 +142,7 @@ def _axis_add_helper(
140142
Indices into the other object along this axis to include.
141143
final_order : np.ndarray of int
142144
Final ordering array giving the sort order after concatenation.
145+
143146
"""
144147
update_params = this._get_param_axis(axis_name, single_named_axis=True)
145148
other_form_dict = {axis_name: other_inds}
@@ -158,9 +161,90 @@ def _axis_add_helper(
158161
setattr(this, param, new_array)
159162

160163

164+
def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
165+
"""
166+
Pad out UVParameter objects with multiple dimensions along an axis.
167+
168+
Parameters
169+
----------
170+
this : UVData
171+
The left UVData object in the add.
172+
axis_name : str
173+
The axis name (e.g. "Nblts", "Npols").
174+
add_len : int
175+
The extra length to be padded on for this axis.
176+
177+
"""
178+
update_params = this._get_param_axis(axis_name)
179+
multi_axis_params = this._get_multi_axis_params()
180+
for param, axis_list in update_params.items():
181+
if param not in multi_axis_params:
182+
continue
183+
this_param_shape = getattr(this, param).shape
184+
this_param_type = getattr(this, "_" + param).expected_type
185+
bool_type = this_param_type is bool or bool in this_param_type
186+
pad_shape = list(this_param_shape)
187+
for ax in axis_list:
188+
pad_shape[ax] = add_len
189+
if bool_type:
190+
pad_array = np.ones(tuple(pad_shape), dtype=bool)
191+
else:
192+
pad_array = np.zeros(tuple(pad_shape))
193+
new_array = np.concatenate([getattr(this, param), pad_array], axis=ax)
194+
if bool_type:
195+
new_array = new_array.astype(np.bool_)
196+
setattr(this, param, new_array)
197+
198+
199+
def _fill_multi_helper(
200+
this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
201+
):
202+
"""
203+
Fill UVParameter objects with multiple dimensions from the right side object.
204+
205+
Parameters
206+
----------
207+
this : UVData
208+
The left UVData object in the add.
209+
other : UVData
210+
The right UVData object in the add.
211+
t2o_dict : dict
212+
dict giving the indices in the left object to be filled from the right
213+
object for each axis (keys are axes, values are index arrays).
214+
sort_axes : list of str
215+
The axes that need to be sorted along.
216+
order_dict : dict
217+
dict giving the final sort indices for each axis (keys are axes, values
218+
are index arrays for sorting).
219+
220+
"""
221+
multi_axis_params = this._get_multi_axis_params()
222+
for param in multi_axis_params:
223+
form = getattr(this, "_" + param).form
224+
index_list = []
225+
for axis in form:
226+
index_list.append(t2o_dict[axis])
227+
new_arr = getattr(this, param)
228+
new_arr[np.ix_(*index_list)] = getattr(other, param)
229+
setattr(this, param, new_arr)
230+
231+
# Fix ordering
232+
for axis_ind, axis in enumerate(form):
233+
if axis in sort_axes:
234+
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
235+
if np.array_equal(unique_order_diffs, np.array([1])):
236+
# everything is already in order
237+
continue
238+
setattr(
239+
this,
240+
param,
241+
np.take(getattr(this, param), order_dict[axis], axis=axis_ind),
242+
)
243+
244+
161245
def _axis_fast_concat_helper(this: UVData, other: UVData, axis_name: str):
162246
"""
163-
Concatenate UVData objects along an axis assuming no overlap.
247+
Concatenate UVParameter objects along an axis assuming no overlap.
164248

165249
Parameters
166250
----------
@@ -5816,7 +5900,7 @@ def __add__(
58165900
# Pad out self to accommodate new data
58175901
new_axis_inds = {}
58185902
order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
5819-
for axis_ind, axis in enumerate(axes):
5903+
for axis in axes:
58205904
if len(new_inds[axis]) > 0:
58215905
new_axis_inds[axis] = np.concatenate(
58225906
(axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
@@ -5843,24 +5927,27 @@ def __add__(
58435927
else:
58445928
order_dict[axis] = np.argsort(new_axis_inds[axis])
58455929

5930+
# first handle parameters with a single axis
58465931
_axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
58475932

5848-
if not self.metadata_only:
5849-
pad_shape = list(this.data_array.shape)
5850-
pad_shape[axis_ind] = len(new_inds[axis])
5851-
zero_pad = np.zeros(tuple(pad_shape))
5852-
this.data_array = np.concatenate(
5853-
[this.data_array, zero_pad], axis=axis_ind
5854-
)
5855-
this.nsample_array = np.concatenate(
5856-
[this.nsample_array, zero_pad], axis=axis_ind
5857-
)
5858-
this.flag_array = np.concatenate(
5859-
[this.flag_array, 1 - zero_pad], axis=axis_ind
5860-
).astype(np.bool_)
5933+
# then pad out parameters with multiple axes
5934+
_axis_pad_helper(this, axis, len(new_inds[axis]))
58615935
else:
58625936
new_axis_inds[axis] = axis_vals[axis]["this"]
58635937

5938+
# Now fill in multidimensional arrays
5939+
t2o_dict = {}
5940+
for axis, inds_dict in axis_vals.items():
5941+
t2o_dict[axis] = np.nonzero(
5942+
np.isin(new_axis_inds[axis], inds_dict["other"])
5943+
)[0]
5944+
5945+
sort_axes = []
5946+
for axis in axes:
5947+
if len(new_inds[axis]) > 0:
5948+
sort_axes.append(axis)
5949+
_fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5950+
58645951
if len(new_inds["Nfreqs"]) > 0:
58655952
# We want to preserve per-spw information based on first appearance
58665953
# in the concatenated array.
@@ -5875,38 +5962,6 @@ def __add__(
58755962
[this_flexpol_dict[key] for key in this.spw_array]
58765963
)
58775964

5878-
# Now populate the data
5879-
t2o_dict = {}
5880-
for axis, inds_dict in axis_vals.items():
5881-
t2o_dict[axis] = np.nonzero(
5882-
np.isin(new_axis_inds[axis], inds_dict["other"])
5883-
)[0]
5884-
5885-
if not self.metadata_only:
5886-
this.data_array[
5887-
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5888-
] = other.data_array
5889-
this.nsample_array[
5890-
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5891-
] = other.nsample_array
5892-
this.flag_array[
5893-
np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"])
5894-
] = other.flag_array
5895-
5896-
# Fix ordering
5897-
for axis_ind, axis in enumerate(axes):
5898-
for name, param in zip(
5899-
this._data_params, this.data_like_parameters, strict=True
5900-
):
5901-
if len(new_inds[axis]) > 0:
5902-
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
5903-
if np.array_equal(unique_order_diffs, np.array([1])):
5904-
# everything is already in order
5905-
continue
5906-
setattr(
5907-
this, name, np.take(param, order_dict[axis], axis=axis_ind)
5908-
)
5909-
59105965
# Update N parameters (e.g. Npols)
59115966
this.Ntimes = len(np.unique(this.time_array))
59125967
this.Nbls = len(np.unique(this.baseline_array))

0 commit comments

Comments
 (0)