鱼C论坛

 找回密码
 立即注册
查看: 1236|回复: 0

[技术交流] k-近邻算法识别手写数字

[复制链接]
发表于 2018-4-10 15:14:11 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
首先介绍k-近邻算法的思想及原理:

优点:精度高,对异常值不敏感,无数据输入假定
缺点:计算复杂度高,空间复杂度高
适用数据范围:数值型和标称型

工作原理:
存在一个样本数据集合,并且样本集中都存在标签,即已知每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,只选择样本数据集中的前k个最相似的数据,这就是k的出处。

一般流程:
(1)收集数据:可以使用任意方法
(2)准备数据:距离计算所需的数据,最好是结构化的数据格式
(3)分析数据:可以使用任何方法
(4)训练算法:此步骤不适用k-近邻算法
(5)测试算法:计算错误率
(6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判断输入数分别属于哪个分类,最后应用对计算出的分类执行后续的处理。

核心代码:
  1. #构建k-近邻算法分类器
  2. def classify0(inX,dataSet,labels,k):   #用于分类的输入向量inX,输入的训练样本集dataSet,标签向量labels,用于选择最近邻居数目的k值
  3.         dataSetSize = dataSet.shape[0]   #利用shape函数读取dataSet第一维度长度
  4.         diffMat = np.tile(inX,(dataSetSize,1)) - dataSet        #利用欧氏距离公式计算两向量点之间的距离
  5.         sqDiffMat = diffMat ** 2
  6.         sqDistance = sqDiffMat.sum(axis=1)          #将矩阵的每一行向量相加求和
  7.         distances = sqDistance**0.5
  8.         sortedDistIndicies = distances.argsort()        #返回数组值从小到大的索引值
  9.         classCount = {}                #建立空字典
  10.         for i in range(k):                #选择距离最小的k个点
  11.                 voteIlabel = labels[sortedDistIndicies[i]]
  12.                 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1        #将classCount字典分解成元组列表
  13.         sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)        #导入运算符模块的itemgetter方法,按照第二个元素的次序对元组进行排序
  14.         return sortedClassCount[0][0]
复制代码


手写识别系统完整代码:
  1. #准备数据,将图像转换为测试向量
  2. #识别手写数字系统,将数字图片统一处理,原先的32*32二进制图像矩阵转换为1*1024的向量,首先利用函数img2vector就将图像转换为向量
  3. import numpy as np
  4. from os import listdir
  5. import operator

  6. def img2vector(filename):
  7.     returnVect = np.zeros((1,1024))  #创建1*1024的numpy矩阵
  8.     fr = open(filename)     #打开指定的文件
  9.     for i in range(32):   #循环读出前32行
  10.         lineStr = fr.readline()
  11.         for j in range(32):
  12.             returnVect[0,32*i+j] = int(lineStr[j])  #将每行的头32个字符值存储到numpy数组中
  13.     return returnVect   #返回数组

  14. #构建k-近邻算法分类器
  15. def classify0(inX,dataSet,labels,k):   #用于分类的输入向量inX,输入的训练样本集dataSet,标签向量labels,用于选择最近邻居数目的k值
  16.         dataSetSize = dataSet.shape[0]   #利用shape函数读取dataSet第一维度长度
  17.         diffMat = np.tile(inX,(dataSetSize,1)) - dataSet        #利用欧氏距离公式计算两向量点之间的距离
  18.         sqDiffMat = diffMat ** 2
  19.         sqDistance = sqDiffMat.sum(axis=1)          #将矩阵的每一行向量相加求和
  20.         distances = sqDistance**0.5
  21.         sortedDistIndicies = distances.argsort()        #返回数组值从小到大的索引值
  22.         classCount = {}                #建立空字典
  23.         for i in range(k):                #选择距离最小的k个点
  24.                 voteIlabel = labels[sortedDistIndicies[i]]
  25.                 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1        #将classCount字典分解成元组列表
  26.         sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)        #导入运算符模块的itemgetter方法,按照第二个元素的次序对元组进行排序
  27.         return sortedClassCount[0][0]

  28. #测试算法,从os模块导入listdir,可以列出给定目录的文件名
  29. def handwritingClassTest():
  30.     hwLabels = []
  31.     trainingFileList = listdir('trainingDigits')    #获取目录内容
  32.     m = len(trainingFileList)
  33.     trainingMat = np.zeros((m,1024))    #创建一个m*1024的训练矩阵
  34.     for i in range(m):
  35.         fileNameStr = trainingFileList[i]
  36.         fileStr = fileNameStr.split('.')[0]
  37.         classNumStr = int(fileStr.split('_')[0])    #从文件名解析分类数字
  38.         hwLabels.append(classNumStr)
  39.         trainingMat[1,:] = img2vector('trainingDigits/%s' % fileNameStr)
  40.     testFileList = listdir('testDigits')
  41.     errorCount = 0.0
  42.     mTest = len(testFileList)
  43.     for i in range(mTest):
  44.         fileNameStr = testFileList[i]
  45.         fileStr = fileNameStr.split('.')[0]
  46.         classNumStr =int(fileStr.split('_')[0])
  47.         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
  48.         classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)    #测试函数
  49.         print("the classifier came back with:%d,the real answer is :%d"% (classifierResult,classNumStr))
  50.         if (classifierResult != classNumStr):
  51.             errorCount += 1.0
  52.         print("\nthe total number of error is :%d" %errorCount)
  53.         print("\nthe total error rate is : %f" % (errorCount/float(mTest)))
复制代码


运行程序,命令框内执行:handwritingClassTest()
结果如下:
  1. the total number of error is :10

  2. the total error rate is : 0.010571
  3. the classifier came back with:9,the real answer is :9

  4. the total number of error is :10

  5. the total error rate is : 0.010571
  6. the classifier came back with:9,the real answer is :9

  7. the total number of error is :10

  8. the total error rate is : 0.010571
复制代码

错误率为1.06%。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-16 15:22

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表