|
| 1 | +function [] = multi_NNLRS(X,lambda,beta,alpha) |
| 2 | +% init vars |
| 3 | +k=length(X); |
| 4 | +[m,n]=size(X{1}); |
| 5 | + |
| 6 | +Z=cell(k,1); |
| 7 | +Z{1:k}=zeros(n); |
| 8 | +E=cell(k,1); |
| 9 | +E{1:k}=zeros(m,n) |
| 10 | +S=cell(k,1); |
| 11 | +S{1:k}=zeros(n); |
| 12 | +J=cell(k,1); |
| 13 | +J{1:k}=zeros(n); |
| 14 | +Y1=cell(k,1); |
| 15 | +Y1{1:k}=zeros(m,n); |
| 16 | +Y2=cell(k,1); |
| 17 | +Y2{1:k}=zeros(n); |
| 18 | +Y3=cell(k,1); |
| 19 | +Y3{1:k}=zeros(n); |
| 20 | +Zk=Z; |
| 21 | +Ek=E; |
| 22 | +Sk=S; |
| 23 | +Jk=J; |
| 24 | +svp=cell(k,1); |
| 25 | +svp{1:k}=0; |
| 26 | +F=Z; |
| 27 | +ZZ=zeros(k,n*n); |
| 28 | + |
| 29 | +% precomputed values |
| 30 | +xtx=cell(k,1); |
| 31 | +for i=1:k |
| 32 | + xtx{i}=X{i}'*X{i}; |
| 33 | +end |
| 34 | +invx=cell(k,1); |
| 35 | +for i=1:k |
| 36 | + invx{i}=inv(xtx{i}+eye(n)); |
| 37 | +end |
| 38 | +Xf=cell(k,1); |
| 39 | +for i=1:k |
| 40 | + Xf{i}=norm(X{i},'fro'); |
| 41 | +end |
| 42 | +% the residual error and the error between Z,J,S |
| 43 | +Xc=cell(k,1); |
| 44 | +ZJc=cell(k,1); |
| 45 | +ZSc=cell(k,1); |
| 46 | + |
| 47 | +% parameters |
| 48 | +norm2X=cell(k,1); |
| 49 | +for i=1:k |
| 50 | + norm2X{i}=norm(X{i},2); |
| 51 | +end |
| 52 | +eta1=cell(k,1); |
| 53 | +for i=1:k |
| 54 | + eta1{i}=norm2X{i}*norm2X{i}*1.02;%eta needs to be larger than ||X||_2^2, but need not be too large. |
| 55 | +end |
| 56 | +mu=1e-6; |
| 57 | +max_mu=10^10; |
| 58 | +rho=1.9; |
| 59 | +% epsilon=1e-4; |
| 60 | +% epsilon2=1e-5; % must be small! |
| 61 | +epsilon=1e-6; |
| 62 | +epsilon2=1e-5; % must be small! |
| 63 | +MAX_ITER=1000; |
| 64 | +iter=0; |
| 65 | +convergenced=false; |
| 66 | +clambda=cell(k,1); |
| 67 | +clambda{1:k}=lambda; |
| 68 | + |
| 69 | +while ~convergenced |
| 70 | + if iter>MAX_ITER |
| 71 | + fprintf(1,'max iter num reached!\n'); |
| 72 | + break; |
| 73 | + end |
| 74 | + cmu=cell(k,1); |
| 75 | + cmu(1:k)={mu}; |
| 76 | + % update S_i |
| 77 | + Sk=S; |
| 78 | + [S, svp]=cellfun(@updateS,xtx,X,E,Y1,Z,S,Sk,Y3,eta1,cmu,'UniformOutput',false); |
| 79 | + % update J_i |
| 80 | + Jk=J; |
| 81 | + [J]=cellfun(@updateJ,Z,J,Y2,cmu,'UniformOutput',false); |
| 82 | + % update Z |
| 83 | + [F]=cellfun(@updateF,J,Y2,S,Y3,cmu,'UniformOutput',false); |
| 84 | + [M]=cellfun(@updateM,F,'UniformOutput',false); |
| 85 | + for i=1:k |
| 86 | + ZZ(i,:)=M{i}; |
| 87 | + end |
| 88 | + ZZ=l21(ZZ,alpha/mu); |
| 89 | + % update Z_i |
| 90 | + Zk=Z; |
| 91 | + for i=1:k |
| 92 | + Z{i}=reshape(ZZ(i,:),n,n)'; |
| 93 | + end |
| 94 | + % update E_i |
| 95 | + [E]=cellfun(@updateE,X,S,E,Y1,cmu,clambda,'UniformOutput',false); |
| 96 | + |
| 97 | + % parameter update rule |
| 98 | + |
| 99 | + % check convergence |
| 100 | + [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = cellfun(@caculateTempVars,X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,'UniformOutput',false); |
| 101 | + changeX=max([Xv{:}]); |
| 102 | + changeZJ=max([ZJv{:}]); |
| 103 | + changeZS=max([ZSv{:}]); |
| 104 | + changeZ=max([Zc{:}]); |
| 105 | + changeJ=max([Jc{:}]); |
| 106 | + changeS=max([Sc{:}]); |
| 107 | + changeE=max([Ec{:}]); |
| 108 | + tmp=[changeZ changeJ changeS changeE ]; |
| 109 | + gap=mu*max(tmp); |
| 110 | + if mod(iter,50)==0 |
| 111 | + fprintf(1,'===========================================================================================================\n'); |
| 112 | + fprintf(1,'gap between two iteration is %f,mu is %f\n',gap,mu); |
| 113 | + fprintf(1,'iter %d,mu is %f,ResidualX is %f,changeZJ is %f,changeZS is %f\n',iter,mu,changeX,changeZJ,changeZS); |
| 114 | + for i=1:k |
| 115 | + fprintf(1,'svp%d %d,',i,svp{i}); |
| 116 | + end |
| 117 | + fprintf(1,'\n'); |
| 118 | + end |
| 119 | + % if changeX <= epsilon && changeZJ <= epsilon && changeZS <= epsilon |
| 120 | + if changeX <= epsilon && gap <=epsilon2 && changeZJ <= epsilon && changeZS <= epsilon |
| 121 | + convergenced=true; |
| 122 | + fprintf(2,'convergenced, iter is %d\n',iter); |
| 123 | + fprintf(2,'iter %d,mu is %f,ResidualX is %f,changeZJ is %f,changeZS is %f\n',iter,mu,changeX,changeZJ,changeZS); |
| 124 | + for i=1:k |
| 125 | + fprintf(1,'svp%d %d,',i,svp{i}); |
| 126 | + end |
| 127 | + fprintf(1,'\n'); |
| 128 | + end |
| 129 | + % update multipliers |
| 130 | + [Y1]=cellfun(@updateY1,Y1,cmu,Xc,'UniformOutput',false); |
| 131 | + [Y2]=cellfun(@updateY2,Y2,cmu,ZJc,'UniformOutput',false); |
| 132 | + [Y3]=cellfun(@updateY3,Y3,cmu,ZSc,'UniformOutput',false); |
| 133 | + % update parameters |
| 134 | + if gap < epsilon2 |
| 135 | + mu=min(rho*mu,max_mu); |
| 136 | + end |
| 137 | + iter=iter+1; |
| 138 | +end |
| 139 | + |
| 140 | +function [S,svp] = updateS(xtx,X,E,Y1,Z,S,Sk,Y3,eta1,mu) |
| 141 | + T=-mu*(xtx-xtx*S+X'*E+X'*Y1/mu+Z-S+Y3/mu); |
| 142 | + % argmin_{S} 1/(mu*eta1)||S||_*+1/2*||S-S_k+T/(mu*eta1)||_F^2 |
| 143 | + [S,svp]=singular_value_shrinkage(Sk-T/(mu*eta1),1/(mu*eta1)); % TODO: sometimes PROPACK is slower than full svd, and sometimes it will throw the following error |
| 144 | + |
| 145 | +function [J] = updateJ(Z,J,Y2,mu) |
| 146 | + J=wthresh(Z+Y2/mu,'s',2*beta); |
| 147 | + |
| 148 | +function [RET] = updateF(J,Y2,S,Y3,mu) |
| 149 | + RET=1/2*(J-Y2/mu+S-Y3/mu); |
| 150 | + |
| 151 | +function [M] = updateM(F) |
| 152 | + n=length(F); |
| 153 | + M=reshape(F',1,n*n); |
| 154 | + |
| 155 | +function [E] = updateE(X,S,E,Y1,mu,lambda) |
| 156 | + E=l21(X*S-X-Y1/mu,lambda/mu); |
| 157 | + |
| 158 | +function [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = caculateTempVars(X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf) |
| 159 | + Xc=X-X*S-E; |
| 160 | + ZJc=Z-J; |
| 161 | + ZSc=Z-S; |
| 162 | + Xv=norm(Xc,'fro')/Xf; |
| 163 | + ZJv=norm(ZJc,'fro')/Xf; |
| 164 | + ZSv=norm(ZSc,'fro')/Xf; |
| 165 | + |
| 166 | + Zc=norm(Zk-Z,'fro')/Xf; |
| 167 | + Jc=norm(Jk-J,'fro')/Xf; |
| 168 | + Sc=norm(Sk-S,'fro')/Xf; |
| 169 | + Ec=norm(Ek-E,'fro')/Xf; |
| 170 | + |
| 171 | +function [Y1] = updateY1(Y1,mu,Xc) |
| 172 | + Y1=Y1+mu*Xc; |
| 173 | + |
| 174 | +function [Y2] = updateY2(Y2,mu,ZJc) |
| 175 | + Y2=Y2+mu*ZJc; |
| 176 | + |
| 177 | +function [Y3] = updateY3(Y3,mu,ZSc) |
| 178 | + Y3=Y3+mu*ZSc; |
0 commit comments