-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path03.ppseq_shuf.py
More file actions
91 lines (80 loc) · 2.54 KB
/
03.ppseq_shuf.py
File metadata and controls
91 lines (80 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# %% import and definition
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import xarray as xr
from ppseq.model import PPSeq
from routine.ppseq import shuf_spks
F_PATH = "./intermediate/concat/sig_master.nc"
SPK_PATH = "./intermediate/deconv/S.nc"
INT_PATH = "./intermediate/ppseq"
FIG_PATH = "./figs/ppseq"
NSHUF = 250
NITER = 30
TEMP_DUR = [5, 10, 50, 100]
os.makedirs(INT_PATH, exist_ok=True)
os.makedirs(FIG_PATH, exist_ok=True)
# %% ppseq
ds_spks = xr.load_dataset(os.path.join(INT_PATH, "spks_ds.nc")).rename(unit_id="cell")
spk = ds_spks["spks_ds"].dropna("frame", how="all")
spk = spk.where(spk > 2, other=0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
shuf_res = []
for temp_dur in TEMP_DUR:
model = PPSeq(
num_templates=1,
num_neurons=int(spk.sizes["cell"]),
template_duration=temp_dur,
alpha_a0=1.5,
beta_a0=0.2,
alpha_b0=1,
beta_b0=0.1,
alpha_t0=1.2,
beta_t0=0.1,
)
for ishuf in range(-1, NSHUF):
if ishuf >= 0:
spk_sh = shuf_spks(spk.values)
else:
spk_sh = spk.values
spk_dat = torch.from_numpy(spk_sh)
torch.manual_seed(0)
lps, amplitudes = model.fit(spk_dat, num_iter=NITER)
shuf_res.append(
pd.DataFrame(
{
"temp_dur": temp_dur,
"ishuf": ishuf,
"lps": lps.numpy(),
"iiter": np.arange(NITER),
}
)
)
shuf_res = pd.concat(shuf_res, ignore_index=True)
shuf_res.to_feather(os.path.join(INT_PATH, "shuf_res.feat"))
# %% plotting
def sel_iter(df):
lps_best = df["lps"].max()
return df[df["lps"] == lps_best].set_index("iiter")
def plot_shuf(data, x="lps", color=None):
ax = plt.gca()
dat_org = data[data["ishuf"] == -1]
dat_shuf = data[data["ishuf"] >= 0]
xval_org = dat_org[x].item()
pval = (dat_shuf[x] > xval_org).sum() / len(dat_shuf)
sns.histplot(dat_shuf, x=x, ax=ax)
ax.axvline(xval_org)
ax.text(1, 1, "pval: {:.3f}".format(pval), transform=ax.transAxes)
shuf_res = pd.read_feather(os.path.join(INT_PATH, "shuf_res.feat"))
res_agg = (
shuf_res.groupby(["temp_dur", "ishuf"])
.apply(sel_iter, include_groups=False)
.reset_index()
)
assert res_agg["iiter"].unique().item() == NITER - 1
g = sns.FacetGrid(res_agg, col="temp_dur", sharex=False)
g.map_dataframe(plot_shuf)
g.figure.savefig(os.path.join(FIG_PATH, "shuf_res.svg"))