Skip to content

Commit 48c3ed8

Browse files
EliEli
authored andcommitted
Fixed read_vtide and got rid of stray matplotlib import for a defunct __main__
1 parent 1332fa0 commit 48c3ed8

File tree

4 files changed

+110
-41
lines changed

4 files changed

+110
-41
lines changed

dms_datastore/read_multi.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import os
55
import pandas as pd
6-
import matplotlib.pyplot as plt
76
from dms_datastore.read_ts import read_ts, read_yaml_header
87
from dms_datastore import dstore_config
98
from dms_datastore.filename import *
@@ -371,25 +370,4 @@ def ts_multifile_read(
371370
return full
372371

373372

374-
if __name__ == "__main__":
375-
# NCRO example
376-
377-
dirname = "//cnrastore-bdo/Modeling_Data/continuous_station_repo_beta/formatted_1yr"
378-
rpats = ["ncro_gle_b9532000_temp*.csv", "cdec_gle*temp*.csv"]
379-
pats = [os.path.join(dirname, p) for p in rpats]
380-
ts = ts_multifile(pats)
381-
print(ts)
382-
ts.plot()
383-
plt.show()
384-
385-
# Example for USGS
386-
# usgs_list = ['lib','ucs','srv','dsj','dws','sdi','fpt','lps','mld','sjj','sjg']
387-
# for nseries in usgs_list:
388-
# print(nseries)
389-
#
390-
# dirname = "//cnrastore-bdo/Modeling_Data/continuous_station_repo/raw/"
391-
# pat = os.path.join(dirname,f"usgs_{nseries}_*turbidity_*.rdb")
392-
# ts = ts_multifile_read(pat,column_name=nseries)
393-
# print(ts)
394-
# ts.plot()
395-
# plt.show()
373+

dms_datastore/read_ts.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,22 +1248,76 @@ def vtide_date_parser(*args):
12481248
return dtm.datetime.strptime(x, "%Y%m%dT%H%M")
12491249

12501250

1251-
def read_vtide(fpath_pattern, start=None, end=None, selector=None, force_regular=False, freq=None, **kwargs):
1252-
ts = csv_retrieve_ts(
1253-
fpath_pattern,
1254-
start,
1255-
end,
1256-
force_regular,
1257-
selector=selector,
1258-
format_compatible_fn=lambda x: True,
1259-
qaqc_selector=None,
1260-
parsedates=[0, 1],
1261-
indexcol=0,
1262-
header=None,
1263-
sep=r"\\s+",
1264-
comment="#",
1251+
def read_vtide_file(path, comment="#", sep=r"\s+", header=None, **kwargs):
1252+
"""Read a single no-header vtide text file with fixed date/time columns."""
1253+
dset = pd.read_csv(
1254+
path,
1255+
sep=sep,
1256+
header=header,
1257+
comment=comment,
1258+
dtype={0: str, 1: str},
1259+
**kwargs,
12651260
)
12661261

1262+
if dset.shape[1] < 2:
1263+
raise ValueError(f"Vtide file {path} must contain at least date/time columns")
1264+
1265+
date_part = dset[0].astype(str).str.strip()
1266+
time_part = dset[1].astype(str).str.strip()
1267+
1268+
if not time_part.str.contains(":").all():
1269+
time_part = time_part.str.zfill(4).str.replace(r"^(\d{2})(\d{2})$", r"\1:\2", regex=True)
1270+
1271+
dt_str = date_part + "T" + time_part
1272+
1273+
try:
1274+
idx = pd.to_datetime(dt_str, format="%Y%m%dT%H%M", errors="raise")
1275+
except ValueError:
1276+
try:
1277+
idx = pd.to_datetime(dt_str, format="%Y-%m-%dT%H:%M", errors="raise")
1278+
except ValueError:
1279+
idx = pd.to_datetime(dt_str, errors="raise")
1280+
1281+
dset.index = idx
1282+
dset.index.name = "datetime"
1283+
dset = dset.drop(columns=[0, 1])
1284+
1285+
return dset
1286+
1287+
1288+
def read_vtide(fpath_pattern, start=None, end=None, selector=None, force_regular=False, freq=None, **kwargs):
1289+
contains_glob = any(ch in fpath_pattern for ch in "*?[]")
1290+
1291+
if contains_glob:
1292+
ts = csv_retrieve_ts(
1293+
fpath_pattern,
1294+
start,
1295+
end,
1296+
force_regular,
1297+
selector=selector,
1298+
format_compatible_fn=lambda x: True,
1299+
qaqc_selector=None,
1300+
parsedates=None,
1301+
indexcol=None,
1302+
header=None,
1303+
sep=r"\s+",
1304+
comment="#",
1305+
**kwargs,
1306+
)
1307+
1308+
if ts is not None and isinstance(ts, pd.DataFrame) and 0 in ts.columns and 1 in ts.columns:
1309+
dt_str = ts[0].astype(str).str.strip() + "T" + ts[1].astype(str).str.zfill(4).str.strip()
1310+
ts.index = pd.to_datetime(dt_str, format="%Y%m%dT%H%M", errors="raise")
1311+
ts.index.name = "datetime"
1312+
ts = ts.drop(columns=[0, 1])
1313+
1314+
return ts
1315+
1316+
ts = read_vtide_file(fpath_pattern, **kwargs)
1317+
1318+
if selector is not None:
1319+
ts = ts[selector] if isinstance(selector, (list, tuple)) else ts[[selector]]
1320+
12671321
return ts
12681322

12691323

dms_datastore/write_ts.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ def write_ts_csv(
118118
**kwargs : other
119119
Other items that will be passed to write_csv
120120
"""
121+
# Series support: convert to single-column DataFrame while preserving the index and column name.
122+
if isinstance(ts, pd.Series):
123+
col_name = ts.name if ts.name is not None else "value"
124+
ts = ts.to_frame(name=col_name)
125+
121126
former_index = ts.index.name
122127
if former_index != "datetime" and not overwrite_conventions:
123128
# warnings.warn("Index will be renamed datetime in file according to specification. Copy made")
@@ -139,10 +144,17 @@ def write_ts_csv(
139144
s = max(pd.Timestamp(bnd[0], 1, 1), ts.first_valid_index())
140145
e = min(pd.Timestamp(bnd[1], 12, 31, 23, 59, 59), ts.last_valid_index())
141146
tssub = ts.loc[s:e]
142-
if (
143-
tssub.count() < 16
144-
).all(): # require 15 values per column. all() is for multiple columns
145-
continue
147+
148+
count = tssub.count()
149+
if hasattr(count, "all"):
150+
# DataFrame path: all columns should have at least 16 values
151+
if not (count >= 16).all():
152+
continue
153+
else:
154+
# Series path: count is scalar
155+
if count < 16:
156+
continue
157+
146158
new_date_range_str = f"{bnd[0]}_{bnd[1]}"
147159

148160
if single_year_label:

tests/test_write_ts.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,28 @@ def test_write_ts_csv_stringio_with_metadata(sample_ts):
5959
assert "station_id: ABC" in contents
6060
assert "units: ft" in contents
6161
assert "value" in contents
62+
63+
64+
def test_write_ts_csv_stringio_series():
65+
s = pd.Series(
66+
[1.1, 2.2, 3.3, 4.4, 5.5],
67+
index=pd.date_range("2020-01-01", periods=5, freq="h", name="datetime"),
68+
name="value",
69+
)
70+
buf = io.StringIO()
71+
write_ts_csv(s, buf)
72+
73+
contents = buf.getvalue()
74+
assert "# format: dwr-dms-1.0" in contents
75+
assert "datetime" in contents
76+
assert "value" in contents
77+
78+
buf.seek(0)
79+
lines = [line for line in buf if not line.startswith("#")]
80+
roundtrip = pd.read_csv(
81+
io.StringIO("".join(lines)),
82+
index_col="datetime",
83+
parse_dates=True,
84+
)
85+
assert list(roundtrip.index) == list(s.index)
86+
assert list(roundtrip["value"]) == list(s.values)

0 commit comments

Comments
 (0)