-
Notifications
You must be signed in to change notification settings - Fork 0
/
KNN图像识别.py
78 lines (73 loc) · 2.66 KB
/
KNN图像识别.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding:utf-8 -*-
import csv
import random
import math
import operator
#加载数据集
def loadDataset(filename,split,trainingSet=[],testSet=[]):
with open(filename,'r') as csvfile:
lines=csv.reader(csvfile)
dataset=list(lines)
for x in range(len(dataset)-1):
for y in range(4):
dataset[x][y]=float(dataset[x][y])
if random.random()<split: #将数据集随机划分
trainingSet.append(dataset[x])
else:
testSet.append(dataset[x])
#计算点之间的距离,多维度的
def euclideanDistance(instance1,instance2,length):
for x in range(length):
distance +=pow((instance1[x]-instance2[x]),2)
return math.sqrt(distance)
#获取K个邻居
def getNeighbors(trainingSet,testInstance,k):
distances=[]
length=len(testInstance)-1
for x in range(len(trainingSet)):
dist=euclideanDistance(testInstance,trainingSet[x],length)
distances.append((trainingSet[x],dist)) #获取测试点到其他点的距离
distances.sort(key=operator.itemgetter(1)) #对所有的距离进行排序
neighbors=[]
for x in range(k): #获取距离最近的K个点
neighbors.append(distances[x][0])
return neighbors
#得到这K个邻居的分类中最多的一类
def getResponse(neighbors):
classVotes={}
for x in range(len(neighbors)):
response=neighbors[x][-1]
if response in classVotes:
classVotes[response]+=1
else:
classVotes[response]=1
sortedVotes=sorted(classVotes.items(),key=operator.itemgetter(1),reverse=True)
return sortedVotes[0][0] #获取票数最多的类别
#计算预测的准确率
def getAccuracy(testSet,predictions):
correct=0
for x in range(len(testSet)):
if testSet[x][-1]==predictions[x]:
correct +=1
return (correct/float(len(testSet)))*100.0
def main():
#prepare data
trainingSet=[]
testSet=[]
split=0.9
loadDataset(r't10k-images.idx3-ubyte',split,trainingSet,testSet)
print('TrainSet:'+ repr(len(trainingSet)))
print('TestSet:'+ repr(len(testSet)))
#generate predictions
predictions=[]
k=3
for x in range(len(testSet)):
neighbors=getNeighbors(trainingSet,testSet[x],k)
result=getResponse(neighbors)
predictions.append(result)
print('predicted='+repr(result)+',actual='+ repr(testSet[x][-1]))
print('predictions:'+repr(predictions))
accuracy=getAccuracy(testSet,predictions)
print('Accuracy:'+ repr(accuracy)+'%')
if __name__ == "__main__":
main()