Skip to content

Commit 23dea6b

Browse files
committed
Improves CCA handling and diagnostics
Handles variable trial lengths when assembling per-trial embeddings and returns variance-explained as a cell so multi-output CCA metadata is preserved. Computes and stores canonical correlations and projection matrix pseudoinverses for downstream use, and exposes a getter to access canonical correlations in the embedding object. Updates 3D trajectory plotting to iterate per-area embeddings and select nearest trajectories correctly. Fixes d' computation to use absolute difference between signal and baseline means. Adds default DPrime parameter indices for baseline and signal windows. These changes make multi-area CCA more robust, preserve useful diagnostic outputs, and correct metric behavior.
1 parent 71a5559 commit 23dea6b

6 files changed

Lines changed: 61 additions & 30 deletions

File tree

+embedding/+CCA/reduce.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
for jj = 1:pars.nArea
3333
index = 0;
3434
for ii = 1:pars.nTrial
35-
E{ii,jj} = E_{jj}(index + (1:pars.TrialL),1:dims)';
36-
index = index + pars.TrialL;
35+
E{ii,jj} = E_{jj}(index + (1:pars.TrialL{ii}),1:dims)';
36+
index = index + pars.TrialL{ii};
3737
end
3838
end
3939

+embedding/CCA.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function [E,ProjMatrix,VarExplained] = CCA(D,pars)
22
D_ = arrayfun(@(aidx)cat(2,D{:,aidx})',1:size(D,2),'UniformOutput',false);
33

4-
[E,ProjMatrix,VarExplained] = embedding.CCA.reduce(D_,pars);
4+
[E,ProjMatrix,VarExplained{1}] = embedding.CCA.reduce(D_,pars);
55
end

+metrics/+compute/dprime.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
% Pooled standard deviation: sqrt( (sigma_b^2 + sigma_s^2) / 2 )
4242
pooled_std = sqrt(0.5 .* (std_b.^2 + std_s.^2));
4343

44-
dp = (mu_s - mu_b) ./ pooled_std;
44+
dp = abs(mu_s - mu_b) ./ pooled_std;
4545

4646
% Return as 1 x nDims row vector
4747
dp = dp';

+metrics/+pars/DPrime.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function pars = DPrime()
22

3-
pars.baseline_idx = [];
4-
pars.signal_idx = [];
3+
pars.baseline_idx = [1:50 351:400];
4+
pars.signal_idx = 175:225;
55

66

77
end

@NeuralEmbedding/NeuralEmbedding.m

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
PostKern
3131
BinWidth
3232

33-
VarExplained % (Double) Variance explained by each embedded dimension
33+
VarExplained %(Double) Variance explained by each embedded dimension
34+
CanonCorr %(Cell) Canonical correlation (where relevant)
3435
NumPC
3536

3637
end
@@ -115,6 +116,7 @@
115116
homogeneous logical = false % flag for trial homogenuity. If all trials all equally long, this is 0
116117

117118
VarExplained_ double % cell storing variance explained values
119+
CanonCorrelation_ cell % cell storing canonical correlation iof applicable
118120
numPC double = 6
119121
end
120122

@@ -606,7 +608,26 @@ function addEvents(obj,evts)
606608
end
607609
value = obj.VarExplained_(amask);
608610
end
609-
611+
function value = get.CanonCorr(obj)
612+
amask = ismember(obj.UArea,obj.aMask_);
613+
if isempty(obj.CanonCorrelation_) || ...
614+
length(obj.CanonCorrelation_) < find(amask,1,'last')
615+
% TODO, not sure if makes sense
616+
end
617+
nMask = numel(obj.aMask_);
618+
value_tmp = cell(1,(nMask-1)*(nMask-2)/2);
619+
str = repmat("",(nMask-1)*(nMask-2)/2);
620+
idx = 1;
621+
for ii = 1:nMask-1
622+
for jj = ii+1:nMask
623+
str(idx) = obj.aMask_(ii)+"_"+obj.aMask_(jj);
624+
value_tmp{idx} = obj.CanonCorrelation_(ii);
625+
idx = idx + 1;
626+
end
627+
end
628+
value = cell2struct(value_tmp,str);
629+
end
630+
610631
function value = get.NumPC(obj)
611632
value = obj.numPC;
612633
end
@@ -861,6 +882,10 @@ function zscoreData(obj)
861882
'condition',obj.cMask_,...
862883
'data',data,...
863884
'Area',obj.aMask_);
885+
886+
if strcmp(type,'DPrime')
887+
str.Area = strjoin(str.Area,"");
888+
end
864889
end
865890

