Skip to content

Commit 0a84f91

Browse files
EliEli
authored andcommitted
Bug fixes having to do with series vs DataFrames
1 parent 73910eb commit 0a84f91

3 files changed

Lines changed: 96 additions & 30 deletions

File tree

vtools/data/timeseries.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,19 @@ def extrapolate_ts(ts, start=None, end=None, method="ffill", val=None):
180180

181181
raise ValueError("'ffill' not allowed when extending before start of data")
182182
ts_full.loc[ts.index[-1] :] = ts_full.loc[ts.index[-1] :].ffill()
183-
return ts_full.astype(ts.dtype)
183+
if isinstance(ts, pd.Series):
184+
return ts_full.astype(ts.dtype)
185+
else:
186+
return ts_full.astype(ts.dtypes.to_dict())
184187

185188
elif method == "bfill":
186189
if end > ts.index[-1]:
187190
raise ValueError("'bfill' not allowed when extending after end of data")
188191
ts_full.loc[: ts.index[0]] = ts_full.loc[: ts.index[0]].bfill()
189-
return ts_full.astype(ts.dtype)
192+
if isinstance(ts, pd.Series):
193+
return ts_full.astype(ts.dtype)
194+
else:
195+
return ts_full.astype(ts.dtypes.to_dict())
190196

191197
elif method == "linear_slope":
192198
if val is not None:
@@ -255,7 +261,14 @@ def extrapolate_ts(ts, start=None, end=None, method="ffill", val=None):
255261
if end > ts.index[-1]:
256262
result.loc[result.index > ts.index[-1]] = val
257263

258-
return result.astype(ts.dtype) if not result.isna().any().any() else result
264+
if isinstance(result, pd.Series):
265+
return result.astype(ts.dtype) if not result.isna().any() else result
266+
else:
267+
return (
268+
result.astype(ts.dtypes.to_dict())
269+
if not result.isna().any().any()
270+
else result
271+
)
259272

260273
else:
261274
raise ValueError(f"Unknown method: {method}")

