Skip to content

Commit 7600498

Browse files
committed
start work on making add and fast_concat use forms
1 parent 25ae582 commit 7600498

3 files changed

Lines changed: 101 additions & 127 deletions

File tree

src/pyuvdata/uvbase.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,47 @@ def copy(self):
754754
"""
755755
return copy.deepcopy(self)
756756

757+
def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
758+
"""
759+
Get a mapping of parameters that have a given axis to the axis number.
760+
761+
Parameters
762+
----------
763+
axis_name : str
764+
A named parameter within the object (e.g., "Nblts", "Ntimes", "Nants").
765+
single_named_axis : bool
766+
Option to only include parameters with a single named axis.
767+
768+
Returns
769+
-------
770+
dict
771+
The keys are UVParameter names that have an axis with axis_name
772+
(axis_name appears in their form). The values are a list of the axis
773+
indices where axis_name appears in their form.
774+
"""
775+
ret_dict = {}
776+
for param in self:
777+
# For each attribute, if the value is None, then bail, otherwise
778+
# attempt to figure out along which axis ind_arr will apply.
779+
780+
attr = getattr(self, param)
781+
if (
782+
attr.value is not None
783+
and isinstance(attr.form, tuple)
784+
and axis_name in attr.form
785+
):
786+
if (
787+
single_named_axis
788+
and sum([isinstance(entry, str) for entry in attr.form]) > 1
789+
):
790+
continue
791+
792+
# Only look at where form is a tuple, since that's the only case we
793+
# can have a dynamically defined shape. Note that index doesn't work
794+
# here in the case of a repeated param_name in the form.
795+
ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0]
796+
return ret_dict
797+
757798
def _select_along_param_axis(self, param_dict: dict):
758799
"""
759800
Downselect values along a parameterized axis.

src/pyuvdata/uvdata/uvdata.py

Lines changed: 54 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,43 @@
4040
)
4141

4242

