Skip to content

Commit 3c76d6d

Browse files
committed
MCCA Implementation
The function "findEmbedding" has been modified by adding a specific case for the MCCA algorithm. The MCCA case has been implemented in the +embedding subfolder with all the necessary functions inside. The parameter mcca_k was added in the "loadParams" function.
1 parent 9bf5260 commit 3c76d6d

7 files changed

Lines changed: 241 additions & 0 deletions

File tree

+embedding/+CCA/mcca.m

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
function [V,rho,A,rhotest]=mcca(X,d,Xtest,k)
2+
% [V,rho,A,rhotest]=mcca(X,d,Xtest,k) Multiset Canonical Correlation
3+
% Analysis. X is the data arranged as samples by dimension, whereby all
4+
% sets are concatenated along the dimensions. d is a vector with the
5+
% dimensions of each set. V are the component vectors and rho the resulting
6+
% inter-set correlations. A are the corresponding forward models, which
7+
% are returned as a list of length N. If Xtest is given, it will also
8+
% compute rho for the test data with the optimal V. If k is given then the
9+
% within-set correlation will be reduced in dimension from d to k prior to
10+
% inversion using PCA. This is useful for rank deficient data or for
11+
% regularization. If k is not given, dimension is reduced to the rank of
12+
% the data prior to inversion.
13+
%
14+
% See https://arxiv.org/abs/1802.03759, https://arxiv.org/abs/1801.08881
15+
16+
% Apr 30, 2018, Lucas Parra (c)
17+
% Sep 11, 2018, removed hack for forward model computation that broke the code sometimes
18+
% Sep 14, 2018, make forward model robust to ill conditioned data
19+
% Sep 15, 2018, keep simpler code in case that there is no regularization or rank problem
20+
21+
if ~exist('k','var') || isempty(k), k=d; end
22+
23+
N=length(d);
24+
R=cov(X);
25+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
26+
D(j,j)=R(j,j);
27+
k(i)=min(k(i),rank(D(j,j))); % check rank for oblivious users
28+
end
29+
if sum(d)==sum(k) % simple case
30+
[V,lambda]=eig(R,D);
31+
else % if rank deficient, or if regularization requested
32+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); Dinv(j,j)=embedding.CCA.regInv(D(j,j),k(i)); end
33+
[V,lambda]=eigs(Dinv*R,sum(k));
34+
end
35+
rho = (diag(lambda)-1)/(N-1);
36+
[~,indx]=sort(rho,'descend'); rho=rho(indx); V=V(:,indx);
37+
38+
% compute forward models
39+
if nargout>2
40+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
41+
W=V(j,1:k(i)); Rw=R(j,j);
42+
if k(i)==d(i), A{i}=Rw*W/(W'*Rw*W); % original formula, but wont work for rank deficient Rw
43+
else A{i}=Rw*W*diag(1./diag(W'*Rw*W)); end % ignores correlation of components but robust to ill conditioned Rw
44+
end
45+
end
46+
47+
% compute rho for test data
48+
if exist('Xtest') && ~isempty(Xtest)
49+
R=cov(Xtest);
50+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); D(j,j)=R(j,j); end
51+
lambda = diag(V'*R*V)./diag(V'*D*V);
52+
rhotest = (lambda-1)/(N-1);
53+
end
54+
55+

+embedding/+CCA/regInv.m

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
function invR = regInv(R, K)
2+
%invR = regInv(R, K)
3+
% PCA regularized inverse of square symmetric positive definite matrix R
4+
if nargin<2, K=size(R,1); end;
5+
if ~ismatrix(R), error('JD: R must have two dimensions'); end;
6+
if size(R,1)~=size(R,2), error('JD: R must be a square matrix'); end;
7+
8+
[U,S,V]=svd(R,0);
9+
diagS=diag(S);
10+
invR=U(:,1:K)*diag(1./diagS(1:K))*V(:,1:K).';
11+
12+
end
13+

+embedding/+MCCA/loadParams.m

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
function pars = loadParams()
2+
%% GPFA specific parameters
3+
pars = struct();
4+
5+
% pars.endLeg_range = @(t)getNormRange(t,fraction);
6+
% pars.interest_range = @(t)getInterestRange(t,fraction,alignment);
7+
% pars.ccaRefSig = [];
8+
pars.mcca_k = 5;
9+
10+
end
11+
12+
function T = getNormRange(t,fraction)
13+
14+
Tmax = length(t);
15+
T = false(size(t));
16+
T(1:Tmax/fraction/2) = true;
17+
T(end-Tmax/fraction/2:end) = true;
18+
T = t(T);
19+
end
20+
21+
function T = getInterestRange(t,fraction,alignment)
22+
23+
Tmax = length(t);
24+
T = false(size(t));
25+
Pre = alignment-Tmax/fraction/2;
26+
Post = alignment+Tmax/fraction/2;
27+
T(Pre:Post) = true;
28+
T = t(T);
29+
end

+embedding/+MCCA/mcca.m

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
function [V,rho,A,rhotest]=mcca(X,d,Xtest,k)
2+
% [V,rho,A,rhotest]=mcca(X,d,Xtest,k) Multiset Canonical Correlation
3+
% Analysis. X is the data arranged as samples by dimension, whereby all
4+
% sets are concatenated along the dimensions. d is a vector with the
5+
% dimensions of each set. V are the component vectors and rho the resulting
6+
% inter-set correlations. A are the corresponding forward models, which
7+
% are returned as a list of length N. If Xtest is given, it will also
8+
% compute rho for the test data with the optimal V. If k is given then the
9+
% within-set correlation will be reduced in dimension from d to k prior to
10+
% inversion using PCA. This is useful for rank deficient data or for
11+
% regularization. If k is not given, dimension is reduced to the rank of
12+
% the data prior to inversion.
13+
%
14+
% See https://arxiv.org/abs/1802.03759, https://arxiv.org/abs/1801.08881
15+
16+
% Apr 30, 2018, Lucas Parra (c)
17+
% Sep 11, 2018, removed hack for forward model computation that broke the code sometimes
18+
% Sep 14, 2018, make forward model robust to ill conditioned data
19+
% Sep 15, 2018, keep simpler code in case that there is no regularization or rank problem
20+
21+
if ~exist('k','var') || isempty(k), k=d; end
22+
23+
N=length(d);
24+
R=cov(X);
25+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
26+
D(j,j)=R(j,j);
27+
k(i)=min(k(i),rank(D(j,j))); % check rank for oblivious users
28+
end
29+
if sum(d)==sum(k) % simple case
30+
[V,lambda]=eig(R,D);
31+
else % if rank deficient, or if regularization requested
32+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); Dinv(j,j)=embedding.CCA.regInv(D(j,j),k(i)); end
33+
[V,lambda]=eigs(Dinv*R,sum(k));
34+
end
35+
rho = (diag(lambda)-1)/(N-1);
36+
[~,indx]=sort(rho,'descend'); rho=rho(indx); V=V(:,indx);
37+
38+
% compute forward models
39+
if nargout>2
40+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1));
41+
W=V(j,1:k(i)); Rw=R(j,j);
42+
if k(i)==d(i), A{i}=Rw*W/(W'*Rw*W); % original formula, but wont work for rank deficient Rw
43+
else A{i}=Rw*W*diag(1./diag(W'*Rw*W)); end % ignores correlation of components but robust to ill conditioned Rw
44+
end
45+
end
46+
47+
% compute rho for test data
48+
if exist('Xtest') && ~isempty(Xtest)
49+
R=cov(Xtest);
50+
for i=N:-1:1, j=(1:d(i))+sum(d(1:i-1)); D(j,j)=R(j,j); end
51+
lambda = diag(V'*R*V)./diag(V'*D*V);
52+
rhotest = (lambda-1)/(N-1);
53+
end
54+
55+

+embedding/+MCCA/regInv.m

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
function invR = regInv(R, K)
2+
%invR = regInv(R, K)
3+
% PCA regularized inverse of square symmetric positive definite matrix R
4+
if nargin<2, K=size(R,1); end;
5+
if ~ismatrix(R), error('JD: R must have two dimensions'); end;
6+
if size(R,1)~=size(R,2), error('JD: R must be a square matrix'); end;
7+
8+
[U,S,V]=svd(R,0);
9+
diagS=diag(S);
10+
invR=U(:,1:K)*diag(1./diagS(1:K))*V(:,1:K).';
11+
12+
end
13+

@NeuralEmbedding/NeuralEmbedding.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
numPC = 3;
5252
VarExp = .8;
5353

54+
% MCCA regularization parameter
55+
mcca_k = 0.9; % Add this line
56+
5457
% General
5558
useGpu = false
5659
useParallel = false

@NeuralEmbedding/findEmbedding.m

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,79 @@
7575
rethrow(er);
7676
return;
7777
end
78+
79+
case {'MCCA','mcca'}
80+
type = "MCCA";
81+
% Get the names of the parameters for this algorithm
82+
parNames = ["numPC","nArea","nTrial","TrialL","mcca_k"];
83+
% Get the parameters for this algorithm
84+
pars = obj.assignEPars(parNames,type);
85+
86+
% Set default value for mcca_k if not provided
87+
if ~isfield(pars, 'mcca_k') || isempty(pars.mcca_k)
88+
pars.mcca_k = 0.9; % Default regularization
89+
end
90+
91+
try
92+
% Prepare data for MCCA - each area as a separate dataset
93+
nTrials = obj.nTrial;
94+
nAreas = obj.nArea;
95+
96+
% Get the unique areas
97+
areas = unique(obj.Area);
98+
99+
% Create D matrix with trials as rows and areas as columns
100+
D = cell(1, nAreas);
101+
thisAMask = obj.aMask_;
102+
for a = 1:nAreas
103+
% Set the area mask to the current area
104+
obj.aMask = areas(a);
105+
106+
% Get the data for the current area
107+
108+
areaData = obj.S;
109+
D{:, a} = cat(2,areaData{:});
110+
end
111+
112+
% Reset the area mask to all areas
113+
obj.aMask = thisAMask;
114+
115+
% Check if all trials have the same number of time points
116+
nTimePoints = cellfun(@(x) size(x, 1), D);
117+
if any(nTimePoints ~= nTimePoints(1))
118+
error('All trials must have the same number of time points for MCCA');
119+
end
120+
121+
% Get number of neurons in each area
122+
DimensionsPerArea = cellfun(@(x)size(x,1),D);
123+
X = cat(1,D{:})';
124+
125+
% Run MCCA
126+
[V, rho, A] = embedding.MCCA.mcca(X, DimensionsPerArea, [], []);
127+
128+
% Store the full projection matrix
129+
obj.ProjMatrix = A;
130+
131+
% Extract the requested number of components
132+
numPC = min(pars.numPC, size(V, 2));
133+
VarExplained = rho(1:numPC);
134+
obj.VarExplained = VarExplained;
135+
136+
% Create embeddings for each trial
137+
E = cell(nTrials, nAreas);
138+
for t = 1:nTrials
139+
for a = 1:nAreas
140+
% Create the embedding for this trial and area
141+
E{t, a} = A{a} * D{t, a};
142+
end
143+
end
144+
145+
catch er
146+
flag = false;
147+
rethrow(er);
148+
return;
149+
end
150+
78151
case {'umap','UMAP'}
79152
type = "UMAP";
80153
% UMAP is not yet supported. Print a message and return

0 commit comments

Comments
 (0)