Skip to content

Commit 99fb038

Browse files
committed
Better comments and variable names
1 parent 5c5fa70 commit 99fb038

2 files changed

Lines changed: 99 additions & 77 deletions

File tree

src/pyuvdata/uvbase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
776776
ret_dict = {}
777777
for param in self:
778778
# For each attribute, if the value is None, then bail, otherwise
779-
# attempt to figure out along which axis ind_arr will apply.
779+
# find the axis number(s) with the named shape.
780780

781781
attr = getattr(self, param)
782782
if (
@@ -791,8 +791,8 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
791791
continue
792792

793793
# Only look at where form is a tuple, since that's the only case we
794-
# can have a dynamically defined shape. Note that index doesn't work
795-
# here in the case of a repeated param_name in the form.
794+
# can have a dynamically defined shape. Handle a repeated
795+
# param_name in the form.
796796
ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0]
797797
return ret_dict
798798

src/pyuvdata/uvdata/uvdata.py

Lines changed: 96 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
195195
setattr(this, param, new_array)
196196

197197

198-
def _fill_multi_helper(
199-
this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
200-
):
198+
def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict):
201199
"""
202200
Fill UVParameter objects with multiple dimensions from the right side object.
203201

@@ -210,8 +208,6 @@ def _fill_multi_helper(
210208
t2o_dict : dict
211209
dict giving the indices in the left object to be filled from the right
212210
object for each axis (keys are axes, values are index arrays).
213-
sort_axes : list of str
214-
The axes that need to be sorted along.
215211
order_dict : dict
216212
dict giving the final sort indices for each axis (keys are axes, values
217213
are index arrays for sorting).
@@ -229,7 +225,7 @@ def _fill_multi_helper(
229225

230226
# Fix ordering
231227
for axis_ind, axis in enumerate(form):
232-
if axis in sort_axes:
228+
if order_dict[axis] is not None:
233229
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
234230
if np.array_equal(unique_order_diffs, np.array([1])):
235231
# everything is already in order
@@ -5728,59 +5724,72 @@ def __add__(
57285724
# Define parameters that must be the same to add objects
57295725
compatibility_params = ["_vis_units"]
57305726

5731-
# identify params that are not explicitly included in overlap calc per axis
57325727
axes = ["Nblts", "Nfreqs", "Npols"]
5733-
axis_params_check = {}
5734-
axis_overlap_params = {
5728+
# axis_key_params defines which parameters to use as the defining
5729+
# parameters along each axis. These are used to identify overlapping data.
5730+
axis_key_params = {
57355731
"Nblts": ["time_array", "baseline_array"],
57365732
"Nfreqs": ["freq_array", "flex_spw_id_array"],
57375733
"Npols": ["polarization_array"],
57385734
}
5739-
axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5740-
axis_dict = {}
5741-
for axis in axes:
5742-
axis_dict[axis] = this._get_param_axis(axis)
5743-
axis_params_check[axis] = []
5744-
for param in axis_dict[axis]:
5735+
# specify a function to form a combined string if there are multiple
5736+
# key arrays (e.g. baseline-time, spw-freq)
5737+
axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5738+
multi_axis_params = this._get_multi_axis_params()
5739+
# axis_parameters gives parameters whose form contains each axis
5740+
axis_parameters = {}
5741+
# axis_check_params gives parameters that should be checked if adding
5742+
# along other axes
5743+
axis_check_params = {}
5744+
# axis_key_arrays gives the arrays to use for checking for overlap per axis
5745+
axis_key_arrays = {}
5746+
# axis_overlap_inds has the outcomes of np.intersect1d on the
5747+
# axis_key_arrays per axis. So it has the both/this inds/other inds
5748+
# for any overlaps.
5749+
axis_overlap_inds = {}
5750+
for axis, overlap_params in axis_key_params.items():
5751+
axis_parameters[axis] = this._get_param_axis(axis)
5752+
axis_check_params[axis] = []
5753+
for param in axis_parameters[axis]:
5754+
# get parameters for compatibility checking. Exclude parameters
5755+
# that define overlap and multidimensional parameters which are
5756+
# handled separately later.
57455757
if (
5746-
param not in this._data_params
5747-
and param not in axis_overlap_params[axis]
5758+
param not in multi_axis_params
5759+
and param not in axis_key_params[axis]
57485760
):
5749-
axis_params_check[axis].append("_" + param)
5761+
axis_check_params[axis].append("_" + param)
57505762

5751-
# build this/other arrays for checking for overlap.
5752-
# Use a combined string if there are multiple arrays defining overlap
5753-
# (e.g. baseline-time, spw-freq)
5754-
axis_vals = {}
5755-
for axis, overlap_params in axis_overlap_params.items():
5763+
# build this/other arrays for checking for overlap.
57565764
if len(overlap_params) > 1:
5757-
axis_vals[axis] = {
5758-
"this": getattr(this, axis_combined_func[axis])(),
5759-
"other": getattr(other, axis_combined_func[axis])(),
5765+
axis_key_arrays[axis] = {
5766+
"this": getattr(this, axis_key_func[axis])(),
5767+
"other": getattr(other, axis_key_func[axis])(),
57605768
}
57615769
else:
5762-
axis_vals[axis] = {
5770+
axis_key_arrays[axis] = {
57635771
"this": getattr(this, overlap_params[0]),
57645772
"other": getattr(other, overlap_params[0]),
57655773
}
57665774

5767-
# Check if we have overlapping data
5768-
axis_inds = {}
5769-
for axis, val_arr in axis_vals.items():
5775+
# Check if we have overlapping data
57705776
both_inds, this_inds, other_inds = np.intersect1d(
5771-
val_arr["this"], val_arr["other"], return_indices=True
5777+
axis_key_arrays[axis]["this"],
5778+
axis_key_arrays[axis]["other"],
5779+
return_indices=True,
57725780
)
5773-
axis_inds[axis] = {
5781+
axis_overlap_inds[axis] = {
57745782
"this": this_inds,
57755783
"other": other_inds,
57765784
"both": both_inds,
57775785
}
57785786

57795787
history_update_string = ""
57805788

5781-
if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
5789+
if np.all(
5790+
[len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds]
5791+
):
57825792
# We have overlaps, check that overlapping data is not valid
5783-
multi_axis_params = this._get_multi_axis_params()
57845793
this_test = []
57855794
other_test = []
57865795
for param in multi_axis_params:
@@ -5795,10 +5804,14 @@ def __add__(
57955804
for ax_ind, axis in enumerate(form):
57965805
expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
57975806
this_index_list.append(
5798-
np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5807+
np.expand_dims(
5808+
axis_overlap_inds[axis]["this"], axis=expand_axes
5809+
)
57995810
)
58005811
other_index_list.append(
5801-
np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5812+
np.expand_dims(
5813+
axis_overlap_inds[axis]["other"], axis=expand_axes
5814+
)
58025815
)
58035816
this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
58045817

@@ -5830,7 +5843,9 @@ def __add__(
58305843
"These objects have overlapping data and cannot be combined."
58315844
)
58325845

5833-
new_inds = {}
5846+
# Now actually find which axes are going to be added along
5847+
# other_inds_use have the indices in other that will be added to this
5848+
other_inds_use = {}
58345849
additions = []
58355850
axis_descriptions = {
58365851
"Nblts": "baseline-time",
@@ -5840,30 +5855,30 @@ def __add__(
58405855
# find the indices in "other" but not in "this"
58415856
for axis in axes:
58425857
temp = np.nonzero(
5843-
~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"])
5858+
~np.isin(axis_key_arrays[axis]["other"], axis_key_arrays[axis]["this"])
58445859
)[0]
58455860
if len(temp) > 0:
5846-
new_inds[axis] = temp
5861+
other_inds_use[axis] = temp
58475862
# add params associated with the other axes to compatibility_params
58485863
for axis2 in axes:
58495864
if axis2 != axis:
5850-
compatibility_params.extend(axis_params_check[axis2])
5865+
compatibility_params.extend(axis_check_params[axis2])
58515866
additions.append(axis_descriptions[axis])
58525867
else:
5853-
new_inds[axis] = []
5868+
other_inds_use[axis] = []
58545869

58555870
# Actually check compatibility parameters
58565871
for cp in compatibility_params:
58575872
params_match = None
5858-
for axis, check_list in axis_params_check.items():
5873+
for axis, check_list in axis_check_params.items():
58595874
if cp in check_list:
58605875
# only check that overlapping indices match
58615876
this_param = getattr(this, cp)
58625877
this_param_overlap = this_param.get_from_form(
5863-
{axis: axis_inds[axis]["this"]}
5878+
{axis: axis_overlap_inds[axis]["this"]}
58645879
)
58655880
other_param_overlap = getattr(other, cp).get_from_form(
5866-
{axis: axis_inds[axis]["other"]}
5881+
{axis: axis_overlap_inds[axis]["other"]}
58675882
)
58685883
params_match = np.allclose(
58695884
this_param_overlap,
@@ -5895,7 +5910,7 @@ def __add__(
58955910
"Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
58965911
"Npols": {"method": "reorder_pols", "parameter": "order"},
58975912
}
5898-
for axis, ind_dict in axis_inds.items():
5913+
for axis, ind_dict in axis_overlap_inds.items():
58995914
if len(ind_dict["this"]) != 0:
59005915
# there is some overlap, so check sorting
59015916
this_argsort = np.argsort(ind_dict["this"])
@@ -5910,16 +5925,22 @@ def __add__(
59105925

59115926
getattr(this, reorder_method[axis]["method"])(**kwargs)
59125927

5913-
# start updating parameters
5914-
new_axis_inds = {}
5928+
# checks are all done, start updating parameters
5929+
# combined_key_arrays has the final key arrays after adding.
5930+
combined_key_arrays = {}
5931+
# order_dict has info about how to sort each axis. Initialize to None
5932+
# for axes that are not added along (so do not need sorting)
59155933
order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
59165934
for axis in axes:
5917-
if len(new_inds[axis]) > 0:
5918-
new_axis_inds[axis] = np.concatenate(
5919-
(axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5935+
if len(other_inds_use[axis]) > 0:
5936+
combined_key_arrays[axis] = np.concatenate(
5937+
(
5938+
axis_key_arrays[axis]["this"],
5939+
axis_key_arrays[axis]["other"][other_inds_use[axis]],
5940+
)
59205941
)
59215942
if axis == "Npols":
5922-
order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis]))
5943+
order_dict[axis] = np.argsort(np.abs(combined_key_arrays[axis]))
59235944
elif axis == "Nfreqs" and (
59245945
np.any(np.diff(this.freq_array) < 0)
59255946
or np.any(np.diff(other.freq_array) < 0)
@@ -5930,38 +5951,39 @@ def __add__(
59305951
np.concatenate(
59315952
(
59325953
this.flex_spw_id_array,
5933-
other.flex_spw_id_array[new_inds[axis]],
5954+
other.flex_spw_id_array[other_inds_use[axis]],
59345955
)
59355956
),
59365957
np.concatenate(
5937-
(this.freq_array, other.freq_array[new_inds[axis]])
5958+
(this.freq_array, other.freq_array[other_inds_use[axis]])
59385959
),
59395960
)
59405961
else:
5941-
order_dict[axis] = np.argsort(new_axis_inds[axis])
5962+
order_dict[axis] = np.argsort(combined_key_arrays[axis])
59425963

5943-
# first handle parameters with a single axis
5944-
_axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
5964+
# first handle parameters with a single named axis
5965+
_axis_add_helper(
5966+
this, other, axis, other_inds_use[axis], order_dict[axis]
5967+
)
59455968

59465969
# then pad out parameters with multiple axes
5947-
_axis_pad_helper(this, axis, len(new_inds[axis]))
5970+
_axis_pad_helper(this, axis, len(other_inds_use[axis]))
59485971
else:
5949-
new_axis_inds[axis] = axis_vals[axis]["this"]
5972+
# no add along this axis, so it's the same as what's already on this
5973+
combined_key_arrays[axis] = axis_key_arrays[axis]["this"]
59505974

59515975
# Now fill in multidimensional parameters
5976+
# t2o_dict has the mapping of where arrays on other get mapped into
5977+
# this after padding
59525978
t2o_dict = {}
5953-
for axis, inds_dict in axis_vals.items():
5979+
for axis, inds_dict in axis_key_arrays.items():
59545980
t2o_dict[axis] = np.nonzero(
5955-
np.isin(new_axis_inds[axis], inds_dict["other"])
5981+
np.isin(combined_key_arrays[axis], inds_dict["other"])
59565982
)[0]
59575983

5958-
sort_axes = []
5959-
for axis in axes:
5960-
if len(new_inds[axis]) > 0:
5961-
sort_axes.append(axis)
5962-
_fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5984+
_fill_multi_helper(this, other, t2o_dict, order_dict)
59635985

5964-
if len(new_inds["Nfreqs"]) > 0:
5986+
if len(other_inds_use["Nfreqs"]) > 0:
59655987
# We want to preserve per-spw information based on first appearance
59665988
# in the concatenated array.
59675989
unique_index = np.sort(
@@ -6015,7 +6037,7 @@ def __add__(
60156037
)
60166038

60176039
# Reset blt_order if blt axis was added to
6018-
if "Nblts" in sort_axes:
6040+
if order_dict["Nblts"] is not None:
60196041
this.blt_order = ("time", "baseline")
60206042

60216043
this.set_rectangularity(force=True)
@@ -6231,18 +6253,18 @@ def fast_concat(
62316253

62326254
# identify params that are not explicitly included in overlap calc per axis
62336255
axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6234-
axis_params_check = {}
6235-
axis_dict = {}
6256+
axis_check_params = {}
6257+
axis_parameters = {}
62366258
for axis2, ax_shape in axis_shape.items():
6237-
axis_dict[axis2] = this._get_param_axis(ax_shape)
6238-
axis_params_check[axis2] = []
6239-
for param in axis_dict[axis2]:
6259+
axis_parameters[axis2] = this._get_param_axis(ax_shape)
6260+
axis_check_params[axis2] = []
6261+
for param in axis_parameters[axis2]:
62406262
if param not in this._data_params:
6241-
axis_params_check[axis2].append("_" + param)
6263+
axis_check_params[axis2].append("_" + param)
62426264

62436265
for axis2 in axis_shape:
62446266
if axis2 != axis:
6245-
compatibility_params.extend(axis_params_check[axis2])
6267+
compatibility_params.extend(axis_check_params[axis2])
62466268

62476269
axis_descriptions = {
62486270
"blt": "baseline-time",

0 commit comments

Comments
 (0)