@@ -5991,9 +5991,20 @@ def fast_concat(
59915991 self and other are not compatible.
59925992
59935993 """
5994- allowed_axes = ["blt", "freq", "polarization"]
5995- if axis not in allowed_axes:
5996- raise ValueError("Axis must be one of: " + ", ".join(allowed_axes))
5994+ # setup a dict to carry all the axis-specific info we need throughout
5995+ # the fast concat process:
5996+ # - description is used in history string
5997+ # - shape: the shape name parameter (e.g. "Nblts", "Nfreqs", "Npols")
5998+ # ---added later----
5999+ # - check_params gives parameters that should be checked if adding
6000+ # along other axes
6001+ axis_info = {
6002+ "blt": {"description": "baseline-time", "shape": "Nblts"},
6003+ "freq": {"description": "frequency", "shape": "Nfreqs"},
6004+ "polarization": {"description": "polarization", "shape": "Npols"},
6005+ }
6006+ if axis not in axis_info:
6007+ raise ValueError("Axis must be one of: " + ", ".join(axis_info))
59976008
59986009 if inplace:
59996010 this = self
@@ -6043,28 +6054,23 @@ def fast_concat(
60436054
60446055 history_update_string = " Combined data along "
60456056
6046- # identify params that are not explicitly included in overlap calc per axis
6047- axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6048- axis_check_params = {}
6049- axis_parameters = {}
6050- for axis2, ax_shape in axis_shape.items():
6051- axis_parameters[axis2] = this._get_param_axis(ax_shape)
6052- axis_check_params[axis2] = []
6053- for param in axis_parameters[axis2]:
6054- if param not in this._data_params:
6055- axis_check_params[axis2].append("_" + param)
6056-
6057- for axis2 in axis_shape:
6058- if axis2 != axis:
6059- compatibility_params.extend(axis_check_params[axis2])
6057+ # figure out what parameters to check for compatibility -- only worry
6058+ # about single axis params
6059+ for _, info in axis_info.items():
6060+ params_this_axis = this._get_param_axis(
6061+ info["shape"], single_named_axis=True
6062+ )
6063+ info["check_params"] = []
6064+ for param in params_this_axis:
6065+ info["check_params"].append("_" + param)
60606066
6061- axis_descriptions = {
6062- "blt": "baseline-time",
6063- "freq": "frequency",
6064- "polarization": "polarization",
6065- }
6067+ for axis2, info in axis_info.items():
6068+ if axis2 != axis:
6069+ compatibility_params.extend(info["check_params"])
60666070
6067- history_update_string += f" {axis_descriptions[axis]} axis using pyuvdata."
6071+ history_update_string += (
6072+ f" {axis_info[axis]['description']} axis using pyuvdata."
6073+ )
60686074
60696075 histories_match = []
60706076 for obj in other:
@@ -6103,13 +6109,15 @@ def fast_concat(
61036109
61046110 this.telescope = tel_obj
61056111
6112+ # actually do the concat
6113+ this._axis_fast_concat_helper(other, axis_info[axis]["shape"])
6114+
61066115 # update the relevant shape parameter
6107- this._axis_fast_concat_helper(other, axis_shape[axis])
61086116 new_shape = sum(
6109- [getattr(this, axis_shape [axis])]
6110- + [getattr(obj, axis_shape [axis]) for obj in other]
6117+ [getattr(this, axis_info [axis]["shape" ])]
6118+ + [getattr(obj, axis_info [axis]["shape" ]) for obj in other]
61116119 )
6112- setattr(this, axis_shape [axis], new_shape)
6120+ setattr(this, axis_info [axis]["shape" ], new_shape)
61136121
61146122 if axis == "freq":
61156123 # We want to preserve per-spw information based on first appearance
0 commit comments