Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ To compare multiple channels to the same event type:
pf.plot_fip_psth(nwb, 'goCue_start_time')
```

### Hierarchical Bootstrapping and significance testing
If you want to perform hierarchical bootstrapping, then you can select `hb_sem` as the error type
```
fig, ax, etrs = plot_fip_psth_compare_alignments(
nwbs,
alignments,
channel_name,
error_type='hb_sem',
hierarchical_params={'nboots':10000}
)
```
Then add significance with:
```
fip_psth_stats_plot(ax, etrs['stats'])`, color='k', threshold=0.05)
```

## Contributing

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
'numpy',
'pydantic',
'hdmf_zarr',
'aind-dynamic-foraging-data-utils >= 0.1.25'
'aind-dynamic-foraging-data-utils >= 0.1.25',
'aind_hierarchical_bootstrap'
]

[project.optional-dependencies]
Expand Down
208 changes: 197 additions & 11 deletions src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
Tools for plotting FIP data
"""

from functools import partial
from multiprocessing import Pool

import aind_hierarchical_bootstrap.bootstrap as hb
import aind_hierarchical_bootstrap.stats as hb_stats
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from aind_dynamic_foraging_data_utils import alignment as an
from aind_dynamic_foraging_data_utils import nwb_utils as nu

import aind_dynamic_foraging_basic_analysis.plot.style as style
from aind_dynamic_foraging_basic_analysis.plot.style import FIP_COLORS, STYLE


Expand All @@ -22,9 +28,12 @@ def plot_fip_psth_compare_alignments( # NOQA C901
extra_colors={},
data_column="data",
error_type="sem",
hierarchical_params={},
):
"""
Compare the same FIP channel aligned to multiple event types

ARGS
nwb, nwb object for the session, or a list of nwbs
alignments, with one session alignments can be either a list of
event types in df_events, or a dictionary whose keys are
Expand All @@ -38,15 +47,22 @@ def plot_fip_psth_compare_alignments( # NOQA C901
extra_colors (dict), a dictionary of extra colors.
keys should be alignments, or colors are random
data_column (string), name of data column in nwb.df_fip
error_type, (string), either "sem" or "sem_over_sessions" to define
error_type, (string), either "sem", "hb_sem", or "sem_over_sessions" to define
the error bar for the PSTH
hierarchical_params, (dict), parameters to be passed to compute_hierarchical_error()

RETURNS
fig - matplotlib figure
ax - matplotlib axis
etrs - dictionary containing PSTH traces for each alignment. If hierarchical
bootstrapping is performed, then also contains the bootstraps and statistics dataframe

EXAMPLE
*******************
plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time',
'right_reward_delivery_time'],'G_1_preprocessed')
"""
if error_type not in ["sem", "sem_over_sessions"]:
if error_type not in ["sem", "sem_over_sessions", "hb_sem"]:
raise Exception("unknown error type")

nwb_list = nwb if isinstance(nwb, list) else [nwb]
Expand Down Expand Up @@ -132,12 +148,37 @@ def plot_fip_psth_compare_alignments( # NOQA C901
colors = {**FIP_COLORS, **extra_colors}

align_label = "Time (s)"
etrs = {}
bootstraps = {}
for alignment in align_list[0]:
this_align = [x[alignment] for x in align_list]
etr = fip_psth_multiple_inner_compute(
nwb_list, this_align, channel, True, tw, censor, censor_times_list, data_column
etr, bootstrap = fip_psth_multiple_inner_compute(
nwb_list,
this_align,
channel,
True,
tw,
censor,
censor_times_list,
data_column,
compute_hierarchical=error_type == "hb_sem",
hierarchical_params=hierarchical_params,
)
fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column, error_type)
etrs[alignment] = etr
if error_type == "hb_sem":
for b in bootstrap:
b[alignment] = b[data_column]
b["{}_sem".format(alignment)] = b["{}_sem".format(data_column)]
b["groups"] = [alignment]
del b[data_column]
del b["{}_sem".format(data_column)]
bootstraps[alignment] = bootstrap

if error_type == "hb_sem":
bootstraps, stats_df = aggregate_bootstrap_statistics(bootstraps)
etrs["stats"] = stats_df
etrs["bootstraps"] = bootstraps

plt.legend()
ax.set_xlabel(align_label, fontsize=STYLE["axis_fontsize"])
Expand All @@ -159,7 +200,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
else:
ax.set_title("{} sessions".format(len(nwb_list)), fontsize=STYLE["axis_fontsize"])
plt.tight_layout()
return fig, ax
return fig, ax, etrs


def plot_fip_psth_compare_channels( # NOQA C901
Expand All @@ -179,17 +220,26 @@ def plot_fip_psth_compare_channels( # NOQA C901
censor=True,
data_column="data",
error_type="sem",
hierarchical_params={},
):
"""
nwb, the nwb object for the session of interest, or a list of nwb objects
ARGS
nwb, the nwb object, etrs for the session of interest, or a list of nwb objects
align should either be a string of the name of an event type in nwb.df_events,
or a list of timepoints. if nwb is a list, then align should be a list containing
lists of timepoints for each session.
channels should be a list of channel names (strings)
censor, censor important timepoints before and after aligned timepoints
data_column (string), name of data column in nwb.df_fip
error_type, (string), either "sem" or "sem_over_sessions" to define
error_type, (string), either "sem", "hb_sem", or "sem_over_sessions" to define
the error bar for the PSTH
hierarchical_params, (dict), parameters to be passed to compute_hierarchical_error()

