Skip to content

Commit 9cae3fe

Browse files
committed
Refactors embedding methods and adds MCCA
Refactors several embedding methods (PCA, GPFA, CCA, MCCA) for improved modularity and flexibility. Adds a new MCCA embedding method, implementing multi-area canonical correlation analysis, improving multi-region analysis capabilities. Updates argument validation for struct inputs in datareader.
1 parent 3c76d6d commit 9cae3fe

18 files changed

Lines changed: 362 additions & 180 deletions

File tree

+datareader/+is/Struct.m

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
fnames = fieldnames(Din);
8686
fnames = string(fnames);
8787
fields2check = ["data","time","condition","area"];
88-
fieldType = ["numeric","cell","string","string"];
88+
fieldType = ["numeric","double","string","string"];
8989

9090
Din_ = repmat(struct(),size(Din));
9191

@@ -151,11 +151,11 @@
151151
'Provided time and input data dimension mismatch. Please provide,for each trial, a time struct field matching data second dimension.');
152152
end
153153
else
154-
assert(iscell(opts.time),...
155-
'Struct input detected. Please provide a cell array of time vectors, one per trial.');
154+
% assert(iscell(opts.time),...
155+
% 'Struct input detected. Please provide a cell array of time vectors, one per trial.');
156156

157-
assert(length(opts.time{1}) == unique(Time),...
158-
'Provided time and input data dimension mismatch. Please provide a cell array of time vectors matching trial lenghts.');
157+
% assert(length(opts.time{1}) == unique(Time),...
158+
% 'Provided time and input data dimension mismatch. Please provide a cell array of time vectors matching trial lenghts.');
159159

160160
dishomogeneous = false;
161161
end

+embedding/+CCA/project.m

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
function E = project(D,C)
22

3-
E = cell(size(D));
4-
alldata = [D{:}];
3+
[nTrial,nArea] = size(D);
4+
E = cell(nTrial,nArea);
5+
alldata = arrayfun(@(Aidx)cat(2,D{:,Aidx}),1:size(D,2),'UniformOutput',false);
56

6-
sc = (alldata - mean(alldata,2))'*C;
7+
E_ = cellfun(@(d,c) c * d,alldata,C,'UniformOutput',false);
78

89
% For each condition, store the reduced version of each data vector
9-
index = 0;
10-
for ii = 1:length(D)
11-
E{ii} = sc(index + (1:size(D(ii).data,2)),:)';
12-
index = index + size(D(ii).data,2);
13-
end %ii
14-
end
10+
for jj = 1:nArea
11+
index = 0;
12+
for ii = 1:nTrial
13+
E{ii,jj} = E_{jj}(index + (1:pars.TrialL),1:dims)';
14+
index = index + pars.TrialL;
15+
end
16+
end

+embedding/+CCA/reduce.m

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function [E,C,Corr] = reduce(D,pars)
22
% CCAREDUCE Internal function for CCA
33

