Skip to content

Commit a24d639

Browse files
committed
add multi_NNLRS
1 parent d8ec346 commit a24d639

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

multi_NNLRS.m

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
function [] = multi_NNLRS(X,lambda,beta,alpha)
2+
% init vars
3+
k=length(X);
4+
[m,n]=size(X{1});
5+
6+
Z=cell(k,1);
7+
Z{1:k}=zeros(n);
8+
E=cell(k,1);
9+
E{1:k}=zeros(m,n)
10+
S=cell(k,1);
11+
S{1:k}=zeros(n);
12+
J=cell(k,1);
13+
J{1:k}=zeros(n);
14+
Y1=cell(k,1);
15+
Y1{1:k}=zeros(m,n);
16+
Y2=cell(k,1);
17+
Y2{1:k}=zeros(n);
18+
Y3=cell(k,1);
19+
Y3{1:k}=zeros(n);
20+
Zk=Z;
21+
Ek=E;
22+
Sk=S;
23+
Jk=J;
24+
svp=cell(k,1);
25+
svp{1:k}=0;
26+
F=Z;
27+
ZZ=zeros(k,n*n);
28+
29+
% precomputed values
30+
xtx=cell(k,1);
31+
for i=1:k
32+
xtx{i}=X{i}'*X{i};
33+
end
34+
invx=cell(k,1);
35+
for i=1:k
36+
invx{i}=inv(xtx{i}+eye(n));
37+
end
38+
Xf=cell(k,1);
39+
for i=1:k
40+
Xf{i}=norm(X{i},'fro');
41+
end
42+
% the residual error and the error between Z,J,S
43+
Xc=cell(k,1);
44+
ZJc=cell(k,1);
45+
ZSc=cell(k,1);
46+
47+
% parameters
48+
norm2X=cell(k,1);
49+
for i=1:k
50+
norm2X{i}=norm(X{i},2);
51+
end
52+
eta1=cell(k,1);
53+
for i=1:k
54+
eta1{i}=norm2X{i}*norm2X{i}*1.02;%eta needs to be larger than ||X||_2^2, but need not be too large.
55+
end
56+
mu=1e-6;
57+
max_mu=10^10;
58+
rho=1.9;
59+
% epsilon=1e-4;
60+
% epsilon2=1e-5; % must be small!
61+
epsilon=1e-6;
62+
epsilon2=1e-5; % must be small!
63+
MAX_ITER=1000;
64+
iter=0;
65+
convergenced=false;
66+
clambda=cell(k,1);
67+
clambda{1:k}=lambda;
68+
69+
while ~convergenced
70+
if iter>MAX_ITER
71+
fprintf(1,'max iter num reached!\n');
72+
break;
73+
end
74+
cmu=cell(k,1);
75+
cmu(1:k)={mu};
76+
% update S_i
77+
Sk=S;
78+
[S, svp]=cellfun(@updateS,xtx,X,E,Y1,Z,S,Sk,Y3,eta1,cmu,'UniformOutput',false);
79+
% update J_i
80+
Jk=J;
81+
[J]=cellfun(@updateJ,Z,J,Y2,cmu,'UniformOutput',false);
82+
% update Z
83+
[F]=cellfun(@updateF,J,Y2,S,Y3,cmu,'UniformOutput',false);
84+
[M]=cellfun(@updateM,F,'UniformOutput',false);
85+
for i=1:k
86+
ZZ(i,:)=M{i};
87+
end
88+
ZZ=l21(ZZ,alpha/mu);
89+
% update Z_i
90+
Zk=Z;
91+
for i=1:k
92+
Z{i}=reshape(ZZ(i,:),n,n)';
93+
end
94+
% update E_i
95+
[E]=cellfun(@updateE,X,S,E,Y1,cmu,clambda,'UniformOutput',false);
96+
97+
% parameter update rule
98+
99+
% check convergence
100+
[Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = cellfun(@caculateTempVars,X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,'UniformOutput',false);
101+
changeX=max([Xv{:}]);
102+
changeZJ=max([ZJv{:}]);
103+
changeZS=max([ZSv{:}]);
104+
changeZ=max([Zc{:}]);
105+
changeJ=max([Jc{:}]);
106+
changeS=max([Sc{:}]);
107+
changeE=max([Ec{:}]);
108+
tmp=[changeZ changeJ changeS changeE ];
109+
gap=mu*max(tmp);
110+
if mod(iter,50)==0
111+
fprintf(1,'===========================================================================================================\n');
112+
fprintf(1,'gap between two iteration is %f,mu is %f\n',gap,mu);
113+
fprintf(1,'iter %d,mu is %f,ResidualX is %f,changeZJ is %f,changeZS is %f\n',iter,mu,changeX,changeZJ,changeZS);
114+
for i=1:k
115+
fprintf(1,'svp%d %d,',i,svp{i});
116+
end
117+
fprintf(1,'\n');
118+
end
119+
% if changeX <= epsilon && changeZJ <= epsilon && changeZS <= epsilon
120+
if changeX <= epsilon && gap <=epsilon2 && changeZJ <= epsilon && changeZS <= epsilon
121+
convergenced=true;
122+
fprintf(2,'convergenced, iter is %d\n',iter);
123+
fprintf(2,'iter %d,mu is %f,ResidualX is %f,changeZJ is %f,changeZS is %f\n',iter,mu,changeX,changeZJ,changeZS);
124+
for i=1:k
125+
fprintf(1,'svp%d %d,',i,svp{i});
126+
end
127+
fprintf(1,'\n');
128+
end
129+
% update multipliers
130+
[Y1]=cellfun(@updateY1,Y1,cmu,Xc,'UniformOutput',false);
131+
[Y2]=cellfun(@updateY2,Y2,cmu,ZJc,'UniformOutput',false);
132+
[Y3]=cellfun(@updateY3,Y3,cmu,ZSc,'UniformOutput',false);
133+
% update parameters
134+
if gap < epsilon2
135+
mu=min(rho*mu,max_mu);
136+
end
137+
iter=iter+1;
138+
end
139+
140+
function [S,svp] = updateS(xtx,X,E,Y1,Z,S,Sk,Y3,eta1,mu)
141+
T=-mu*(xtx-xtx*S+X'*E+X'*Y1/mu+Z-S+Y3/mu);
142+
% argmin_{S} 1/(mu*eta1)||S||_*+1/2*||S-S_k+T/(mu*eta1)||_F^2
143+
[S,svp]=singular_value_shrinkage(Sk-T/(mu*eta1),1/(mu*eta1)); % TODO: sometimes PROPACK is slower than full svd, and sometimes it will throw the following error
144+
145+
function [J] = updateJ(Z,J,Y2,mu)
146+
J=wthresh(Z+Y2/mu,'s',2*beta);
147+
148+
function [RET] = updateF(J,Y2,S,Y3,mu)
149+
RET=1/2*(J-Y2/mu+S-Y3/mu);
150+
151+
function [M] = updateM(F)
152+
n=length(F);
153+
M=reshape(F',1,n*n);
154+
155+
function [E] = updateE(X,S,E,Y1,mu,lambda)
156+
E=l21(X*S-X-Y1/mu,lambda/mu);
157+
158+
function [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = caculateTempVars(X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf)
159+
Xc=X-X*S-E;
160+
ZJc=Z-J;
161+
ZSc=Z-S;
162+
Xv=norm(Xc,'fro')/Xf;
163+
ZJv=norm(ZJc,'fro')/Xf;
164+
ZSv=norm(ZSc,'fro')/Xf;
165+
166+
Zc=norm(Zk-Z,'fro')/Xf;
167+
Jc=norm(Jk-J,'fro')/Xf;
168+
Sc=norm(Sk-S,'fro')/Xf;
169+
Ec=norm(Ek-E,'fro')/Xf;
170+
171+
function [Y1] = updateY1(Y1,mu,Xc)
172+
Y1=Y1+mu*Xc;
173+
174+
function [Y2] = updateY2(Y2,mu,ZJc)
175+
Y2=Y2+mu*ZJc;
176+
177+
function [Y3] = updateY3(Y3,mu,ZSc)
178+
Y3=Y3+mu*ZSc;

0 commit comments

Comments
 (0)