-
Notifications
You must be signed in to change notification settings - Fork 1
/
list-of-list-solution.rkt
119 lines (99 loc) · 3.5 KB
/
list-of-list-solution.rkt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#lang racket
#|
Attribution:
This implementation of decision trees in Racket was written by Daniel Prager and
was originally shared at:
https://groups.google.com/forum/#!topic/racket-users/cPuTr8lrXCs
With permission it was added to the project.
|#
(define (string->data s [sep " "])
(for/list ([line (in-list (string-split s #rx"\r?\n"))])
(map string->number (string-split line sep))))
(define banknote-data
(string->data (file->string "data_banknote_authentication.csv") ","))
(define test-data
(string->data
"2.771244718 1.784783929 0
1.728571309 1.169761413 0
3.678319846 2.81281357 0
3.961043357 2.61995032 0
2.999208922 2.209014212 0
7.497545867 3.162953546 1
9.00220326 3.339047188 1
7.444542326 0.476683375 1
10.12493903 3.234550982 1
6.642287351 3.319983761 1"))
(define (make-split rows index value)
(define-values (left right)
(for/fold ([left null] [right null])
([row (in-list rows)])
(if (< (list-ref row index) value)
(values (cons row left) right)
(values left (cons row right)))))
(list left right))
(define (gini-coefficient splits)
(for/sum ([split (in-list splits)])
(define n (* 1.0 (length split)))
(define (g v) (* (/ v n) (- 1.0 (/ v n))))
(if (zero? n)
0
(let ([m (for/sum ([row (in-list split)] #:when (zero? (last row)))
1)])
(+ (g m) (g (- n m)))))))
(define (get-split rows)
(define-values (best index value _)
(for*/fold ([best null] [i -1] [v -1] [score 999])
([index (in-range (sub1 (length (first rows))))]
[row (in-list rows)])
(let* ([value (list-ref row index)]
[s (make-split rows index value)]
[gini (gini-coefficient s)])
(if (< gini score)
(values s index value gini)
(values best i v score)))))
(list index value best))
(define (to-terminal group)
(define zeros (count (λ (row) (zero? (last row))) group))
(if (> zeros (- (length group) zeros)) 0 1))
(define (split node max-depth min-size depth)
(match-define (list index value (list left right)) node)
(define (split-if-small branch)
(if (<= (length branch) min-size)
(to-terminal branch)
(split (get-split branch) max-depth min-size (add1 depth))))
(cond [(null? left) (to-terminal right)]
[(null? right) (to-terminal left)]
[(>= depth max-depth) (list index value
(to-terminal left) (to-terminal right))]
[else (list index value
(split-if-small left) (split-if-small right))]))
(define (build-tree rows max-depth min-size)
(split (get-split rows) max-depth min-size 1))
(define (predict node row)
(if (list? node)
(match-let ([(list index value left right) node])
(predict (if (< (list-ref row index) value)
left
right)
row))
node))
(define (check-model model validation-set)
(/ (count (λ (row) (= (predict model row) (last row)))
validation-set)
(length validation-set)
1.0))
;(define test-model (build-tree test-data 1 1))
;(for/list ([row (in-list test-data)])
; (list row (predict test-model row)))
(define data (shuffle banknote-data))
(define model (time (build-tree (take data 274) 5 10)))
model
(check-model model (drop data 274))
(random-seed 12345)
(define data2 (shuffle banknote-data))
(time
(void
(build-tree (take data2 274) 5 10)))
(time
(for ([i (in-range 20)])
(build-tree (take data2 274) 5 10)))