Skip to content

Commit 301be59

Browse files
added savings the dataframes, so it's easier to do ad-hoc analysis if i need
1 parent 3936cc6 commit 301be59

1 file changed

Lines changed: 128 additions & 1 deletion

File tree

src/rachel_analysis_utils/nwb_utils.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import glob
55
import pandas as pd
66
import numpy as np
7+
from pathlib import Path
78

89
from aind_dynamic_foraging_data_utils import nwb_utils, enrich_dfs
910
from aind_dynamic_foraging_data_utils import code_ocean_utils as co_utils
@@ -116,7 +117,133 @@ def __str__(self):
116117
def __repr__(self):
117118
return f"{self.session_id}"
118119

119-
120+
def save(self, plot_loc, df_sess = None):
121+
"""
122+
Save dataframe attributes into:
123+
plot_loc / session_id / <attr>.parquet
124+
"""
125+
126+
session_folder = Path(plot_loc) / str(self.session_id)
127+
session_folder.mkdir(parents=True, exist_ok=True)
128+
129+
for attr, val in self.__dict__.items():
130+
if isinstance(val, pd.DataFrame):
131+
# print(f"now saving {attr}")
132+
133+
if attr == "df_events":
134+
val["data"] = val["data"].astype(str)
135+
136+
# df = self.convert_df_to_saveable_format(val)
137+
if attr == "df_trials":
138+
val["side_bias_confidence_interval_low"] = val["side_bias_confidence_interval"].apply(lambda x: x[0])
139+
val["side_bias_confidence_interval_high"] = val["side_bias_confidence_interval"].apply(lambda x: x[1])
140+
val = val.drop(columns=["side_bias_confidence_interval"])
141+
val.to_parquet(session_folder / f"{attr}.parquet", index=False, engine="fastparquet")
142+
143+
return session_folder
144+
145+
@classmethod
146+
def load(cls, session_folder):
147+
"""
148+
Load object from a saved session folder.
149+
"""
150+
151+
session_folder = Path(session_folder)
152+
153+
obj = cls.__new__(cls)
154+
obj.session_id = session_folder.name
155+
obj.nwb_file_loc = None
156+
157+
for file in session_folder.glob("*.parquet"):
158+
setattr(obj, file.stem, pd.read_parquet(file, engine="fastparquet"))
159+
160+
return obj
161+
162+
def save_nwb_list(nwb_list, plot_loc, df_sess=None):
163+
"""
164+
Save a list or list-of-lists of dummy_nwb objects.
165+
166+
Folder structure:
167+
plot_loc/
168+
<subject_id>/
169+
df_sess.parquet (optional)
170+
<session_id>/
171+
<attr>.parquet
172+
"""
173+
174+
# flatten list or list-of-lists
175+
flat_dummy_nwbs = [
176+
nwb
177+
for item in nwb_list
178+
for nwb in (item if isinstance(item, list) else [item])
179+
]
180+
181+
subject_ids = set()
182+
183+
for nwb in flat_dummy_nwbs:
184+
185+
subject_id = str(nwb.session_id).split("_")[0]
186+
subject_ids.add(subject_id)
187+
188+
subject_folder = Path(plot_loc) / subject_id
189+
subject_folder.mkdir(parents=True, exist_ok=True)
190+
191+
# call the class save method
192+
print(f'now saving {nwb.session_id}')
193+
nwb.save(subject_folder)
194+
195+
# save optional df_sess once per subject
196+
if df_sess is not None:
197+
print(f"now saving df_sess")
198+
199+
df_sess.to_csv(
200+
Path(plot_loc) / "df_sess.csv"
201+
)
202+
203+
204+
def load_nwb_list(plot_loc):
205+
"""
206+
Load dummy_nwb objects from:
207+
208+
plot_loc/
209+
df_sess.csv (optional)
210+
<subject_id>/
211+
<session_id>/
212+
df_events.parquet
213+
df_fip.parquet
214+
df_trials.parquet
215+
"""
216+
217+
plot_loc = Path(plot_loc)
218+
219+
nwbs = []
220+
df_sess = None
221+
222+
# load df_sess if present
223+
df_sess_file = plot_loc / "df_sess.csv"
224+
if df_sess_file.exists():
225+
print("loading df_sess")
226+
df_sess = pd.read_csv(df_sess_file)
227+
else:
228+
df_sess = None
229+
230+
# load sessions
231+
for subject_folder in sorted(plot_loc.iterdir()):
232+
233+
if not subject_folder.is_dir():
234+
continue
235+
236+
for session_folder in sorted(subject_folder.iterdir()):
237+
238+
if not session_folder.is_dir():
239+
continue
240+
241+
print(f"loading {session_folder.name}")
242+
243+
nwb = dummy_nwb.load(session_folder)
244+
nwbs.append(nwb)
245+
246+
return nwbs, df_sess
120247
def get_dummy_nwbs(df_trials, df_events, df_fip):
121248
ses_idx_list = df_trials.ses_idx.unique()
122249
dummy_nwbs_list = []

0 commit comments

Comments
 (0)