Skip to content

Commit 1c4af72

Browse files
committed
Merge branch 'MCCA'
2 parents 9bf5260 + 9cae3fe commit 1c4af72

22 files changed

Lines changed: 519 additions & 96 deletions

File tree

+datareader/+is/Struct.m

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
fnames = fieldnames(Din);
8686
fnames = string(fnames);
8787
fields2check = ["data","time","condition","area"];
88-
fieldType = ["numeric","cell","string","string"];
88+
fieldType = ["numeric","double","string","string"];
8989

9090
Din_ = repmat(struct(),size(Din));
9191

@@ -151,11 +151,11 @@
151151
'Provided time and input data dimension mismatch. Please provide,for each trial, a time struct field matching data second dimension.');
152152
end
153153
else
154-
assert(iscell(opts.time),...
155-
'Struct input detected. Please provide a cell array of time vectors, one per trial.');
154+
% assert(iscell(opts.time),...
155+
% 'Struct input detected. Please provide a cell array of time vectors, one per trial.');
156156

157-
assert(length(opts.time{1}) == unique(Time),...
158-
'Provided time and input data dimension mismatch. Please provide a cell array of time vectors matching trial lenghts.');
157+
% assert(length(opts.time{1}) == unique(Time),...
158+
% 'Provided time and input data dimension mismatch. Please provide a cell array of time vectors matching trial lenghts.');
159159

160160
dishomogeneous = false;
161161
end

