【machine learning】KNN算法
适逢学习机器学习基础知识,就将书中内容读读记记,本博文代码参考书本Machine Learning in Action(《机器学习实战》)。
一、概述
kNN算法又称为k近邻分类(k-nearest neighbor classification)算法。
kNN算法则是从训练集中找到和新数据最接近的k条记录,然后根据他们的主要分类来决定新数据的类别。该算法涉及3个主要因素:训练集、距离或相似的衡量、k的大小。
二、算法要点
1、指导思想
kNN算法的指导思想是“近朱者赤,近墨者黑”,由你的邻居来推断出你的类别。
计算步骤如下:
1)算距离:给定测试对象,计算它与训练集中的每个对象的距离
2)找邻居:圈定距离最近的k个训练对象,作为测试对象的近邻
3)做分类:根据这k个近邻归属的主要类别,来对测试对象分类
2、距离或相似度的衡量
什么是合适的距离衡量?距离越近应该意味着这两个点属于一个分类的可能性越大。
距离衡量包括欧式距离、夹角余弦等。
对于文本分类来说,使用余弦(cosine)来计算相似度就比欧式(Euclidean)距离更合适。
3、类别的判定
投票决定:少数服从多数,近邻中哪个类别的点最多就分为该类,属于以频率为标准。
加权投票法:根据距离的远近,对近邻的投票进行加权,距离越近则权重越大(权重为距离平方的倒数),属于以量化为标准。
三、优缺点
1、优点
简单,易于理解,易于实现,无需估计参数,无需训练
适合对稀有事件进行分类(例如当流失率很低时,比如低于0.5%,构造流失预测模型)
特别适合于多分类问题(multi-modal,对象具有多个类别标签),例如根据基因特征来判断其功能分类,kNN比SVM的表现要好
2、缺点
懒惰算法,对测试样本分类时的计算量大,内存开销大,评分慢
可解释性较差,无法给出决策树那样的规则。
四、利用KNN进行手写识别
假如存在训练数据,都是二值得灰度图,来源于手写面板的采集图像数据。如下表示数字‘0’,所在文件夹下包括表示0~9的文件,文件夹命名A_B.txt,A表示真实数字,B表示该数字的第B个样本(一般数据越多有有利于接近预测值)
在另一个文件夹中,也存在同样命名的数据文件,用于检验有监督学习下的准确率,我们称为测试数据。
在代码中,我们需要三个函数def classify0(inX, dataSet, labels, k)——用于对输入单个样本inX进行分类,dataSet为训练数据,labels为训练数据的类别,K为近邻范围
def img2vector(filename)——将文件filename中的数据规格由32X32转换为1X1024的向量
def handwritingClassTest()——利用测试数据进行测试,得出错误率
我这里只用了0~9分别20个训练数据而已,提高速度。需要源代码可以到机器学习实战的配套代码中取http://vdisk.weibo.com/s/uEZesAafcjQgx?sudaref=www.baidu.com
代码中用到了numpy库,numpy库用在数据量大的计算较高效
numpy用法小抄:
>>> tile([0, 0], (1, 2))
array([[0, 0, 0, 0]])
>>> tile([0, 0], (2, 1))
array([[0, 0],
[0, 0]])
第一个是矩阵A
第二个参数是要 只有一个数字时,表示 对 A中元素重复的次数
两个参数时(x, y) y表示对A中元素重复的次数, x表示 对前面的操作执行x次。
>>> b= np.arange(12).reshape(3,4)
>>> b
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> b.sum(axis=0) # 计算每一列的和,注意理解轴的含义,参考数组的第一篇文章
array([12, 15, 18, 21])
>>> b.min(axis=1) # 获取每一行的最小值
array([0, 4, 8])
>>> b.cumsum(axis=1) # 计算每一行的累积和
array([[ 0, 1, 3, 6],
[ 4, 9, 15, 22],
[ 8, 17, 27, 38]])
KNN.py
#! /usr/bin/env python #coding=utf-8 from numpy import * import operator from os import listdir def classify0(inX, dataSet, labels, k): #inX------[x,x,x,x] #dataSet------array([[x,x,x,x],[x,x,x,x]]) #labels------[x,x] #k------n dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #将字典按value值大小降序排序,结果为二维列表 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename,'r') for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVect trainFile = 'F:\\python\\pyproject\\ML\\codes\\machinelearninginaction\\Ch02\\training20\\' testFile = 'F:\\python\\pyproject\\ML\\codes\\machinelearninginaction\\Ch02\\testDigits\\' def handwritingClassTest(): hwLabels = [] trainingFileList = listdir(trainFile) #load the training set m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) path = trainFile + '%s' trainingMat[i,:] = img2vector(path % fileNameStr) testFileList = listdir(testFile) #iterate through the test set errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) path = testFile + '%s' vectorUnderTest = img2vector(path % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) if (classifierResult != classNumStr): errorCount += 1.0 print "\nthe total number of errors is: %d" % errorCount print "\nthe total error rate is: %f" % (errorCount/float(mTest))
再在test.py中调用KNN.handwritingClassTest(),则程序开始运行
test.py
#! /usr/bin/env python #coding=utf-8 import KNN KNN.handwritingClassTest()
郑重声明:本站内容如果来自互联网及其他传播媒体,其版权均属原媒体及文章作者所有。转载目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。