-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdsca_bothmerge.m
126 lines (100 loc) · 3.29 KB
/
dsca_bothmerge.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
%%% history
%%% - 2020/10/22 y.takagi - initially created with modifying Dmtry Kobak's dPCA program
%%% see also: https://github.com/machenslab/dPCA
function [W, V, whichMarg] = dsca_bothmerge(Xfull,Yfull, numComps, varargin)
% default input parameters
options = struct('combinedParams', [], ...
'lambda', 0, ...
'order', 'yes', ...
'timeSplits', [], ...
'timeParameter', [], ...
'notToSplit', [], ...
'Cnoise', []);
% read input parameters
optionNames = fieldnames(options);
if mod(length(varargin),2) == 1
error('Please provide propertyName/propertyValue pairs')
end
for pair = reshape(varargin,2,[]) % pair is {propName; propValue}
if any(strcmp(pair{1}, optionNames))
options.(pair{1}) = pair{2};
else
error('%s is not a recognized parameter name', pair{1})
end
end
% centering
X = Xfull(:,:);
X = bsxfun(@minus, X, nanmean(X,2));
XfullCen = reshape(X, size(Xfull));
Y = Yfull(:,:);
Y = bsxfun(@minus, Y, nanmean(Y,2));
YfullCen = reshape(Y, size(Yfull));
% total variance
totalVar = sum(Y(:).^2);
% marginalize
[Xmargs, margNums] = dpca_marginalize(XfullCen, 'combinedParams', options.combinedParams, ...
'timeSplits', options.timeSplits, ...
'timeParameter', options.timeParameter, ...
'notToSplit', options.notToSplit, ...
'ifFlat', 'yes');
[Ymargs, margNums] = dpca_marginalize(YfullCen, 'combinedParams', options.combinedParams, ...
'timeSplits', options.timeSplits, ...
'timeParameter', options.timeParameter, ...
'notToSplit', options.notToSplit, ...
'ifFlat', 'yes');
% initialize
decoder = [];
encoder = [];
whichMarg = [];
% noise covariance
if isempty(options.Cnoise)
options.Cnoise = zeros(size(X,1));
end
% loop over marginalizations
for i=1:length(Ymargs)
if length(numComps) == 1
nc = numComps;
else
nc = numComps(margNums(i));
end
if length(options.lambda) == 1
thisLambda = options.lambda;
else
thisLambda = options.lambda(margNums(i));
end
if nc == 0
continue
end
C = Ymargs{i}*Xmargs{i}'*pinv(Xmargs{i}*Xmargs{i}' + options.Cnoise + (totalVar*thisLambda)^2*eye(size(Xmargs{i},1)));
M = C*Xmargs{i};
[U,~,~] = eigs(M*M', nc);
P = U;
D = U'*C;
decoder = [decoder; D];
encoder = [encoder P];
whichMarg = [whichMarg i*ones(1, nc)];
end
% transposing
V = encoder;
W = decoder';
% flipping axes such that all encoders have more positive values
toFlip = find(sum(sign(V))<0);
W(:, toFlip) = -W(:, toFlip);
V(:, toFlip) = -V(:, toFlip);
% ordering components by explained variance (or not)
if length(numComps) == 1 || strcmp(options.order, 'yes')
for i=1:size(W,2)
Z = Y - V(:,i)*(W(:,i)'*X);
explVar(i) = 1 - sum(Z(:).^2)/totalVar;
end
[~ , order] = sort(explVar, 'descend');
if length(numComps) == 1
L = numComps;
else
L = sum(numComps);
end
W = W(:, order(1:L));
V = V(:, order(1:L));
whichMarg = whichMarg(order(1:L));
end
end