+embedding/+CCA/mcca.m

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
function [V,rho,A,rhotest]=mcca(X,d,Xtest,k)
2+
% [V,rho,A,rhotest]=mcca(X,d,Xtest,k) Multiset Canonical Correlation
3+
% Analysis. X is the data arranged as samples by dimension, whereby all
4+
% sets are concatenated along the dimensions. d is a vector with the
5+
% dimensions of each set. V are the component vectors and rho the resulting
6+
% inter-set correlations. A are the corresponding forward models, which
7+
% are returned as a list of length N. If Xtest is given, it will also
8+
% compute rho for the test data with the optimal V. If k is given then the
9+
% within-set correlation will be reduced in dimension from d to k prior to
10+
% inversion using PCA. This is useful for rank deficient data or for
11+
% regularization. If k is not given, dimension is reduced to the rank of
12+
% the data prior to inversion.
13+
%
14+
% See https://arxiv.org/abs/1802.03759, https://arxiv.org/abs/1801.08881
15+
16+
% Apr 30, 2018, Lucas Parra (c)
17+
% Sep 11, 2018, removed hack for forward model computation that broke the code sometimes
18+
% Sep 14, 2018, make forward model robust to ill conditioned data
19+
% Sep 15, 2018, keep simpler code in case that there is no regularization or rank problem
20+
21+
if ~exist('k','var') || isempty(k), k=d; end
22+
23+
N=length(d);
24+
R=cov(X);
25+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
26+
D(j,j)=R(j,j);
27+
k(i)=min(k(i),rank(D(j,j))); % check rank for oblivious users
28+
end
29+
if sum(d)==sum(k) % simple case
30+
[V,lambda]=eig(R,D);
31+
else % if rank deficient, or if regularization requested
32+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); Dinv(j,j)=embedding.CCA.regInv(D(j,j),k(i)); end
33+
[V,lambda]=eigs(Dinv*R,sum(k));
34+
end
35+
rho = (diag(lambda)-1)/(N-1);
36+
[~,indx]=sort(rho,'descend'); rho=rho(indx); V=V(:,indx);
37+
38+
% compute forward models
39+
if nargout>2
40+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
41+
W=V(j,1:k(i)); Rw=R(j,j);
42+
if k(i)==d(i), A{i}=Rw*W/(W'*Rw*W); % original formula, but wont work for rank deficient Rw
43+
else A{i}=Rw*W*diag(1./diag(W'*Rw*W)); end % ignores correlation of components but robust to ill conditioned Rw
44+
end
45+
end
46+
47+
% compute rho for test data
48+
if exist('Xtest') && ~isempty(Xtest)
49+
R=cov(Xtest);
50+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); D(j,j)=R(j,j); end
51+
lambda = diag(V'*R*V)./diag(V'*D*V);
52+
rhotest = (lambda-1)/(N-1);
53+
end
54+
55+

+embedding/+CCA/project.m

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
function E = project(D,C)
22

3-
E = cell(size(D));
4-
alldata = [D{:}];
3+
[nTrial,nArea] = size(D);
4+
E = cell(nTrial,nArea);
5+
alldata = arrayfun(@(Aidx)cat(2,D{:,Aidx}),1:size(D,2),'UniformOutput',false);
56

6-
sc = (alldata - mean(alldata,2))'*C;
7+
E_ = cellfun(@(d,c) c * d,alldata,C,'UniformOutput',false);
78

89
% For each condition, store the reduced version of each data vector
9-
index = 0;
10-
for ii = 1:length(D)
11-
E{ii} = sc(index + (1:size(D(ii).data,2)),:)';
12-
index = index + size(D(ii).data,2);
13-
end %ii
14-
end
10+
for jj = 1:nArea
11+
index = 0;
12+
for ii = 1:nTrial
13+
E{ii,jj} = E_{jj}(index + (1:pars.TrialL),1:dims)';
14+
index = index + pars.TrialL;
15+
end
16+
end

+embedding/+CCA/reduce.m

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function [E,C,Corr] = reduce(D,pars)
22
% CCAREDUCE Internal function for CCA
33

4-
% Agglomerate all of the conditions, and perform PCA
4+
% Agglomerate all of the conditions, and perform mCCA
55
E_ = cell(1,pars.nArea);
66
C = cell(size(D));
77
Corr = cell(size(D));
@@ -19,18 +19,19 @@
1919
[A,B,Corr,E_{1},E_{2}] = canoncorr(D{:});
2020
C = {A(:,1:dims),B(:,1:dims)};
2121
else
22-
data = cat(1,D{:})';
22+
data = arrayfun(@(Aidx)cat(2,D{:,Aidx}),1:size(D,2),'UniformOutput',false);
2323
d = cellfun(@(x)size(x,1),D);
24-
[~,Corr,C] = embedding.CCA.mcca(data,d);
24+
[~,Corr,C] = embedding.CCA.mcca(cat(1,data{:})',d);
25+
E_ = cellfun(@(d,c) c * d,data,C,'UniformOutput',false);
2526
end
2627

2728
% [U,V] = checkFlip(D{:},C{:},endLeg_range, interest_range);
2829

2930
% For each condition, store the reduced version of each data vector
3031
E = cell(pars.nTrial,pars.nArea);
31-
for ii = 1:pars.nTrial
32+
for jj = 1:pars.nArea
3233
index = 0;
33-
for jj = 1:pars.nArea
34+
for ii = 1:pars.nTrial
3435
E{ii,jj} = E_{jj}(index + (1:pars.TrialL),1:dims)';
3536
index = index + pars.TrialL;
3637
end

+embedding/+CCA/regInv.m

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
function invR = regInv(R, K)
2+
%invR = regInv(R, K)
3+
% PCA regularized inverse of square symmetric positive definite matrix R
4+
if nargin<2, K=size(R,1); end;
5+
if ~ismatrix(R), error('JD: R must have two dimensions'); end;
6+
if size(R,1)~=size(R,2), error('JD: R must be a square matrix'); end;
7+
8+
[U,S,V]=svd(R,0);
9+
diagS=diag(S);
10+
invR=U(:,1:K)*diag(1./diagS(1:K))*V(:,1:K).';
11+
12+
end
13+

+embedding/+GPFA/project.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
alldata = [D{:}];
55
projMatrix = C{:};
66

7-
sc = alldata'*projMatrix;
7+
sc = projMatrix * alldata;
88

99
% For each condition, store the reduced version of each data vector
1010
index = 0;
1111
for ii = 1:length(D)
12-
E{ii} = sc(index + (1:size(D{ii},2)),:)';
12+
E{ii} = sc(:,index + (1:size(D{ii},2)));
1313
index = index + size(D{ii},2);
1414
end %ii
1515
end

+embedding/+GPFA/reduce.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,5 @@
100100
[E{:}] = deal(AllSeq.xsm);
101101
[~, C{1}] = embedding.GPFA.util.orthogonalize([AllSeq.xsm], estParams.C);
102102
[~,lat] = pcacov(estParams.C * estParams.C');
103-
VarExp{1} = cumsum(lat(1:xDim))./sum(lat);
103+
VarExp{1} = cumsum(lat(xDim))./sum(lat);
104104
end

+embedding/+MCCA/loadParams.m

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
function pars = loadParams()
2+
%% GPFA specific parameters
3+
pars = struct();
4+
5+
% pars.endLeg_range = @(t)getNormRange(t,fraction);
6+
% pars.interest_range = @(t)getInterestRange(t,fraction,alignment);
7+
% pars.ccaRefSig = [];
8+
pars.mcca_k = 0.9;
9+
10+
end

+embedding/+MCCA/mcca.m

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
function [V,rho,A,rhotest]=mcca(X,d,Xtest,k)
2+
% [V,rho,A,rhotest]=mcca(X,d,Xtest,k) Multiset Canonical Correlation
3+
% Analysis. X is the data arranged as samples by dimension, whereby all
4+
% sets are concatenated along the dimensions. d is a vector with the
5+
% dimensions of each set. V are the component vectors and rho the resulting
6+
% inter-set correlations. A are the corresponding forward models, which
7+
% are returned as a list of length N. If Xtest is given, it will also
8+
% compute rho for the test data with the optimal V. If k is given then the
9+
% within-set correlation will be reduced in dimension from d to k prior to
10+
% inversion using PCA. This is useful for rank deficient data or for
11+
% regularization. If k is not given, dimension is reduced to the rank of
12+
% the data prior to inversion.
13+
%
14+
% See https://arxiv.org/abs/1802.03759, https://arxiv.org/abs/1801.08881
15+
16+
% Apr 30, 2018, Lucas Parra (c)
17+
% Sep 11, 2018, removed hack for forward model computation that broke the code sometimes
18+
% Sep 14, 2018, make forward model robust to ill conditioned data
19+
% Sep 15, 2018, keep simpler code in case that there is no regularization or rank problem
20+
21+
if ~exist('k','var') || isempty(k), k=d; end
22+
23+
N=length(d);
24+
R=cov(X);
25+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
26+
D(j,j)=R(j,j);
27+
k(i)=min(k(i),rank(D(j,j))); % check rank for oblivious users
28+
end
29+
if sum(d)==sum(k) % simple case
30+
[V,lambda]=eig(R,D);
31+
else % if rank deficient, or if regularization requested
32+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); Dinv(j,j)=embedding.CCA.regInv(D(j,j),k(i)); end
33+
[V,lambda]=eigs(Dinv*R,sum(k));
34+
end
35+
rho = (diag(lambda)-1)/(N-1);
36+
[~,indx]=sort(rho,'descend'); rho=rho(indx); V=V(:,indx);
37+
38+
% compute forward models
39+
if nargout>2
40+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
41+
W=V(j,1:k(i)); Rw=R(j,j);
42+
if k(i)==d(i), A{i}=Rw*W/(W'*Rw*W); % original formula, but wont work for rank deficient Rw
43+
else A{i}=Rw*W*diag(1./diag(W'*Rw*W)); end % ignores correlation of components but robust to ill conditioned Rw
44+
end
45+
end
46+
47+
% compute rho for test data
48+
if exist('Xtest') && ~isempty(Xtest)
49+
R=cov(Xtest);
50+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); D(j,j)=R(j,j); end
51+
lambda = diag(V'*R*V)./diag(V'*D*V);
52+
rhotest = (lambda-1)/(N-1);
53+
end
54+
55+

+embedding/+MCCA/project.m

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
function E = project(D,C)
2+
3+
[nTrial,nArea] = size(D);
4+
E = cell(nTrial,nArea);
5+
alldata = arrayfun(@(Aidx)cat(2,D{:,Aidx}),1:size(D,2),'UniformOutput',false);
6+
7+
E_ = cellfun(@(d,c) c * d,alldata,C,'UniformOutput',false);
8+
9+
% For each condition, store the reduced version of each data vector
10+
for jj = 1:nArea
11+
index = 0;
12+
for ii = 1:nTrial
13+
E{ii,jj} = E_{jj}(index + (1:pars.TrialL),1:dims)';
14+
index = index + pars.TrialL;
15+
end
16+
end

0 commit comments

Comments
 (0)