43+
def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None):
44+
update_params = this._get_param_axis(axis_name, single_named_axis=True)
45+
other_form_dict = {axis_name: other_inds}
46+
for param, axis_list in update_params.items():
47+
axis = axis_list[0]
48+
new_array = np.concatenate(
49+
[
50+
getattr(this, param),
51+
getattr(other, "_" + param).get_from_form(other_form_dict),
52+
],
53+
axis=axis,
54+
)
55+
if param == "scan_number_array":
56+
print()
57+
print(new_array)
58+
if final_order is not None:
59+
new_array = np.take(new_array, final_order, axis=axis)
60+
if param == "scan_number_array":
61+
print(new_array)
62+
63+
setattr(this, param, new_array)
64+
65+
66+
def _axis_fast_concat_helper(this, other, axis_name: str):
67+
update_params = this._get_param_axis(axis_name)
68+
for param, axis_list in update_params.items():
69+
axis = axis_list[0]
70+
setattr(
71+
this,
72+
param,
73+
np.concatenate(
74+
[getattr(this, param)] + [getattr(obj, param) for obj in other],
75+
axis=axis,
76+
),
77+
)
78+
79+
4380
class UVData(UVBase):
4481
"""
4582
A class for defining a radio interferometer dataset.
@@ -5549,6 +5586,12 @@ def __add__(
55495586
"_uvw_array",
55505587
]
55515588
compatibility_params.extend(extra_params)
5589+
# TODO: make this list programmatically if possible?
5590+
if (
5591+
this.scan_number_array is not None
5592+
or other.scan_number_array is not None
5593+
):
5594+
compatibility_params.append("_scan_number_array")
55525595

55535596
# find the freq indices in "other" but not in "this"
55545597
if (this.flex_spw_polarization_array is None) != (
@@ -5626,6 +5669,7 @@ def __add__(
56265669
"_phase_center_app_dec",
56275670
"_phase_center_frame_pa",
56285671
"_phase_center_id_array",
5672+
"_scan_number_array",
56295673
]
56305674
for cp in compatibility_params:
56315675
if cp in blt_inds_params:
@@ -5712,6 +5756,9 @@ def __add__(
57125756
if len(bnew_inds) > 0:
57135757
this_blts = np.concatenate((this_blts, new_blts))
57145758
blt_order = np.argsort(this_blts)
5759+
5760+
_axis_add_helper(this, other, "Nblts", bnew_inds, blt_order)
5761+
57155762
if not self.metadata_only:
57165763
zero_pad = np.zeros((len(bnew_inds), this.Nfreqs, this.Npols))
57175764
this.data_array = np.concatenate([this.data_array, zero_pad], axis=0)
@@ -5721,53 +5768,11 @@ def __add__(
57215768
this.flag_array = np.concatenate(
57225769
[this.flag_array, 1 - zero_pad], axis=0
57235770
).astype(np.bool_)
5724-
this.uvw_array = np.concatenate(
5725-
[this.uvw_array, other.uvw_array[bnew_inds, :]], axis=0
5726-
)[blt_order, :]
5727-
this.time_array = np.concatenate(
5728-
[this.time_array, other.time_array[bnew_inds]]
5729-
)[blt_order]
5730-
this.integration_time = np.concatenate(
5731-
[this.integration_time, other.integration_time[bnew_inds]]
5732-
)[blt_order]
5733-
this.lst_array = np.concatenate(
5734-
[this.lst_array, other.lst_array[bnew_inds]]
5735-
)[blt_order]
5736-
this.ant_1_array = np.concatenate(
5737-
[this.ant_1_array, other.ant_1_array[bnew_inds]]
5738-
)[blt_order]
5739-
this.ant_2_array = np.concatenate(
5740-
[this.ant_2_array, other.ant_2_array[bnew_inds]]
5741-
)[blt_order]
5742-
this.baseline_array = np.concatenate(
5743-
[this.baseline_array, other.baseline_array[bnew_inds]]
5744-
)[blt_order]
5745-
this.phase_center_app_ra = np.concatenate(
5746-
[this.phase_center_app_ra, other.phase_center_app_ra[bnew_inds]]
5747-
)[blt_order]
5748-
this.phase_center_app_dec = np.concatenate(
5749-
[this.phase_center_app_dec, other.phase_center_app_dec[bnew_inds]]
5750-
)[blt_order]
5751-
this.phase_center_frame_pa = np.concatenate(
5752-
[this.phase_center_frame_pa, other.phase_center_frame_pa[bnew_inds]]
5753-
)[blt_order]
5754-
this.phase_center_id_array = np.concatenate(
5755-
[this.phase_center_id_array, other.phase_center_id_array[bnew_inds]]
5756-
)[blt_order]
57575771

57585772
f_order = None
57595773
if len(fnew_inds) > 0:
5760-
this.freq_array = np.concatenate(
5761-
[this.freq_array, other.freq_array[fnew_inds]]
5762-
)
5763-
this.channel_width = np.concatenate(
5764-
[this.channel_width, other.channel_width[fnew_inds]]
5765-
)
5774+
_axis_add_helper(this, other, "Nfreqs", fnew_inds)
57665775

5767-
this.flex_spw_id_array = np.concatenate(
5768-
[this.flex_spw_id_array, other.flex_spw_id_array[fnew_inds]]
5769-
)
5770-
this.spw_array = np.concatenate([this.spw_array, other.spw_array])
57715776
# We want to preserve per-spw information based on first appearance
57725777
# in the concatenated array.
57735778
unique_index = np.sort(
@@ -5814,9 +5819,8 @@ def __add__(
58145819

58155820
p_order = None
58165821
if len(pnew_inds) > 0:
5817-
this.polarization_array = np.concatenate(
5818-
[this.polarization_array, other.polarization_array[pnew_inds]]
5819-
)
5822+
_axis_add_helper(this, other, "Npols", pnew_inds)
5823+
58205824
p_order = np.argsort(np.abs(this.polarization_array))
58215825
if not self.metadata_only:
58225826
zero_pad = np.zeros(
@@ -6198,18 +6202,8 @@ def fast_concat(
61986202

61996203
if axis == "freq":
62006204
this.Nfreqs = sum([this.Nfreqs] + [obj.Nfreqs for obj in other])
6201-
this.freq_array = np.concatenate(
6202-
[this.freq_array] + [obj.freq_array for obj in other]
6203-
)
6204-
this.channel_width = np.concatenate(
6205-
[this.channel_width] + [obj.channel_width for obj in other]
6206-
)
6207-
this.flex_spw_id_array = np.concatenate(
6208-
[this.flex_spw_id_array] + [obj.flex_spw_id_array for obj in other]
6209-
)
6210-
this.spw_array = np.concatenate(
6211-
[this.spw_array] + [obj.spw_array for obj in other]
6212-
)
6205+
_axis_fast_concat_helper(this, other, "Nfreqs")
6206+
_axis_fast_concat_helper(this, other, "Nspws")
62136207
# We want to preserve per-spw information based on first appearance
62146208
# in the concatenated array.
62156209
unique_index = np.sort(
@@ -6219,83 +6213,16 @@ def fast_concat(
62196213

62206214
this.Nspws = len(this.spw_array)
62216215

6222-
if not self.metadata_only:
6223-
this.data_array = np.concatenate(
6224-
[this.data_array] + [obj.data_array for obj in other], axis=1
6225-
)
6226-
this.nsample_array = np.concatenate(
6227-
[this.nsample_array] + [obj.nsample_array for obj in other], axis=1
6228-
)
6229-
this.flag_array = np.concatenate(
6230-
[this.flag_array] + [obj.flag_array for obj in other], axis=1
6231-
)
62326216
elif axis == "polarization":
6233-
this.polarization_array = np.concatenate(
6234-
[this.polarization_array] + [obj.polarization_array for obj in other]
6235-
)
6217+
_axis_fast_concat_helper(this, other, "Npols")
62366218
this.Npols = sum([this.Npols] + [obj.Npols for obj in other])
62376219

6238-
if not self.metadata_only:
6239-
this.data_array = np.concatenate(
6240-
[this.data_array] + [obj.data_array for obj in other], axis=2
6241-
)
6242-
this.nsample_array = np.concatenate(
6243-
[this.nsample_array] + [obj.nsample_array for obj in other], axis=2
6244-
)
6245-
this.flag_array = np.concatenate(
6246-
[this.flag_array] + [obj.flag_array for obj in other], axis=2
6247-
)
62486220
elif axis == "blt":
62496221
this.Nblts = sum([this.Nblts] + [obj.Nblts for obj in other])
6250-
this.ant_1_array = np.concatenate(
6251-
[this.ant_1_array] + [obj.ant_1_array for obj in other]
6252-
)
6253-
this.ant_2_array = np.concatenate(
6254-
[this.ant_2_array] + [obj.ant_2_array for obj in other]
6255-
)
6222+
_axis_fast_concat_helper(this, other, "Nblts")
62566223
this.Nants_data = this._calc_nants_data()
6257-
this.uvw_array = np.concatenate(
6258-
[this.uvw_array] + [obj.uvw_array for obj in other], axis=0
6259-
)
6260-
this.time_array = np.concatenate(
6261-
[this.time_array] + [obj.time_array for obj in other]
6262-
)
62636224
this.Ntimes = len(np.unique(this.time_array))
6264-
this.lst_array = np.concatenate(
6265-
[this.lst_array] + [obj.lst_array for obj in other]
6266-
)
6267-
this.baseline_array = np.concatenate(
6268-
[this.baseline_array] + [obj.baseline_array for obj in other]
6269-
)
62706225
this.Nbls = len(np.unique(this.baseline_array))
6271-
this.integration_time = np.concatenate(
6272-
[this.integration_time] + [obj.integration_time for obj in other]
6273-
)
6274-
this.phase_center_app_ra = np.concatenate(
6275-
[this.phase_center_app_ra] + [obj.phase_center_app_ra for obj in other]
6276-
)
6277-
this.phase_center_app_dec = np.concatenate(
6278-
[this.phase_center_app_dec]
6279-
+ [obj.phase_center_app_dec for obj in other]
6280-
)
6281-
this.phase_center_frame_pa = np.concatenate(
6282-
[this.phase_center_frame_pa]
6283-
+ [obj.phase_center_frame_pa for obj in other]
6284-
)
6285-
this.phase_center_id_array = np.concatenate(
6286-
[this.phase_center_id_array]
6287-
+ [obj.phase_center_id_array for obj in other]
6288-
)
6289-
if not self.metadata_only:
6290-
this.data_array = np.concatenate(
6291-
[this.data_array] + [obj.data_array for obj in other], axis=0
6292-
)
6293-
this.nsample_array = np.concatenate(
6294-
[this.nsample_array] + [obj.nsample_array for obj in other], axis=0
6295-
)
6296-
this.flag_array = np.concatenate(
6297-
[this.flag_array] + [obj.flag_array for obj in other], axis=0
6298-
)
62996226

63006227
# update filename attribute
63016228
for obj in other:

tests/uvdata/test_uvdata.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3385,6 +3385,7 @@ def test_sum_vis_errors(hera_uvh5, attr_to_get, attr_to_set, arg_dict, msg):
33853385
@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values")
33863386
def test_add_freq(casa_uvfits):
33873387
uv_full = casa_uvfits
3388+
uv_full.scan_number_array = np.arange(uv_full.Nblts)
33883389

33893390
uv1 = uv_full.select(freq_chans=np.arange(0, 32), inplace=False)
33903391
uv2 = uv_full.select(freq_chans=np.arange(32, 64), inplace=False)
@@ -3415,6 +3416,7 @@ def test_add_freq(casa_uvfits):
34153416
@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values")
34163417
def test_add_pols(casa_uvfits):
34173418
uv_full = casa_uvfits
3419+
uv_full.scan_number_array = np.arange(uv_full.Nblts)
34183420

34193421
uv1 = uv_full.select(polarizations=uv_full.polarization_array[0:2], inplace=False)
34203422
uv2 = uv_full.select(polarizations=uv_full.polarization_array[2:4], inplace=False)
@@ -3453,6 +3455,7 @@ def test_add_pols(casa_uvfits):
34533455
@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values")
34543456
def test_add_times(casa_uvfits):
34553457
uv_full = casa_uvfits
3458+
uv_full.scan_number_array = np.arange(uv_full.Nblts)
34563459

34573460
times = np.unique(uv_full.time_array)
34583461
uv1 = uv_full.select(times=times[0 : len(times) // 2], inplace=False)
@@ -3474,6 +3477,7 @@ def test_add_times(casa_uvfits):
34743477
@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values")
34753478
def test_add_bls(casa_uvfits):
34763479
uv_full = casa_uvfits
3480+
uv_full.reorder_blts()
34773481
uv_full.scan_number_array = np.arange(uv_full.Nblts)
34783482

34793483
ant_list = list(range(15)) # Roughly half the antennas in the data
@@ -3515,6 +3519,7 @@ def test_add_bls(casa_uvfits):
35153519
uv3.ant_1_array = uv3.ant_1_array[-1::-1]
35163520
uv3.ant_2_array = uv3.ant_2_array[-1::-1]
35173521
uv3.baseline_array = uv3.baseline_array[-1::-1]
3522+
uv3.scan_number_array = uv3.scan_number_array[-1::-1]
35183523
uv1 += uv3
35193524
uv1 += uv2
35203525
assert utils.history._check_histories(
@@ -3533,6 +3538,7 @@ def test_add_bls(casa_uvfits):
35333538
@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values")
35343539
def test_add_multi_axis(casa_uvfits):
35353540
uv_full = casa_uvfits
3541+
uv_full.scan_number_array = np.arange(uv_full.Nblts)
35363542

35373543
uv_ref = uv_full.copy()
35383544
times = np.unique(uv_full.time_array)

0 commit comments

Comments
 (0)