RETURNS
fig - matplotlib figure
ax - matplotlib axis
etrs - dictionary containing PSTH traces for each alignment. If hierarchical
bootstrapping is performed, then also contains the bootstraps and statistics dataframe

EXAMPLE
********************
Expand All @@ -198,7 +248,7 @@ def plot_fip_psth_compare_channels( # NOQA C901
plot_fip_psth(nwb_list, [session_1_timepoints, session_2_timepoints, ... ])
"""

if error_type not in ["sem", "sem_over_sessions"]:
if error_type not in ["sem", "sem_over_sessions", "hb_sem"]:
raise Exception("Unknown error type")

# Check if nwb is a list, otherwise put it in a list to check
Expand Down Expand Up @@ -258,18 +308,36 @@ def plot_fip_psth_compare_channels( # NOQA C901

# Iterate through channels and plot
colors = [FIP_COLORS.get(c, "") for c in channels]
etrs = {}
bootstraps = {}
for dex, c in enumerate(channels):
include = [c in nwb.df_fip["event"].values for nwb in nwb_list]
etr = fip_psth_multiple_inner_compute(
etr, bootstrap = fip_psth_multiple_inner_compute(
[x for dex, x in enumerate(nwb_list) if include[dex]],
[x for dex, x in enumerate(align_timepoints_list) if include[dex]],
c,
True,
tw,
censor,
data_column=data_column,
compute_hierarchical=error_type == "hb_sem",
hierarchical_params=hierarchical_params,
)
fip_psth_inner_plot(ax, etr, colors[dex], c, data_column, error_type)
etrs[c] = etr
if error_type == "hb_sem":
for b in bootstrap:
b[c] = b[data_column]
b["{}_sem".format(c)] = b["{}_sem".format(data_column)]
b["groups"] = [c]
del b[data_column]
del b["{}_sem".format(data_column)]
bootstraps[c] = bootstrap

if error_type == "hb_sem":
bootstraps, stats_df = aggregate_bootstrap_statistics(bootstraps)
etrs["stats"] = stats_df
etrs["bootstraps"] = bootstraps

plt.legend()
ax.set_xlabel(align_label, fontsize=STYLE["axis_fontsize"])
Expand All @@ -290,7 +358,37 @@ def plot_fip_psth_compare_channels( # NOQA C901
else:
ax.set_title("{} sessions".format(len(nwb_list)))
plt.tight_layout()
return fig, ax
return fig, ax, etrs


def fip_psth_stats_plot(ax, stats_df, threshold=0.05):
"""
Plots markers where a significant threshold is reached

Does not perform multiple comparisons testing.

If multiple unique tests are in the stats_df, each test
is plotted separately.

ARGS
ax - axis to plot on
stats_df - dataframe of stats results

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add more what this stats_df will have?

also, note that the two plot_fip function will return this as part of etrs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more documentation

threshold - significance level

"""
unique_tests = stats_df["name"].unique()
colors = style.get_colors(list(unique_tests), offset=0.25, cmap_name="plasma")
for test in unique_tests:
significant = stats_df.query("name == @test").query("p < @threshold")
if len(significant) > 0:
ax.plot(
significant.index,
[ax.get_ylim()[0]] * len(significant),
"o",
color=colors[test],
label="{}: p < {}".format(test, threshold),
)
plt.legend()


def fip_psth_inner_plot(ax, etr, color, label, data_column, error_type="sem"):
Expand Down Expand Up @@ -325,6 +423,8 @@ def fip_psth_multiple_inner_compute(
censor=True,
censor_times=None,
data_column="data",
compute_hierarchical=False,
hierarchical_params={},
):
"""
Wrapper function for fip_psth_inner_compute that takes a list of NWB files
Expand Down Expand Up @@ -373,7 +473,14 @@ def fip_psth_multiple_inner_compute(
# Compute SEM collapsing over sessions
result["sem"] = etr_all.groupby("time")[data_column].sem()

return result
if compute_hierarchical:
result, bootstraps = compute_hierarchical_error(
result, etr_all, data_column=data_column, **hierarchical_params
)
else:
bootstraps = None

return result, bootstraps
else:
return etr_all

Expand Down Expand Up @@ -422,6 +529,85 @@ def fip_psth_inner_compute(
return etr


def compute_hierarchical_error(
result, etr_all, levels=["ses_idx"], nboots=10000, data_column="data"
):
"""
Computes hierarchical bootstraps at each timepoint.
result, summary dataframe
etr_al, dataframe with all datapoints
levels, list of hierarchy to bootstrap. Slows considerably with each level
nboots, number of bootstraps
data_column, column to use in etr_all

Computing hierarchical bootstraps is slow. Consider using
error_type='sem_over_sessions' until analyses are finalized.
You may also speed up this computation by adding more CPUs,
as this function is optimized for multiprocessing. Additionally,
you can reduce 'nboots' for faster processing.

See aind_hierarchical_boostrap for more information
"""

# Set up partial function that wraps other parameters
temp_func = partial(hb.bootstrap, metric=data_column, levels=levels, nboots=nboots)

# Split dataframe by timepoint
dfs = []
for num in result.index.values:
dfs.append(etr_all.query("time == @num"))

# Run multiprocess pool
with Pool() as pool:
bootstraps = pool.map(temp_func, dfs)

# Organize results
for index, val in enumerate(result.index.values):
bootstraps[index]["time"] = val
result["hb_sem"] = [x["data_sem"] for x in bootstraps]

return result, bootstraps


def aggregate_bootstrap_statistics(bootstraps):
"""
Computes statistics on bootstrap results across groups (alignments or channels)
bootstraps - a dictionary of lists. The keys are the groups to compare,
either alignments or channels. The lists are the timepoints of the PSTH
The lengths of the lists must be the same for all groups.
Returns
combined_dicts - a list of dictionaries, one for each timepoint of the PSTH.
The dictionary for each timepoint is the merged dictionaries from each group
stats_df - a dataframe with statistics results, from aind_hierarchical_bootstrap
"""
# check that the lists are always the same length
lens = set()
for key in bootstraps:
lens.add(len(bootstraps[key]))
if len(lens) > 1:
raise Exception("Event triggered responses for each alignment are different lengths")

combined_dicts = []
for i in range(lens.pop()):
temp = {}
for key in bootstraps:
temp[key] = bootstraps[key][i][key]
temp["{}_sem".format(key)] = bootstraps[key][i]["{}_sem".format(key)]
temp["groups"] = list(bootstraps.keys())
temp["time"] = bootstraps[key][i]["time"]
combined_dicts.append(temp)

# now compute statistics for each timepoint
stats_dfs = []
for d in combined_dicts:
stats_df = hb_stats.compute_stats(d)
stats_df["time"] = d["time"]
stats_dfs.append(stats_df)
stats_df = pd.concat(stats_dfs).set_index("time", drop=True)

return combined_dicts, stats_df


def plot_histogram(nwb, preprocessed=True, edge_percentile=2, data_column="data"):
"""
Generates a histogram of values of each FIP channel
Expand Down