文章目录
1.K-近邻算法2.K-近邻模型(三个基本要素)1.距离度量2.K值的选择3.分类决策规则
3.kd树1.构造平衡kd树2.搜索kd树
4.K近邻代码实现
1.K-近邻算法
2.K-近邻模型(三个基本要素)
1.距离度量
2.K值的选择
3.分类决策规则
3.kd树
通过线性扫描实现k近邻算法,当训练集很大时,计算非常耗时
因此,需要考虑如何对训练数据进行快速的k近邻搜索
为了提高k近邻搜索效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离次数
kd树正是这个方法
1.构造平衡kd树
P54-55
2.搜索kd树
P56-57
4.K近邻代码实现
"""
@author: liujie
@software: PyCharm
@file: KNN.py
@time: 2020/10/20 22:25
"""
import time
import numpy
as np
from tqdm
import tqdm
def loaddata(filename
):
"""
加载数据
:param filename:文件路径
:return: 返回数据集与标签
"""
print('start to read file')
dataArr
= []
labelArr
= []
fr
= open(filename
)
for line
in tqdm
(fr
.readlines
()):
currentLine
= line
.strip
().split
(',')
dataArr
.append
([int(num
) for num
in currentLine
[1:]])
labelArr
.append
(int(currentLine
[0]))
return dataArr
, labelArr
def calDist(x1
,x2
):
"""
计算欧式距离
:param x1: 向量1
:param x2: 向量2
:return: 欧式距离
"""
return np
.sqrt
(np
.sum(np
.square
(x1
- x2
)))
def getClosest(trainDataMat
, trainLabelMat
, x
, topK
):
"""
预测x的标记
多数表决
:param trainDataMat:训练数据集
:param trainLabelMat: 训练数据标签
:param x: 预测样本x
:param topK: 选择参考最邻近样本的数目
:return: 预测的标记
"""
distDict
= {}
for i
in tqdm
(range(len(trainDataMat
))):
xi
= trainDataMat
[i
]
curDist
= calDist
(x
,xi
)
distDict
[i
] = curDist
dist_list_topK
= sorted(distDict
.items
(),key
= lambda x
:x
[1],reverse
=False)[:topK
]
dist_dict_topk
= dict(dist_list_topK
)
labelLict
= [0] * 10
for index
in dist_dict_topk
:
labelLict
[int(trainLabelMat
[index
])] += 1
return np
.argsort
(np
.array
(labelLict
))[-1]
def model_test(trainData
,trainLabel
,testData
,testLabel
,topK
):
"""
测试正确率
:param trainData:训练数据集
:param trainLabel: 训练标签
:param testData: 测试数据集
:param testLabel: 测试标签
:param topK: 选择多少个临近点参考
:return: 正确率
"""
print('start to test')
trainDataMat
= np
.mat
(trainData
)
trainLabelMat
= np
.mat
(trainLabel
).T
testDataMat
= np
.mat
(testData
)
testLabelMat
= np
.mat
(testLabel
).T
errorCnt
= 0
for i
in range(200):
print('test %d : %d'%(i
,200))
x
= testDataMat
[i
]
y
= getClosest
(trainDataMat
,trainLabelMat
,x
,topK
)
if y
!= testLabelMat
[i
]:errorCnt
+= 1
return 1 - errorCnt
/ 200
if __name__
== '__main__':
start
= time
.time
()
trainData
, trainLabel
= loaddata
('data/mnist_train.csv')
testData
, testLabel
= loaddata
('data/mnist_test.csv')
accur
= model_test
(trainData
,trainLabel
,testData
,testLabel
,25)
print('accur : %d'%(accur
*100),'%')
end
= time
.time
()
print('time span:',end
-start
)
accur
: 97 %
time span
: 307.68515515327454