866891
end
@@ -886,7 +911,7 @@ function plot3(obj,maxT)
886911
arrayfun(@(o)o.plot3(maxT),obj);
887912
return;
888913
end
889-
reducedE = cellfun(@(x)[x(1:3,:) nan(3,1)], ...
914+
reducedE_ = cellfun(@(x)[x(1:3,:) nan(3,1)], ...
890915
obj.E, ...
891916
'UniformOutput',false);
892917
t = cellfun(@(t)[t(:)' nan], ...
@@ -895,20 +920,22 @@ function plot3(obj,maxT)
895920
nT = sum(obj.cMask);
896921
MaxLines = min(nT,maxT);
897922
% idx = randperm(nT,MaxLines);
898-
idx = findClosestN(reducedE,MaxLines);
899-
reducedE = [reducedE{idx}];
923+
idx = findClosestN(reducedE_,MaxLines);
900924
t = [t{idx}];
901-
figure;
902-
surface([reducedE(1,:);reducedE(1,:)], ...
903-
[reducedE(2,:);reducedE(2,:)], ...
904-
[reducedE(3,:);reducedE(3,:)], ...
905-
[t;t], ...
906-
'facecol','no',...
907-
'edgecol','interp',...
908-
'linew',1)
909-
title(obj.Animal + " " +obj.Session)
910-
xlabel('Dimension 1');ylabel('Dimension 2');zlabel('Dimension 3');
911-
colorbar
925+
for aa = 1:size(reducedE_,2)
926+
reducedE = [reducedE_{idx,aa}];
927+
figure;
928+
surface([reducedE(1,:);reducedE(1,:)], ...
929+
[reducedE(2,:);reducedE(2,:)], ...
930+
[reducedE(3,:);reducedE(3,:)], ...
931+
[t;t], ...
932+
'facecol','no',...
933+
'edgecol','interp',...
934+
'linew',1)
935+
title(obj.Animal + " " +obj.Session + " " + obj.aMask_(aa))
936+
xlabel('Dimension 1');ylabel('Dimension 2');zlabel('Dimension 3');
937+
colorbar
938+
end
912939

913940
function idx = findClosestN(traj,N)
914941
m = median(cat(3,traj{:}),3);

@NeuralEmbedding/findEmbedding.m

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
projectOnly = false;
1717
end
1818

19-
2019
% switch through the different algorithms
2120
switch deblank(type)
2221
case {'SmoothPCA','PCA','pca'}
@@ -66,20 +65,19 @@
6665
try
6766
% Initialize the data for CCA
6867
D = cell(obj.nTrial,obj.nArea);
69-
a = 1;
70-
for aa = obj.UArea(1:end-1)
68+
for aa = 1:obj.nArea
7169
% Set the area mask to the current area
72-
obj.aMask = aa;
70+
obj.aMask = obj.UArea(aa);
7371
% Get the data for the current area
74-
D(:,a) = obj.S;
72+
D(:,aa) = obj.S;
7573
% Increment the area counter
76-
a = a + 1;
7774
end
7875
% Set the area mask to all areas
7976
obj.aMask = obj.UArea(1:end-1);
8077
% Compute the CCA
81-
[E,W,VarExplained] = ...
78+
[E,W,CanonCorrelation] = ...
8279
embedding.CCA(D,pars);
80+
Winv = cellfun(@pinv,W,'UniformOutput',false);
8381
catch er
8482
% If the algorithm fails, set flag to false and rethrow the error
8583
flag = false;
@@ -177,6 +175,12 @@
177175
E_s,'UniformOutput',false);
178176
obj.W_(amask) = W;
179177
obj.Winv_(amask) = Winv;
180-
obj.VarExplained_(amask) = VarExplained{1};
178+
if exist("VarExplained","var")
179+
obj.VarExplained_(amask) = VarExplained{:};
180+
end
181+
if exist("CanonCorrelation","var")
182+
obj.CanonCorrelation_(1:obj.nArea-1) = CanonCorrelation;
183+
end
184+
181185
end
182186

0 commit comments

Comments
 (0)