4-
% Agglomerate all of the conditions, and perform PCA
4+
% Agglomerate all of the conditions, and perform mCCA
55
E_ = cell(1,pars.nArea);
66
C = cell(size(D));
77
Corr = cell(size(D));
@@ -19,18 +19,19 @@
1919
[A,B,Corr,E_{1},E_{2}] = canoncorr(D{:});
2020
C = {A(:,1:dims),B(:,1:dims)};
2121
else
22-
data = cat(1,D{:})';
22+
data = arrayfun(@(Aidx)cat(2,D{:,Aidx}),1:size(D,2),'UniformOutput',false);
2323
d = cellfun(@(x)size(x,1),D);
24-
[~,Corr,C] = embedding.CCA.mcca(data,d);
24+
[~,Corr,C] = embedding.CCA.mcca(cat(1,data{:})',d);
25+
E_ = cellfun(@(d,c) c * d,data,C,'UniformOutput',false);
2526
end
2627

2728
% [U,V] = checkFlip(D{:},C{:},endLeg_range, interest_range);
2829

2930
% For each condition, store the reduced version of each data vector
3031
E = cell(pars.nTrial,pars.nArea);
31-
for ii = 1:pars.nTrial
32+
for jj = 1:pars.nArea
3233
index = 0;
33-
for jj = 1:pars.nArea
34+
for ii = 1:pars.nTrial
3435
E{ii,jj} = E_{jj}(index + (1:pars.TrialL),1:dims)';
3536
index = index + pars.TrialL;
3637
end

+embedding/+GPFA/project.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
alldata = [D{:}];
55
projMatrix = C{:};
66

7-
sc = alldata'*projMatrix;
7+
sc = projMatrix * alldata;
88

99
% For each condition, store the reduced version of each data vector
1010
index = 0;
1111
for ii = 1:length(D)
12-
E{ii} = sc(index + (1:size(D{ii},2)),:)';
12+
E{ii} = sc(:,index + (1:size(D{ii},2)));
1313
index = index + size(D{ii},2);
1414
end %ii
1515
end

+embedding/+GPFA/reduce.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,5 @@
100100
[E{:}] = deal(AllSeq.xsm);
101101
[~, C{1}] = embedding.GPFA.util.orthogonalize([AllSeq.xsm], estParams.C);
102102
[~,lat] = pcacov(estParams.C * estParams.C');
103-
VarExp{1} = cumsum(lat(1:xDim))./sum(lat);
103+
VarExp{1} = cumsum(lat(xDim))./sum(lat);
104104
end

+embedding/+MCCA/loadParams.m

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,6 @@
55
% pars.endLeg_range = @(t)getNormRange(t,fraction);
66
% pars.interest_range = @(t)getInterestRange(t,fraction,alignment);
77
% pars.ccaRefSig = [];
8-
pars.mcca_k = 5;
8+
pars.mcca_k = 0.9;
99

1010
end
11-
12-
function T = getNormRange(t,fraction)
13-
14-
Tmax = length(t);
15-
T = false(size(t));
16-
T(1:Tmax/fraction/2) = true;
17-
T(end-Tmax/fraction/2:end) = true;
18-
T = t(T);
19-
end
20-
21-
function T = getInterestRange(t,fraction,alignment)
22-
23-
Tmax = length(t);
24-
T = false(size(t));
25-
Pre = alignment-Tmax/fraction/2;
26-
Post = alignment+Tmax/fraction/2;
27-
T(Pre:Post) = true;
28-
T = t(T);
29-
end

+embedding/+MCCA/project.m

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
function E = project(D,C)
2+
3+
[nTrial,nArea] = size(D);
4+
E = cell(nTrial,nArea);
5+
alldata = arrayfun(@(Aidx)cat(2,D{:,Aidx}),1:size(D,2),'UniformOutput',false);
6+
7+
E_ = cellfun(@(d,c) c * d,alldata,C,'UniformOutput',false);
8+
9+
% For each condition, store the reduced version of each data vector
10+
for jj = 1:nArea
11+
index = 0;
12+
for ii = 1:nTrial
13+
E{ii,jj} = E_{jj}(index + (1:pars.TrialL),1:dims)';
14+
index = index + pars.TrialL;
15+
end
16+
end

+embedding/+MCCA/reduce.m

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
function [C,Corr] = reduce(D_all,pars)
2+
% CCAREDUCE Internal function for CCA
3+
nSets = length(D_all);
4+
5+
% Agglomerate all the sets, and perform mCCA
6+
7+
%
8+
% if length(D) < 2 && ~iscell(D)
9+
% error('Input must be a cell-array of at least two elements with data from two distinct areas.')
10+
% end
11+
% endLeg_range = pars.endLeg_range;
12+
% interest_range = pars.interest_range;
13+
14+
data = cellfun(@(eall)cat(2,eall{:}),D_all,'UniformOutput',false);
15+
d = cellfun(@(x)size(x,1),data);
16+
[~,Corr,C] = embedding.CCA.mcca(cat(1,data{:})',d);
17+
18+
end

+embedding/+PCA/project.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
alldata = [D{:}];
55
projMatrix = C{:};
66

7-
sc = (alldata - mean(alldata,2))'*projMatrix;
7+
sc = projMatrix * (alldata - mean(alldata,2));
88

99
% For each condition, store the reduced version of each data vector
1010
index = 0;
1111
for ii = 1:length(D)
12-
E{ii} = sc(index + (1:size(D{ii},2)),:)';
12+
E{ii} = sc(:,index + (1:size(D{ii},2)));
1313
index = index + size(D{ii},2);
1414
end %ii
1515
end

+embedding/+PCA/reduce.m

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
% For each condition, store the reduced version of each data vector
1717
index = 0;
1818
for i=1:length(D)
19-
D(i).data = sc(index + (1:size(D(i).data,2)),1:dims)';
19+
D(i).data = sc(index + (1:size(D(i).data,2)),:)';
2020
index = index + size(D(i).data,2);
2121
end
2222
[E{:}] = deal(D.data);
23-
C{1} = u(:,1:dims);
24-
VarExplained{1} = cumsum(lat) ./ sum(lat); % eigenvalues
23+
C{1} = u;
24+
VarExplained{1} = cumsum(lat(dims)) ./ sum(lat); % eigenvalues
2525

2626
end

0 commit comments

Comments
 (0)