-
Notifications
You must be signed in to change notification settings - Fork 45
/
kNN_cosine.py
executable file
·64 lines (50 loc) · 2.09 KB
/
kNN_cosine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#########################################
# kNN: k Nearest Neighbors
# Input: newInput: vector to compare to existing dataset (1xN)
# dataSet: size m data set of known vectors (NxM)
# labels: data set labels (1xM vector)
# k: number of neighbors to use for comparison
# Output: the most popular class label
#########################################
from numpy import *
import operator
import math
import tensorflow as tf
import numpy as np
# create a dataset which contains 4 samples with 2 classes
def createDataSet():
# create a matrix: each row as a sample
group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
labels = ['A', 'A', 'B', 'B'] # four samples and two classes
return group, labels
def cosine_distance(v1,v2):
"compute cosine similarity of v1 to v2: (v1 dot v2)/{||v1||*||v2||)"
v1_sq = np.inner(v1,v1)
v2_sq = np.inner(v2,v2)
dis = 1 - np.inner(v1,v2) / math.sqrt(v1_sq * v2_sq)
return dis
# classify using kNN
def kNNClassify(newInput, dataSet, labels, k):
global distance
distance = [0]* dataSet.shape[0]
for i in range(dataSet.shape[0]):
distance[i] = cosine_distance(newInput, dataSet[i])
## step 2: sort the distance
# argsort() returns the indices that would sort an array in a ascending order
sortedDistIndices = argsort(distance)
classCount = {} # define a dictionary (can be append element)
for i in xrange(k):
## step 3: choose the min k distance
voteLabel = labels[sortedDistIndices[i]]
## step 4: count the times labels occur
# when the key voteLabel is not in dictionary classCount, get()
# will return 0
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
## step 5: the max voted class will return
maxCount = 0
for key, value in classCount.items():
if value > maxCount:
maxCount = value
maxIndex = key
return maxIndex
#return sortedDistIndices