-
Notifications
You must be signed in to change notification settings - Fork 1
/
random_forest.Rmd
161 lines (124 loc) · 6.96 KB
/
random_forest.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
---
title: "Random Forest Model"
author: "Nilai Vemula"
date: "December 7, 2020"
output:
github_document:
html_preview: no
---
## Random Forest Model
### Loading Random Forest Libraries
```{r, results='hide', message=FALSE, warning=FALSE}
require(randomForest)
require(caTools)
require(tidyverse)
set.seed(2020)
```
### Loading in Data
We can read in the data using the `read_csv()` function. The parameters included in this function are to transform the categorical data into factor data types for easy analysis later on.
```{r}
diabetes_data_upload <- read_csv(
"../data/diabetes_data_upload.csv",
col_types = cols(
Gender = col_factor(levels = c("Male",
"Female")),
Polyuria = col_factor(levels = c("Yes",
"No")),
Polydipsia = col_factor(levels = c("Yes",
"No")),
`sudden weight loss` = col_factor(levels = c("Yes",
"No")),
weakness = col_factor(levels = c("Yes",
"No")),
Polyphagia = col_factor(levels = c("Yes",
"No")),
`Genital thrush` = col_factor(levels = c("Yes",
"No")),
`visual blurring` = col_factor(levels = c("Yes",
"No")),
Itching = col_factor(levels = c("Yes",
"No")),
Irritability = col_factor(levels = c("Yes",
"No")),
`delayed healing` = col_factor(levels = c("Yes",
"No")),
`partial paresis` = col_factor(levels = c("Yes",
"No")),
`muscle stiffness` = col_factor(levels = c("Yes",
"No")),
Alopecia = col_factor(levels = c("Yes",
"No")),
Obesity = col_factor(levels = c("Yes",
"No")),
class = col_factor(levels = c("Positive",
"Negative"))
)
)
diabetes_data_upload <- data.frame(diabetes_data_upload)
summary(diabetes_data_upload)
```
In the summary of our data, we see that there are more patients with diabetes than without, which calls into question the validity of our model. While dataset is unbalanced, 320 vs 200 is not a very large difference, so our model should still be reasonably accurate.
### Splitting Data into a Training and Testing Set
In preparation modeling our data, we can first separate it into a testing and a training set with a split ratio of 0.8. This means that 80% of our overall dataset will be used for training and 20% will be used for testing the model at the end.
```{r}
sample = sample.split(diabetes_data_upload$class, SplitRatio = .80)
train = subset(diabetes_data_upload, sample == 1)
test = subset(diabetes_data_upload, sample == 0)
dim(train)
dim(test)
```
### Building the Model
We will build a random forest model using the `randomForest()` function and typical default parameter values of 100 trees in the forest.
```{r}
rf <- randomForest(
class ~ .,
data=train,
ntree = 100,
importance = TRUE
)
rf
```
```{r}
train_accuracy <- (rf$confusion[1,1] + rf$confusion[2,2]) / (dim(train)[1])
```
After training the random forest model, we see that our model has a training accuracy of `r round(train_accuracy,3)*100`%.
We can plotting the model error vs the number of trees in the forest below and we observe that our model likely does not need the default amount of 100 trees. The error starts to sharply decrease with the first few trees, but then it quickly plateaus.
```{r error}
plot(rf)
```
### Variable Importance
```{r var}
par(bg="#f2f2f2")
varImpPlot(rf, main="Variable Importance Plot", type=1, pch=19)
dev.copy(png, filename = "../plots/variable_importance.png", width = 8, height = 6, units="in", res=400)
dev.off()
```
The random forest model is very interesting because it can allow us to investigate the relative importance of the variables that we use in our model. In the variable importance plot above, we see that polydipsia, polyuria, gender, and age are the four most important variables in our model. This is determined by computing the `MeanDecreaseAccuracy`. This quantity is calculated by changing the values in each feature and then observing how much that change decreases the accuracy of the model. An interesting conclusion from this plot is that age is very important to the accuracy of the model. In our density plots above, we concluded that there was only a slight difference in the mean ages between the patients that had diabetes and did not have diabetes, and age also was not statistically significant in the linear model. Therefore, it is likely that age has a nonlinear relationship with the diabetes status of a patient or that age in combination with another factor is an important predictor.
### Validating on the Testing Set
Seeing that the model had good accuracy on the training set and the conclusions from the variable importance plot are reasonable, we can now validate our model on the remaining 20% of the dataset. We use our random forest model to generate predictions for the remaining samples and then compare those with the observed diabetes status of those patients. Finally, we can visualize this with a confusion matrix.
```{r}
prediction_for_table <- predict(rf,test[,-17])
confusion <- table(observed=test[[17]],predicted=prediction_for_table)
confusion
```
```{r confusion}
true_class <- factor(c('Negative', 'Negative', 'Positive', 'Positive'))
predicted_class <- factor(c('Negative', 'Positive', 'Negative', 'Positive'))
counts <- c(confusion[2,2], confusion[2,1], confusion[1,2], confusion[1,1])
df <- data.frame(true_class, predicted_class, counts)
ggplot(data = df, mapping = aes(x = true_class, y = predicted_class)) +
geom_tile(aes(fill = counts), colour = "white") +
geom_text(aes(label = sprintf("%1.0f", counts)), vjust = 1, colour="white") +
scale_fill_gradient() +
theme_bw() + theme(legend.position = "none") +
labs(title="Confusion Matrix",x="True Class", y="Predicted Class")+
theme(plot.title = element_text(hjust = 0.5)) +
theme(plot.background = element_rect(fill = '#f2f2f2', colour = '#f2f2f2')) +
theme(panel.background = element_rect(fill = '#f2f2f2', colour = '#f2f2f2'))
ggsave("../plots/confusion_matrix.png", width = 4, height = 4, units = "in",dpi= 1200)
```
```{r}
test_accuracy <- (confusion[2,2] + confusion[1,1])/sum(confusion)
print(test_accuracy)
```
From our confusion matrix, we can see that the testing accuracy of our model is `r round(test_accuracy,3)*100`%. This value is similar to that of our training accuracy, indicating that our model is not over-fitting to the data. These model accuracy values are very good, and our model has great predictive power. We can do further hyper-parameter tuning to improve our model more, but that is out of the scope of this project.