-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.m
30 lines (21 loc) · 894 Bytes
/
model.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
function accuracy=model(training_set, testing_set)
rng(0) % comment out the rng seed so we can get true cross-validation
train_input=prepare_input_dataset(training_set(:,1:end-1));
train_output=prepare_output_dataset(training_set(:,end));
test_input=prepare_input_dataset(testing_set(:,1:end-1));
test_output=prepare_output_dataset(testing_set(:,end));
% Initialize network
n_hidden = 4;
n_inputs= 8;
n_outputs= 4;
net=initialize_network(n_inputs,n_hidden,n_outputs);
% Set training parameters
l_rate = 5; n_epoch = 200;
[net] = train_network(net, train_input, train_output, l_rate, n_epoch);
for j=1:size(test_output,1)
true_val(j)=find(test_output(j,:)==max(test_output(j,:)));
end
[test]=predict(net, test_input);
correct=numel(find(test==true_val));
accuracy=correct/length(test);
end