vtools/functions/merge.py

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,70 +51,96 @@ def ts_merge(series, names=None):
5151
--------
5252
ts_splice : Alternative merging method for irregular time series.
5353
"""
54-
# Make defensive copies of the input series.
54+
# Make defensive copies of the input series.
5555
series = [s.copy() for s in series]
5656

5757
if not isinstance(series, (tuple, list)) or len(series) == 0:
58-
raise ValueError("`series` must be a non-empty tuple or list of pandas.Series or pandas.DataFrame.")
58+
raise ValueError(
59+
"`series` must be a non-empty tuple or list of pandas.Series or pandas.DataFrame."
60+
)
5961

6062
# Ensure all input series have an index of the same type.
6163
index_type = type(series[0].index)
6264
if not all(isinstance(s.index, index_type) for s in series):
63-
raise ValueError(f"All input series must have indexes of type {index_type.__name__}.")
65+
raise ValueError(
66+
f"All input series must have indexes of type {index_type.__name__}."
67+
)
6468

6569
if not all(isinstance(s.index, (pd.DatetimeIndex, pd.PeriodIndex)) for s in series):
6670
raise ValueError("All input series must have a DatetimeIndex or PeriodIndex.")
6771

6872
# Determine if frequency can be preserved.
6973
first_freq = series[0].index.freq
70-
same_freq = all(s.index.freq == first_freq for s in series if s.index.freq is not None)
74+
same_freq = all(
75+
s.index.freq == first_freq for s in series if s.index.freq is not None
76+
)
7177

7278
# For mixed types, convert Series to DataFrame.
7379
has_series = any(isinstance(s, pd.Series) for s in series)
7480
has_dataframe = any(isinstance(s, pd.DataFrame) for s in series)
7581
if has_series and has_dataframe:
7682
if isinstance(names, str):
77-
series = [s.to_frame(name=names) if isinstance(s, pd.Series) else s for s in series]
83+
series = [
84+
s.to_frame(name=names) if isinstance(s, pd.Series) else s
85+
for s in series
86+
]
7887
elif names is None:
79-
df_cols = {col for s in series if isinstance(s, pd.DataFrame) for col in s.columns}
88+
df_cols = {
89+
col for s in series if isinstance(s, pd.DataFrame) for col in s.columns
90+
}
8091
for s in series:
8192
if isinstance(s, pd.Series) and s.name not in df_cols:
82-
raise ValueError("Mixed Series and DataFrames require Series names to match DataFrame columns.")
83-
series = [s.to_frame(name=s.name) if isinstance(s, pd.Series) else s for s in series]
93+
raise ValueError(
94+
"Mixed Series and DataFrames require Series names to match DataFrame columns."
95+
)
96+
series = [
97+
s.to_frame(name=s.name) if isinstance(s, pd.Series) else s
98+
for s in series
99+
]
84100

85101
# For DataFrame inputs, validate column consistency.
86102
if isinstance(series[0], pd.DataFrame):
87103
if names is None:
88104
common = set(series[0].columns)
89105
for df in series:
90106
if set(df.columns) != common:
91-
raise ValueError("All input DataFrames must have the same columns when `names` is None.")
107+
raise ValueError(
108+
"All input DataFrames must have the same columns when `names` is None."
109+
)
92110
elif hasattr(names, "__iter__") and not isinstance(names, str):
93111
for df in series:
94112
if not all(col in df.columns for col in names):
95-
raise ValueError(f"An input DataFrame does not contain all specified columns: {names}.")
113+
raise ValueError(
114+
f"An input DataFrame does not contain all specified columns: {names}."
115+
)
96116

97117
# If names is a string, pre-rename each series for consistency.
98118
if names and isinstance(names, str):
99119
series = [
100-
(s.rename(columns={s.columns[0]: names})
101-
if isinstance(s, pd.DataFrame) else s.rename(names))
120+
(
121+
s.rename(columns={s.columns[0]: names})
122+
if isinstance(s, pd.DataFrame)
123+
else s.rename(names)
124+
)
102125
for s in series
103126
]
104127

105128
# Compute the union of all time indices.
106129
full_index = series[0].index
107130
for s in series[1:]:
108131
full_index = full_index.union(s.index, sort=False)
109-
full_index = full_index.sort_values()
132+
full_index = full_index.sort_values()
110133

111134
# Merge series by reindexing and using combine_first.
112135
merged = series[0].reindex(full_index)
113136
for s in series[1:]:
114137
merged = merged.combine_first(s.reindex(full_index))
115138

116139
# If all inputs were univariate, ensure output remains univariate.
117-
univariate = all((s.name is not None if isinstance(s, pd.Series) else s.shape[1] == 1) for s in series)
140+
univariate = all(
141+
(s.name is not None if isinstance(s, pd.Series) else s.shape[1] == 1)
142+
for s in series
143+
)
118144
if univariate and isinstance(merged, pd.DataFrame):
119145
merged = merged.iloc[:, 0]
120146

@@ -132,6 +158,7 @@ def ts_merge(series, names=None):
132158

133159
return merged
134160

161+
135162
def ts_splice(series, names=None, transition="prefer_last", floor_dates=False):
136163
"""
137164
Splice multiple time series together, prioritizing series in patches of time.
@@ -189,27 +216,43 @@ def ts_splice(series, names=None, transition="prefer_last", floor_dates=False):
189216
series = [s.copy() for s in series]
190217

191218
if not isinstance(series, (tuple, list)) or len(series) == 0:
192-
raise ValueError("`series` must be a non-empty tuple or list of pandas.Series or pandas.DataFrame.")
193-
if not all(isinstance(s.index, pd.DatetimeIndex) or isinstance(s.index, pd.PeriodIndex) for s in series):
219+
raise ValueError(
220+
"`series` must be a non-empty tuple or list of pandas.Series or pandas.DataFrame."
221+
)
222+
if not all(
223+
isinstance(s.index, pd.DatetimeIndex) or isinstance(s.index, pd.PeriodIndex)
224+
for s in series
225+
):
194226
raise ValueError("All input series must have a DatetimeIndex or PeriodIndex.")
195227

196228
# Ensure all input series have the same type of index.
197229
index_type = type(series[0].index)
198230
if not all(isinstance(s.index, index_type) for s in series):
199-
raise ValueError(f"All input series must have indexes of type {index_type.__name__}.")
231+
raise ValueError(
232+
f"All input series must have indexes of type {index_type.__name__}."
233+
)
200234

