-
Notifications
You must be signed in to change notification settings - Fork 11
/
demo_classopath.m
93 lines (78 loc) · 2.04 KB
/
demo_classopath.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
%% lasso with a linear equality constraint
clear;
% set seed
s = RandStream('mt19937ar','Seed',1);
RandStream.setGlobalStream(s);
% dimension
n = 100;
p = 150;
% truth with sum constraint sum(b)=0
beta = zeros(p,1);
beta(1:round(p/4)) = 0;
beta(round(p/4)+1:round(p/2)) = 1;
beta(round(p/2)+1:round(3*p/4)) = 0;
beta(round(3*p/4)+1:end) = -1;
% generate data
X = randn(n,p);
y = X*beta + randn(n,1);
% equality constraint
Aeq = ones(1,p);
beq = 0;
penidx = true(1,p);
p = size(X,2);
%% obtain solution path by path following
tic;
[rhopath, betapath] ...
= lsq_classopath(X,y,[],[],Aeq,beq,...
'qp_solver','matlab','penidx',penidx);
timing_path = toc;
% plot solutions
figure; hold on;
set(gca,'FontSize', 20);
plot(rhopath, betapath');
xlabel('\rho');
ylabel('\beta(\rho)');
xlim([min(rhopath) max(rhopath)*1.05]);
title(['path following algorithm:' num2str(timing_path) ' s']);
%% obtain solution path by GUROBI optimization at grid
betapath_gurobi = zeros(size(betapath));
tic;
for k = 1:length(rhopath)
display(k);
[betapath_gurobi(:,k)] ...
= lsq_constrsparsereg(X,y,rhopath(k),...
'method','qp','qp_solver','gurobi','Aeq', Aeq, 'beq', beq);
end
timing_gurobi = toc;
% plot solutions
figure; hold on;
set(gca,'FontSize', 20);
plot(rhopath, betapath_gurobi');
xlabel('\rho');
ylabel('\beta(\rho)');
xlim([min(rhopath) max(rhopath)*1.05]);
title(['Gurobi on grid:' num2str(timing_gurobi) ' s']);
%% obtain solution path by ADMM optimization at grid
betapath_admm = zeros(size(betapath));
tic;
for k = 1:length(rhopath)
display(k);
if k==1
x0 = zeros(p,1);
else
x0 = betapath(:,k-1);
end
[betapath_admm(:,k)] ...
= lsq_constrsparsereg(X,y,rhopath(k),...
'method','admm','projC', @(x) [x(1); x(2:end)-mean(x(2:end))],...
'x0',x0,'penidx',penidx);
end
timing_admm = toc;
% plot solutions
figure; hold on;
set(gca,'FontSize', 20);
plot(rhopath, betapath_admm');
xlabel('\rho');
ylabel('\beta(\rho)');
xlim([min(rhopath) max(rhopath)*1.05]);
title(['ADMM on grid:' num2str(timing_admm) ' s']);