-
Notifications
You must be signed in to change notification settings - Fork 22
/
run_experiment.m
59 lines (51 loc) · 1.67 KB
/
run_experiment.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
addpath(genpath('binaryLRloss'));
load('example_data.mat');
rand('seed', 0)
C = cvpartition(y, 'kfold',2);
%X = bsxfun(@rdivide, X, mean(X,1)+1e-10 );
Xtest = X(C.test(1),:);
ytest = y(C.test(1));
Xtrain = X(C.test(2),:);
ytrain = y(C.test(2));
%%
w_init = 0*randn(size(X,2),1);
mfOptions.Method = 'lbfgs';
mfOptions.optTol = 2e-2;
mfOptions.progTol = 2e-6;
mfOptions.LS = 2;
mfOptions.LS_init = 2;
mfOptions.MaxIter = 250;
mfOptions.DerivativeCheck = 0;
results = containers.Map;
casenames = {'LR','DetDropout', 'DetDropoutApprox', 'Dropout'};
for casenum = 1:length(casenames)
obj = casenames{casenum};
switch obj
case 'LR'
funObj = @(w)LogisticLoss(w,Xtrain,ytrain);
lambdaL2=0.01;
% you can optimize this value on the test set,
% and LR would still be quite a bit worse
case 'DetDropout'
funObj = @(w)LogisticLossDetObjDropout(w,Xtrain,ytrain,0.5);
lambdaL2=0.01;
case 'DetDropoutApprox'
funObj = @(w)LogisticLossDetObjDropoutDeltaApprox(w,Xtrain,ytrain,0.5);
lambdaL2=0.01;
case 'Dropout'
funObj = @(w)LogisticLossMCDropoutSample(w,Xtrain,ytrain,0.5,100,100);
lambdaL2=0.01;
end
funObjL2 = @(w)penalizedL2(w,funObj,lambdaL2);
w = minFunc(funObjL2,w_init,mfOptions);
ypred = Xtest * w > 0;
acc = sum(ypred == (ytest+1)/2 )/length(ytest);
% ypred = Xtrain * w > 0;
% acc = sum(ypred == (ytrain+1)/2 )/length(ytrain);
resultname = [casenames{casenum}];
results(resultname) = acc;
end
keys = results.keys;
for i=1:length(keys)
fprintf('%s: %f\n', keys{i}, results(keys{i}));
end