-
Notifications
You must be signed in to change notification settings - Fork 0
/
knn.py
82 lines (71 loc) · 2.84 KB
/
knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
knn.py
~~~~~~~~~~~~~~~
Kth Nearest Neighbors implementation.
This module contains functions that can be used for kNN classification.
Call kNN(modelData, instance, k) on an existing modelData of training data to classify
a new instance based on the k nearest neighbors.
Example:
$ python knn.py
>>> import model
>>> m1 = Model("m1", 3)
>>> m1.load("model/test_model_1")
>>>
>>> instance = [ 0.432, 0.192, 0.416 ]
>>> print kNN(m1.data, instance, 5)
>>> { 'A': 0.8 , 'B': 0.2 }
"""
def kNN(modelData, instance, k):
"""Returns a dictionary of vote proportions for the kth nearest neighbors
of the instance in the modelData.
This is the main function called by other files."""
n = len(instance)
neighbors = allNeighbors(modelData, instance, n)
kNearest = kNearestNeighbors(neighbors, k)
return vote(kNearest)
def euclideanDistance(a, b, n):
"""Computes length of line segment connecting 2 n-dimensional points.
Each point is a list of at least length n."""
d = sum((a[i] - b[i]) ** 2 for i in xrange(n))
return d ** 0.5
def allNeighbors(modelData, instance, n):
"""Returns a list of (sym, distance) tuples of the modelData set, where
n is the dimenstionality of the instance used to calculate distance and
sym is the classification of the modelData data.
The modelData should be a list of tuples (sym, data)."""
neighbors = []
for (sym, data) in modelData:
distance = euclideanDistance(instance, data, n)
neighbors.append((sym, distance))
return neighbors
def kNearestNeighbors(neighbors, k):
"""Returns a list of the k neighbors with the least distance. Each element
in neighbors is a (sym, distance) tuple."""
# sort by comparing each tuple's distance (index = 1)
sortedNeighbors = sorted(neighbors, lambda a, b: cmp(a[1], b[1]))
return sortedNeighbors[:k]
def vote(neighbors):
"""Returns dictionary of the proportion of each instance in form
(instance, distance) tuples."""
total = 0
count = dict()
# Count all instances
for (instance, distance) in neighbors:
total += 1
if (instance in count):
count[instance] += 1
else:
count[instance] = 1
# Divide each count by total to get proportion
for instance in count.keys():
count[instance] = float(count[instance]) / total
return count
def topNClasses(voteProportions, n):
"""Returns a sorted descending list of the top n classes in a vote."""
votes = []
for key in voteProportions.keys(): # put votes into a list
votes.append((key, voteProportions[key]))
# sort votes by comparing vote proportion (index 1)
votes = sorted(votes, lambda a, b: cmp(a[1], b[1]))
votes = votes[::-1] # reverse to get descending order
return votes[:n] # return the n highest