201-
if transition not in ["prefer_last", "prefer_first"] and not isinstance(transition, list):
202-
raise ValueError("`transition` must be 'prefer_last', 'prefer_first', or a list of timestamps.")
235+
if transition not in ["prefer_last", "prefer_first"] and not isinstance(
236+
transition, list
237+
):
238+
raise ValueError(
239+
"`transition` must be 'prefer_last', 'prefer_first', or a list of timestamps."
240+
)
203241

204242
# Determine if frequency can be preserved.
205243
first_freq = series[0].index.freq
206-
same_freq = all(s.index.freq == first_freq for s in series if s.index.freq is not None)
244+
same_freq = all(
245+
s.index.freq == first_freq for s in series if s.index.freq is not None
246+
)
207247

208248
# If names is a string, pre-rename each series for consistency.
209249
if names and isinstance(names, str):
210250
series = [
211-
(s.rename(columns={s.columns[0]: names})
212-
if isinstance(s, pd.DataFrame) else s.rename(names))
251+
(
252+
s.rename(columns={s.columns[0]: names})
253+
if isinstance(s, pd.DataFrame)
254+
else s.rename(names)
255+
)
213256
for s in series
214257
]
215258

@@ -225,19 +268,24 @@ def ts_splice(series, names=None, transition="prefer_last", floor_dates=False):
225268
transition_points = [s.first_valid_index() for s in series[1:]]
226269
duplicate_keep = "last"
227270
if floor_dates:
228-
transition_points = [dt.floor("D") for dt in transition_points]
271+
transition_points = [dt.floor("d") for dt in transition_points]
229272
transition_points = [None] + transition_points + [None]
230273

231274
# Extract sections based on transition points.
232275
sections = []
233-
for ts_obj, start, end in zip(series, transition_points[:-1], transition_points[1:]):
276+
for ts_obj, start, end in zip(
277+
series, transition_points[:-1], transition_points[1:]
278+
):
234279
section = ts_obj.loc[start:end]
235280
if not section.empty:
236281
sections.append(section)
237282
spliced = pd.concat(sections, axis=0).sort_index() if sections else pd.DataFrame()
238283

239284
# If all inputs were univariate, ensure output remains univariate.
240-
univariate = all((s.name is not None if isinstance(s, pd.Series) else s.shape[1] == 1) for s in series)
285+
univariate = all(
286+
(s.name is not None if isinstance(s, pd.Series) else s.shape[1] == 1)
287+
for s in series
288+
)
241289
if univariate and isinstance(spliced, pd.DataFrame):
242290
spliced = spliced.iloc[:, 0]
243291

@@ -272,6 +320,7 @@ def _apply_names(result, names):
272320
result = result[names]
273321
return result
274322

323+
275324
def _reindex_to_continuous(result, first_freq):
276325
"""
277326
Reindex the given result (DataFrame or Series) to a continuous index spanning
@@ -304,7 +353,7 @@ def _reindex_to_continuous(result, first_freq):
304353
continuous_index = result.index # For other types, leave unchanged
305354

306355
result = result.reindex(continuous_index)
307-
356+
308357
if isinstance(result.index, pd.PeriodIndex):
309358
# For PeriodIndex, rebuild the index because .freq is read-only.
310359
result.index = pd.PeriodIndex(result.index, freq=first_freq)
@@ -313,4 +362,4 @@ def _reindex_to_continuous(result, first_freq):
313362
result.index.freq = first_freq
314363
except ValueError:
315364
result.index.freq = None
316-
return result
365+
return result

vtools/functions/transition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ def transition_ts(
124124
if n_after > 0
125125
else ts1.loc[[ts1.index[0]]]
126126
)
127+
print(seg0)
128+
print(seg1)
127129
all_data = pd.concat([seg0, seg1])
130+
all_data = all_data[~all_data.index.duplicated()]
131+
all_data = all_data.sort_index()
128132

129133
if isinstance(ts0, pd.Series):
130134
interp = PchipInterpolator(all_data.index.astype(np.int64), all_data.values)

0 commit comments

Comments
 (0)